diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml
new file mode 100644
index 0000000000..b04fb15cb2
--- /dev/null
+++ b/.github/workflows/black.yml
@@ -0,0 +1,10 @@
+name: Lint
+
+on: [push, pull_request]
+
+jobs:
+ lint:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v2
+ - uses: psf/black@stable
diff --git a/README.md b/README.md
index d1182544d4..98ba150c83 100644
--- a/README.md
+++ b/README.md
@@ -2,7 +2,6 @@
# Archery
-[![star](https://gitee.com/rtttte/Archery/badge/star.svg?theme=gvp)](https://gitee.com/rtttte/Archery)
[![Django CI](https://github.com/hhyo/Archery/actions/workflows/django.yml/badge.svg)](https://github.com/hhyo/Archery/actions/workflows/django.yml)
[![Release](https://img.shields.io/github/release/hhyo/archery.svg)](https://github.com/hhyo/archery/releases/)
[![codecov](https://codecov.io/gh/hhyo/archery/branch/master/graph/badge.svg)](https://codecov.io/gh/hhyo/archery)
@@ -11,7 +10,7 @@
[![Publish Docker image](https://github.com/hhyo/Archery/actions/workflows/docker-image.yml/badge.svg)](https://github.com/hhyo/Archery/actions/workflows/docker-image.yml)
[![docker_pulls](https://img.shields.io/docker/pulls/hhyo/archery.svg)](https://hub.docker.com/r/hhyo/archery/)
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](http://github.com/hhyo/archery/blob/master/LICENSE)
-[![996.icu](https://img.shields.io/badge/link-996.icu-red.svg)](https://996.icu)
+[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[文档](https://archerydms.com/) | [FAQ](https://github.com/hhyo/archery/wiki/FAQ) | [Releases](https://github.com/hhyo/archery/releases/)
diff --git a/archery/__init__.py b/archery/__init__.py
index 980bd0c637..c1d56d39c8 100644
--- a/archery/__init__.py
+++ b/archery/__init__.py
@@ -1,2 +1,2 @@
version = (1, 9, 0)
-display_version = '.'.join(str(i) for i in version)
+display_version = ".".join(str(i) for i in version)
diff --git a/archery/asgi.py b/archery/asgi.py
index 52d4205405..656f0fa9d0 100644
--- a/archery/asgi.py
+++ b/archery/asgi.py
@@ -11,6 +11,6 @@
from django.core.asgi import get_asgi_application
-os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'archery.settings')
+os.environ.setdefault("DJANGO_SETTINGS_MODULE", "archery.settings")
application = get_asgi_application()
diff --git a/archery/settings.py b/archery/settings.py
index 16a996258f..f4da941302 100644
--- a/archery/settings.py
+++ b/archery/settings.py
@@ -9,24 +9,23 @@
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
-environ.Env.read_env(os.path.join(BASE_DIR, '.env'))
+environ.Env.read_env(os.path.join(BASE_DIR, ".env"))
env = environ.Env(
DEBUG=(bool, False),
ALLOWED_HOSTS=(List[str], ["*"]),
- SECRET_KEY=(str, 'hfusaf2m4ot#7)fkw#di2bu6(cv0@opwmafx5n#6=3d%x^hpl6'),
+ SECRET_KEY=(str, "hfusaf2m4ot#7)fkw#di2bu6(cv0@opwmafx5n#6=3d%x^hpl6"),
DATABASE_URL=(str, "mysql://root:@127.0.0.1:3306/archery"),
CACHE_URL=(str, "redis://127.0.0.1:6379/0"),
DINGDING_CACHE_URL=(str, "redis://127.0.0.1:6379/1"),
ENABLE_LDAP=(bool, False),
AUTH_LDAP_ALWAYS_UPDATE_USER=(bool, True),
- AUTH_LDAP_USER_ATTR_MAP=(dict, {
- "username": "cn",
- "display": "displayname",
- "email": "mail"
- }),
+ AUTH_LDAP_USER_ATTR_MAP=(
+ dict,
+ {"username": "cn", "display": "displayname", "email": "mail"},
+ ),
Q_CLUISTER_SYNC=(bool, False), # qcluster 同步模式, debug 时可以调整为 True
# CSRF_TRUSTED_ORIGINS=subdomain.example.com,subdomain.example2.com subdomain.example.com
- CSRF_TRUSTED_ORIGINS=(list, [])
+ CSRF_TRUSTED_ORIGINS=(list, []),
)
# SECURITY WARNING: keep the secret key used in production secret!
@@ -48,59 +47,59 @@
# Application definition
INSTALLED_APPS = (
- 'django.contrib.admin',
- 'django.contrib.auth',
- 'django.contrib.contenttypes',
- 'django.contrib.sessions',
- 'django.contrib.messages',
- 'django.contrib.staticfiles',
- 'django_q',
- 'sql',
- 'sql_api',
- 'common',
- 'rest_framework',
- 'django_filters',
- 'drf_spectacular',
+ "django.contrib.admin",
+ "django.contrib.auth",
+ "django.contrib.contenttypes",
+ "django.contrib.sessions",
+ "django.contrib.messages",
+ "django.contrib.staticfiles",
+ "django_q",
+ "sql",
+ "sql_api",
+ "common",
+ "rest_framework",
+ "django_filters",
+ "drf_spectacular",
)
MIDDLEWARE = (
- 'django.contrib.sessions.middleware.SessionMiddleware',
- 'django.middleware.common.CommonMiddleware',
- 'django.middleware.csrf.CsrfViewMiddleware',
- 'django.contrib.auth.middleware.AuthenticationMiddleware',
- 'django.contrib.messages.middleware.MessageMiddleware',
- 'django.middleware.clickjacking.XFrameOptionsMiddleware',
- 'django.middleware.security.SecurityMiddleware',
- 'django.middleware.gzip.GZipMiddleware',
- 'common.middleware.check_login_middleware.CheckLoginMiddleware',
- 'common.middleware.exception_logging_middleware.ExceptionLoggingMiddleware',
+ "django.contrib.sessions.middleware.SessionMiddleware",
+ "django.middleware.common.CommonMiddleware",
+ "django.middleware.csrf.CsrfViewMiddleware",
+ "django.contrib.auth.middleware.AuthenticationMiddleware",
+ "django.contrib.messages.middleware.MessageMiddleware",
+ "django.middleware.clickjacking.XFrameOptionsMiddleware",
+ "django.middleware.security.SecurityMiddleware",
+ "django.middleware.gzip.GZipMiddleware",
+ "common.middleware.check_login_middleware.CheckLoginMiddleware",
+ "common.middleware.exception_logging_middleware.ExceptionLoggingMiddleware",
)
-ROOT_URLCONF = 'archery.urls'
+ROOT_URLCONF = "archery.urls"
TEMPLATES = [
{
- 'BACKEND': 'django.template.backends.django.DjangoTemplates',
- 'DIRS': [os.path.join(BASE_DIR, 'common/templates')],
- 'APP_DIRS': True,
- 'OPTIONS': {
- 'context_processors': [
- 'django.template.context_processors.debug',
- 'django.template.context_processors.request',
- 'django.contrib.auth.context_processors.auth',
- 'django.contrib.messages.context_processors.messages',
- 'common.utils.global_info.global_info',
+ "BACKEND": "django.template.backends.django.DjangoTemplates",
+ "DIRS": [os.path.join(BASE_DIR, "common/templates")],
+ "APP_DIRS": True,
+ "OPTIONS": {
+ "context_processors": [
+ "django.template.context_processors.debug",
+ "django.template.context_processors.request",
+ "django.contrib.auth.context_processors.auth",
+ "django.contrib.messages.context_processors.messages",
+ "common.utils.global_info.global_info",
],
},
},
]
-WSGI_APPLICATION = 'archery.wsgi.application'
+WSGI_APPLICATION = "archery.wsgi.application"
# Internationalization
-LANGUAGE_CODE = 'zh-hans'
+LANGUAGE_CODE = "zh-hans"
-TIME_ZONE = 'Asia/Shanghai'
+TIME_ZONE = "Asia/Shanghai"
USE_I18N = True
@@ -108,14 +107,16 @@
# 时间格式化
USE_L10N = False
-DATETIME_FORMAT = 'Y-m-d H:i:s'
-DATE_FORMAT = 'Y-m-d'
+DATETIME_FORMAT = "Y-m-d H:i:s"
+DATE_FORMAT = "Y-m-d"
# Static files (CSS, JavaScript, Images)
-STATIC_URL = '/static/'
-STATIC_ROOT = os.path.join(BASE_DIR, 'static')
-STATICFILES_DIRS = [os.path.join(BASE_DIR, 'common/static'), ]
-STATICFILES_STORAGE = 'common.storage.ForgivingManifestStaticFilesStorage'
+STATIC_URL = "/static/"
+STATIC_ROOT = os.path.join(BASE_DIR, "static")
+STATICFILES_DIRS = [
+ os.path.join(BASE_DIR, "common/static"),
+]
+STATICFILES_STORAGE = "common.storage.ForgivingManifestStaticFilesStorage"
# 扩展django admin里users字段用到,指定了sql/models.py里的class users
AUTH_USER_MODEL = "sql.Users"
@@ -123,19 +124,19 @@
# 密码校验
AUTH_PASSWORD_VALIDATORS = [
{
- 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
+ "NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator",
},
{
- 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
- 'OPTIONS': {
- 'min_length': 9,
- }
+ "NAME": "django.contrib.auth.password_validation.MinimumLengthValidator",
+ "OPTIONS": {
+ "min_length": 9,
+ },
},
{
- 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
+ "NAME": "django.contrib.auth.password_validation.CommonPasswordValidator",
},
{
- 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
+ "NAME": "django.contrib.auth.password_validation.NumericPasswordValidator",
},
]
@@ -148,36 +149,36 @@
# 该项目本身的mysql数据库地址
DATABASES = {
- 'default': {
+ "default": {
**env.db(),
**{
- 'DEFAULT_CHARSET': 'utf8mb4',
- 'CONN_MAX_AGE': 50,
- 'OPTIONS': {
- 'init_command': "SET sql_mode='STRICT_TRANS_TABLES'",
- 'charset': 'utf8mb4'
+ "DEFAULT_CHARSET": "utf8mb4",
+ "CONN_MAX_AGE": 50,
+ "OPTIONS": {
+ "init_command": "SET sql_mode='STRICT_TRANS_TABLES'",
+ "charset": "utf8mb4",
},
- 'TEST': {
- 'NAME': 'test_archery',
- 'CHARSET': 'utf8mb4',
- }
- }
+ "TEST": {
+ "NAME": "test_archery",
+ "CHARSET": "utf8mb4",
+ },
+ },
}
}
# Django-Q
Q_CLUSTER = {
- 'name': 'archery',
- 'workers': 4,
- 'recycle': 500,
- 'timeout': 60,
- 'compress': True,
- 'cpu_affinity': 1,
- 'save_limit': 0,
- 'queue_limit': 50,
- 'label': 'Django Q',
- 'django_redis': 'default',
- 'sync': env("Q_CLUISTER_SYNC") # 本地调试可以修改为True,使用同步模式
+ "name": "archery",
+ "workers": 4,
+ "recycle": 500,
+ "timeout": 60,
+ "compress": True,
+ "cpu_affinity": 1,
+ "save_limit": 0,
+ "queue_limit": 50,
+ "label": "Django Q",
+ "django_redis": "default",
+ "sync": env("Q_CLUISTER_SYNC"), # 本地调试可以修改为True,使用同步模式
}
# 缓存配置
@@ -186,49 +187,46 @@
}
# https://docs.djangoproject.com/en/3.2/ref/settings/#std-setting-DEFAULT_AUTO_FIELD
-DEFAULT_AUTO_FIELD = 'django.db.models.AutoField'
+DEFAULT_AUTO_FIELD = "django.db.models.AutoField"
# API Framework
REST_FRAMEWORK = {
- 'DEFAULT_SCHEMA_CLASS': 'drf_spectacular.openapi.AutoSchema',
- 'DEFAULT_RENDERER_CLASSES': ('rest_framework.renderers.JSONRenderer',),
+ "DEFAULT_SCHEMA_CLASS": "drf_spectacular.openapi.AutoSchema",
+ "DEFAULT_RENDERER_CLASSES": ("rest_framework.renderers.JSONRenderer",),
# 鉴权
- 'DEFAULT_AUTHENTICATION_CLASSES': (
- 'rest_framework_simplejwt.authentication.JWTAuthentication',
- 'rest_framework.authentication.SessionAuthentication',
+ "DEFAULT_AUTHENTICATION_CLASSES": (
+ "rest_framework_simplejwt.authentication.JWTAuthentication",
+ "rest_framework.authentication.SessionAuthentication",
),
# 权限
- 'DEFAULT_PERMISSION_CLASSES': ('sql_api.permissions.IsInUserWhitelist',),
+ "DEFAULT_PERMISSION_CLASSES": ("sql_api.permissions.IsInUserWhitelist",),
# 限速(anon:未认证用户 user:认证用户)
- 'DEFAULT_THROTTLE_CLASSES': (
- 'rest_framework.throttling.AnonRateThrottle',
- 'rest_framework.throttling.UserRateThrottle',
+ "DEFAULT_THROTTLE_CLASSES": (
+ "rest_framework.throttling.AnonRateThrottle",
+ "rest_framework.throttling.UserRateThrottle",
),
- 'DEFAULT_THROTTLE_RATES': {
- 'anon': '120/min',
- 'user': '600/min'
- },
+ "DEFAULT_THROTTLE_RATES": {"anon": "120/min", "user": "600/min"},
# 过滤
- 'DEFAULT_FILTER_BACKENDS': ('django_filters.rest_framework.DjangoFilterBackend',),
+ "DEFAULT_FILTER_BACKENDS": ("django_filters.rest_framework.DjangoFilterBackend",),
# 分页
- 'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.PageNumberPagination',
- 'PAGE_SIZE': 5,
+ "DEFAULT_PAGINATION_CLASS": "rest_framework.pagination.PageNumberPagination",
+ "PAGE_SIZE": 5,
}
# Swagger UI
SPECTACULAR_SETTINGS = {
- 'TITLE': 'Archery API',
- 'DESCRIPTION': 'OpenAPI 3.0',
- 'VERSION': '1.0.0',
+ "TITLE": "Archery API",
+ "DESCRIPTION": "OpenAPI 3.0",
+ "VERSION": "1.0.0",
}
# API Authentication
SIMPLE_JWT = {
- 'ACCESS_TOKEN_LIFETIME': timedelta(hours=4),
- 'REFRESH_TOKEN_LIFETIME': timedelta(days=3),
- 'ALGORITHM': 'HS256',
- 'SIGNING_KEY': SECRET_KEY,
- 'AUTH_HEADER_TYPES': ('Bearer',),
+ "ACCESS_TOKEN_LIFETIME": timedelta(hours=4),
+ "REFRESH_TOKEN_LIFETIME": timedelta(days=3),
+ "ALGORITHM": "HS256",
+ "SIGNING_KEY": SECRET_KEY,
+ "AUTH_HEADER_TYPES": ("Bearer",),
}
# LDAP
@@ -238,68 +236,78 @@
from django_auth_ldap.config import LDAPSearch
AUTHENTICATION_BACKENDS = (
- 'django_auth_ldap.backend.LDAPBackend', # 配置为先使用LDAP认证,如通过认证则不再使用后面的认证方式
- 'django.contrib.auth.backends.ModelBackend', # django系统中手动创建的用户也可使用,优先级靠后。注意这2行的顺序
+ "django_auth_ldap.backend.LDAPBackend", # 配置为先使用LDAP认证,如通过认证则不再使用后面的认证方式
+ "django.contrib.auth.backends.ModelBackend", # django系统中手动创建的用户也可使用,优先级靠后。注意这2行的顺序
)
AUTH_LDAP_SERVER_URI = env("AUTH_LDAP_SERVER_URI", default="ldap://xxx")
AUTH_LDAP_USER_DN_TEMPLATE = env("AUTH_LDAP_USER_DN_TEMPLATE", default=None)
if not AUTH_LDAP_USER_DN_TEMPLATE:
del AUTH_LDAP_USER_DN_TEMPLATE
- AUTH_LDAP_BIND_DN = env("AUTH_LDAP_BIND_DN", default="cn=xxx,ou=xxx,dc=xxx,dc=xxx")
+ AUTH_LDAP_BIND_DN = env(
+ "AUTH_LDAP_BIND_DN", default="cn=xxx,ou=xxx,dc=xxx,dc=xxx"
+ )
AUTH_LDAP_BIND_PASSWORD = env("AUTH_LDAP_BIND_PASSWORD", default="***********")
- AUTH_LDAP_USER_SEARCH_BASE = env("AUTH_LDAP_USER_SEARCH_BASE", default="ou=xxx,dc=xxx,dc=xxx")
- AUTH_LDAP_USER_SEARCH_FILTER = env("AUTH_LDAP_USER_SEARCH_FILTER", default='(cn=%(user)s)')
- AUTH_LDAP_USER_SEARCH = LDAPSearch(AUTH_LDAP_USER_SEARCH_BASE, ldap.SCOPE_SUBTREE, AUTH_LDAP_USER_SEARCH_FILTER)
- AUTH_LDAP_ALWAYS_UPDATE_USER = env("AUTH_LDAP_ALWAYS_UPDATE_USER", default=True) # 每次登录从ldap同步用户信息
+ AUTH_LDAP_USER_SEARCH_BASE = env(
+ "AUTH_LDAP_USER_SEARCH_BASE", default="ou=xxx,dc=xxx,dc=xxx"
+ )
+ AUTH_LDAP_USER_SEARCH_FILTER = env(
+ "AUTH_LDAP_USER_SEARCH_FILTER", default="(cn=%(user)s)"
+ )
+ AUTH_LDAP_USER_SEARCH = LDAPSearch(
+ AUTH_LDAP_USER_SEARCH_BASE, ldap.SCOPE_SUBTREE, AUTH_LDAP_USER_SEARCH_FILTER
+ )
+ AUTH_LDAP_ALWAYS_UPDATE_USER = env(
+ "AUTH_LDAP_ALWAYS_UPDATE_USER", default=True
+ ) # 每次登录从ldap同步用户信息
AUTH_LDAP_USER_ATTR_MAP = env("AUTH_LDAP_USER_ATTR_MAP")
# LOG配置
LOGGING = {
- 'version': 1,
- 'disable_existing_loggers': False,
- 'formatters': {
- 'verbose': {
- 'format': '[%(asctime)s][%(threadName)s:%(thread)d][task_id:%(name)s][%(filename)s:%(lineno)d][%(levelname)s]- %(message)s'
+ "version": 1,
+ "disable_existing_loggers": False,
+ "formatters": {
+ "verbose": {
+ "format": "[%(asctime)s][%(threadName)s:%(thread)d][task_id:%(name)s][%(filename)s:%(lineno)d][%(levelname)s]- %(message)s"
},
},
- 'handlers': {
- 'default': {
- 'level': 'DEBUG',
- 'class': 'logging.handlers.RotatingFileHandler',
- 'filename': 'logs/archery.log',
- 'maxBytes': 1024 * 1024 * 100, # 5 MB
- 'backupCount': 5,
- 'formatter': 'verbose',
+ "handlers": {
+ "default": {
+ "level": "DEBUG",
+ "class": "logging.handlers.RotatingFileHandler",
+ "filename": "logs/archery.log",
+ "maxBytes": 1024 * 1024 * 100, # 5 MB
+ "backupCount": 5,
+ "formatter": "verbose",
+ },
+ "django-q": {
+ "level": "DEBUG",
+ "class": "logging.handlers.RotatingFileHandler",
+ "filename": "logs/qcluster.log",
+ "maxBytes": 1024 * 1024 * 100, # 5 MB
+ "backupCount": 5,
+ "formatter": "verbose",
},
- 'django-q': {
- 'level': 'DEBUG',
- 'class': 'logging.handlers.RotatingFileHandler',
- 'filename': 'logs/qcluster.log',
- 'maxBytes': 1024 * 1024 * 100, # 5 MB
- 'backupCount': 5,
- 'formatter': 'verbose',
+ "console": {
+ "level": "DEBUG",
+ "class": "logging.StreamHandler",
+ "formatter": "verbose",
},
- 'console': {
- 'level': 'DEBUG',
- 'class': 'logging.StreamHandler',
- 'formatter': 'verbose'
- }
},
- 'loggers': {
- 'default': { # default日志
- 'handlers': ['console', 'default'],
- 'level': 'WARNING'
+ "loggers": {
+ "default": { # default日志
+ "handlers": ["console", "default"],
+ "level": "WARNING",
},
- 'django-q': { # django_q模块相关日志
- 'handlers': ['console', 'django-q'],
- 'level': 'WARNING',
- 'propagate': False
+ "django-q": { # django_q模块相关日志
+ "handlers": ["console", "django-q"],
+ "level": "WARNING",
+ "propagate": False,
},
- 'django_auth_ldap': { # django_auth_ldap模块相关日志
- 'handlers': ['console', 'default'],
- 'level': 'WARNING',
- 'propagate': False
+ "django_auth_ldap": { # django_auth_ldap模块相关日志
+ "handlers": ["console", "default"],
+ "level": "WARNING",
+ "propagate": False,
},
# 'django.db': { # 打印SQL语句,方便开发
# 'handlers': ['console', 'default'],
@@ -311,14 +319,14 @@
# 'level': 'DEBUG',
# 'propagate': False
# },
- }
+ },
}
-MEDIA_ROOT = os.path.join(BASE_DIR, 'media')
+MEDIA_ROOT = os.path.join(BASE_DIR, "media")
if not os.path.exists(MEDIA_ROOT):
os.mkdir(MEDIA_ROOT)
-PKEY_ROOT = os.path.join(MEDIA_ROOT, 'keys')
+PKEY_ROOT = os.path.join(MEDIA_ROOT, "keys")
if not os.path.exists(PKEY_ROOT):
os.mkdir(PKEY_ROOT)
diff --git a/archery/urls.py b/archery/urls.py
index e1e35c40e4..89d8caf351 100644
--- a/archery/urls.py
+++ b/archery/urls.py
@@ -3,9 +3,9 @@
from common import views
urlpatterns = [
- path('admin/', admin.site.urls),
- path('api/', include(('sql_api.urls', 'sql_api'), namespace="sql_api")),
- path('', include(('sql.urls', 'sql'), namespace="sql")),
+ path("admin/", admin.site.urls),
+ path("api/", include(("sql_api.urls", "sql_api"), namespace="sql_api")),
+ path("", include(("sql.urls", "sql"), namespace="sql")),
]
handler400 = views.bad_request
diff --git a/common/auth.py b/common/auth.py
index f26e081c84..3bb227be6a 100644
--- a/common/auth.py
+++ b/common/auth.py
@@ -15,7 +15,7 @@
from common.utils.ding_api import get_ding_user_id
from sql.models import Users, ResourceGroup, TwoFactorAuthConfig
-logger = logging.getLogger('default')
+logger = logging.getLogger("default")
def init_user(user):
@@ -25,17 +25,24 @@ def init_user(user):
:return:
"""
# 添加到默认权限组
- default_auth_group = SysConfig().get('default_auth_group', '')
+ default_auth_group = SysConfig().get("default_auth_group", "")
if default_auth_group:
- default_auth_group = default_auth_group.split(',')
- [user.groups.add(group) for group in Group.objects.filter(name__in=default_auth_group)]
+ default_auth_group = default_auth_group.split(",")
+ [
+ user.groups.add(group)
+ for group in Group.objects.filter(name__in=default_auth_group)
+ ]
# 添加到默认资源组
- default_resource_group = SysConfig().get('default_resource_group', '')
+ default_resource_group = SysConfig().get("default_resource_group", "")
if default_resource_group:
- default_resource_group = default_resource_group.split(',')
- [user.resource_group.add(group) for group in
- ResourceGroup.objects.filter(group_name__in=default_resource_group)]
+ default_resource_group = default_resource_group.split(",")
+ [
+ user.resource_group.add(group)
+ for group in ResourceGroup.objects.filter(
+ group_name__in=default_resource_group
+ )
+ ]
class ArcheryAuth(object):
@@ -55,8 +62,8 @@ def challenge(username=None, password=None):
return user
def authenticate(self):
- username = self.request.POST.get('username')
- password = self.request.POST.get('password')
+ username = self.request.POST.get("username")
+ password = self.request.POST.get("password")
# 确认用户是否已经存在
try:
user = Users.objects.get(username=username)
@@ -65,23 +72,30 @@ def authenticate(self):
if authenticated_user:
# ldap 首次登录逻辑
init_user(authenticated_user)
- return {'status': 0, 'msg': 'ok', 'data': authenticated_user}
+ return {"status": 0, "msg": "ok", "data": authenticated_user}
else:
- return {'status': 1, 'msg': '用户名或密码错误,请重新输入!', 'data': ''}
+ return {"status": 1, "msg": "用户名或密码错误,请重新输入!", "data": ""}
except:
- logger.error('验证用户密码时报错')
+ logger.error("验证用户密码时报错")
logger.error(traceback.format_exc())
- return {'status': 1, 'msg': f'服务异常,请联系管理员处理', 'data': ''}
+ return {"status": 1, "msg": f"服务异常,请联系管理员处理", "data": ""}
# 已存在用户, 验证是否在锁期间
# 读取配置文件
- lock_count = int(self.sys_config.get('lock_cnt_threshold', 5))
- lock_time = int(self.sys_config.get('lock_time_threshold', 60 * 5))
+ lock_count = int(self.sys_config.get("lock_cnt_threshold", 5))
+ lock_time = int(self.sys_config.get("lock_time_threshold", 60 * 5))
# 验证是否在锁, 分了几个if 防止代码太长
if user.failed_login_count and user.last_login_failed_at:
if user.failed_login_count >= lock_count:
now = datetime.datetime.now()
- if user.last_login_failed_at + datetime.timedelta(seconds=lock_time) > now:
- return {'status': 3, 'msg': f'登录失败超过限制,该账号已被锁定!请等候大约{lock_time}秒再试', 'data': ''}
+ if (
+ user.last_login_failed_at + datetime.timedelta(seconds=lock_time)
+ > now
+ ):
+ return {
+ "status": 3,
+ "msg": f"登录失败超过限制,该账号已被锁定!请等候大约{lock_time}秒再试",
+ "data": "",
+ }
else:
# 如果锁已超时, 重置失败次数
user.failed_login_count = 0
@@ -90,11 +104,11 @@ def authenticate(self):
if authenticated_user:
if not authenticated_user.last_login:
init_user(authenticated_user)
- return {'status': 0, 'msg': 'ok', 'data': authenticated_user}
+ return {"status": 0, "msg": "ok", "data": authenticated_user}
user.failed_login_count += 1
user.last_login_failed_at = datetime.datetime.now()
user.save()
- return {'status': 1, 'msg': '用户名或密码错误,请重新输入!', 'data': ''}
+ return {"status": 1, "msg": "用户名或密码错误,请重新输入!", "data": ""}
# ajax接口,登录页面调用,用来验证用户名密码
@@ -102,69 +116,71 @@ def authenticate_entry(request):
"""接收http请求,然后把请求中的用户名密码传给ArcherAuth去验证"""
new_auth = ArcheryAuth(request)
result = new_auth.authenticate()
- if result['status'] == 0:
- authenticated_user = result['data']
+ if result["status"] == 0:
+ authenticated_user = result["data"]
twofa_enabled = TwoFactorAuthConfig.objects.filter(user=authenticated_user)
# 是否开启全局2fa
- if SysConfig().get('enforce_2fa'):
+ if SysConfig().get("enforce_2fa"):
# 用户是否配置过2fa
if twofa_enabled:
- verify_mode = 'verify_only'
+ verify_mode = "verify_only"
else:
- verify_mode = 'verify_config'
+ verify_mode = "verify_config"
# 设置无登录状态session
s = SessionStore()
- s['user'] = authenticated_user.username
- s['verify_mode'] = verify_mode
+ s["user"] = authenticated_user.username
+ s["verify_mode"] = verify_mode
s.set_expiry(300)
s.create()
- result = {'status': 0, 'msg': 'ok', 'data': s.session_key}
+ result = {"status": 0, "msg": "ok", "data": s.session_key}
else:
# 用户是否配置过2fa
if twofa_enabled:
# 设置无登录状态session
s = SessionStore()
- s['user'] = authenticated_user.username
- s['verify_mode'] = 'verify_only'
+ s["user"] = authenticated_user.username
+ s["verify_mode"] = "verify_only"
s.set_expiry(300)
s.create()
- result = {'status': 0, 'msg': 'ok', 'data': s.session_key}
+ result = {"status": 0, "msg": "ok", "data": s.session_key}
else:
# 未设置2fa直接登录
login(request, authenticated_user)
# 从钉钉获取该用户的 dingding_id,用于单独给他发消息
- if SysConfig().get("ding_to_person") is True and "admin" not in request.POST.get('username'):
- get_ding_user_id(request.POST.get('username'))
- result = {'status': 0, 'msg': 'ok', 'data': None}
+ if SysConfig().get(
+ "ding_to_person"
+ ) is True and "admin" not in request.POST.get("username"):
+ get_ding_user_id(request.POST.get("username"))
+ result = {"status": 0, "msg": "ok", "data": None}
- return HttpResponse(json.dumps(result), content_type='application/json')
+ return HttpResponse(json.dumps(result), content_type="application/json")
# 注册用户
def sign_up(request):
- sign_up_enabled = SysConfig().get('sign_up_enabled', False)
+ sign_up_enabled = SysConfig().get("sign_up_enabled", False)
if not sign_up_enabled:
- result = {'status': 1, 'msg': '注册未启用,请联系管理员开启', 'data': None}
- return HttpResponse(json.dumps(result), content_type='application/json')
- username = request.POST.get('username')
- password = request.POST.get('password')
- password2 = request.POST.get('password2')
- display = request.POST.get('display')
- email = request.POST.get('email')
- result = {'status': 0, 'msg': 'ok', 'data': None}
+ result = {"status": 1, "msg": "注册未启用,请联系管理员开启", "data": None}
+ return HttpResponse(json.dumps(result), content_type="application/json")
+ username = request.POST.get("username")
+ password = request.POST.get("password")
+ password2 = request.POST.get("password2")
+ display = request.POST.get("display")
+ email = request.POST.get("email")
+ result = {"status": 0, "msg": "ok", "data": None}
if not (username and password):
- result['status'] = 1
- result['msg'] = '用户名和密码不能为空'
+ result["status"] = 1
+ result["msg"] = "用户名和密码不能为空"
elif len(Users.objects.filter(username=username)) > 0:
- result['status'] = 1
- result['msg'] = '用户名已存在'
+ result["status"] = 1
+ result["msg"] = "用户名已存在"
elif password != password2:
- result['status'] = 1
- result['msg'] = '两次输入密码不一致'
+ result["status"] = 1
+ result["msg"] = "两次输入密码不一致"
elif not display:
- result['status'] = 1
- result['msg'] = '请填写中文名'
+ result["status"] = 1
+ result["msg"] = "请填写中文名"
else:
# 验证密码
try:
@@ -175,15 +191,15 @@ def sign_up(request):
display=display,
email=email,
is_active=1,
- is_staff=True
+ is_staff=True,
)
except ValidationError as msg:
- result['status'] = 1
- result['msg'] = str(msg)
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result["status"] = 1
+ result["msg"] = str(msg)
+ return HttpResponse(json.dumps(result), content_type="application/json")
# 退出登录
def sign_out(request):
logout(request)
- return HttpResponseRedirect(reverse('sql:login'))
+ return HttpResponseRedirect(reverse("sql:login"))
diff --git a/common/check.py b/common/check.py
index 300e17bc55..93ac84747b 100644
--- a/common/check.py
+++ b/common/check.py
@@ -11,106 +11,119 @@
from sql.models import Instance
from common.utils.sendmsg import MsgSender
-logger = logging.getLogger('default')
+logger = logging.getLogger("default")
# 检测inception配置
@superuser_required
def go_inception(request):
- result = {'status': 0, 'msg': 'ok', 'data': []}
- go_inception_host = request.POST.get('go_inception_host', '')
- go_inception_port = request.POST.get('go_inception_port', '')
- inception_remote_backup_host = request.POST.get('inception_remote_backup_host', '')
- inception_remote_backup_port = request.POST.get('inception_remote_backup_port', '')
- inception_remote_backup_user = request.POST.get('inception_remote_backup_user', '')
- inception_remote_backup_password = request.POST.get('inception_remote_backup_password', '')
+ result = {"status": 0, "msg": "ok", "data": []}
+ go_inception_host = request.POST.get("go_inception_host", "")
+ go_inception_port = request.POST.get("go_inception_port", "")
+ inception_remote_backup_host = request.POST.get("inception_remote_backup_host", "")
+ inception_remote_backup_port = request.POST.get("inception_remote_backup_port", "")
+ inception_remote_backup_user = request.POST.get("inception_remote_backup_user", "")
+ inception_remote_backup_password = request.POST.get(
+ "inception_remote_backup_password", ""
+ )
try:
- conn = MySQLdb.connect(host=go_inception_host, port=int(go_inception_port), charset='utf8mb4',
- connect_timeout=5)
+ conn = MySQLdb.connect(
+ host=go_inception_host,
+ port=int(go_inception_port),
+ charset="utf8mb4",
+ connect_timeout=5,
+ )
cur = conn.cursor()
except Exception as e:
logger.error(traceback.format_exc())
- result['status'] = 1
- result['msg'] = '无法连接goInception\n{}'.format(str(e))
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result["status"] = 1
+ result["msg"] = "无法连接goInception\n{}".format(str(e))
+ return HttpResponse(json.dumps(result), content_type="application/json")
else:
cur.close()
conn.close()
try:
- conn = MySQLdb.connect(host=inception_remote_backup_host,
- port=int(inception_remote_backup_port),
- user=inception_remote_backup_user,
- password=inception_remote_backup_password,
- charset='utf8mb4',
- connect_timeout=5)
+ conn = MySQLdb.connect(
+ host=inception_remote_backup_host,
+ port=int(inception_remote_backup_port),
+ user=inception_remote_backup_user,
+ password=inception_remote_backup_password,
+ charset="utf8mb4",
+ connect_timeout=5,
+ )
cur = conn.cursor()
except Exception as e:
logger.error(traceback.format_exc())
- result['status'] = 1
- result['msg'] = '无法连接goInception备份库\n{}'.format(str(e))
+ result["status"] = 1
+ result["msg"] = "无法连接goInception备份库\n{}".format(str(e))
else:
cur.close()
conn.close()
# 返回结果
- return HttpResponse(json.dumps(result), content_type='application/json')
+ return HttpResponse(json.dumps(result), content_type="application/json")
# 检测email配置
@superuser_required
def email(request):
- result = {'status': 0, 'msg': 'ok', 'data': []}
- mail = True if request.POST.get('mail', '') == 'true' else False
- mail_ssl = True if request.POST.get('mail_ssl') == 'true' else False
- mail_smtp_server = request.POST.get('mail_smtp_server', '')
- mail_smtp_port = request.POST.get('mail_smtp_port', '')
- mail_smtp_user = request.POST.get('mail_smtp_user', '')
- mail_smtp_password = request.POST.get('mail_smtp_password', '')
+ result = {"status": 0, "msg": "ok", "data": []}
+ mail = True if request.POST.get("mail", "") == "true" else False
+ mail_ssl = True if request.POST.get("mail_ssl") == "true" else False
+ mail_smtp_server = request.POST.get("mail_smtp_server", "")
+ mail_smtp_port = request.POST.get("mail_smtp_port", "")
+ mail_smtp_user = request.POST.get("mail_smtp_user", "")
+ mail_smtp_password = request.POST.get("mail_smtp_password", "")
if not mail:
- result['status'] = 1
- result['msg'] = '请先开启邮件通知!'
+ result["status"] = 1
+ result["msg"] = "请先开启邮件通知!"
# 返回结果
- return HttpResponse(json.dumps(result), content_type='application/json')
+ return HttpResponse(json.dumps(result), content_type="application/json")
try:
mail_smtp_port = int(mail_smtp_port)
if mail_smtp_port < 0:
raise ValueError
except ValueError:
- result['status'] = 1
- result['msg'] = '端口号只能为正整数'
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result["status"] = 1
+ result["msg"] = "端口号只能为正整数"
+ return HttpResponse(json.dumps(result), content_type="application/json")
if not request.user.email:
- result['status'] = 1
- result['msg'] = '请先完善当前用户邮箱信息!'
- return HttpResponse(json.dumps(result), content_type='application/json')
- bd = 'Archery 邮件发送测试...'
- subj = 'Archery 邮件发送测试'
- sender = MsgSender(server=mail_smtp_server, port=mail_smtp_port, user=mail_smtp_user,
- password=mail_smtp_password, ssl=mail_ssl)
+ result["status"] = 1
+ result["msg"] = "请先完善当前用户邮箱信息!"
+ return HttpResponse(json.dumps(result), content_type="application/json")
+ bd = "Archery 邮件发送测试..."
+ subj = "Archery 邮件发送测试"
+ sender = MsgSender(
+ server=mail_smtp_server,
+ port=mail_smtp_port,
+ user=mail_smtp_user,
+ password=mail_smtp_password,
+ ssl=mail_ssl,
+ )
sender_response = sender.send_email(subj, bd, [request.user.email])
- if sender_response != 'success':
- result['status'] = 1
- result['msg'] = sender_response
- return HttpResponse(json.dumps(result), content_type='application/json')
- return HttpResponse(json.dumps(result), content_type='application/json')
+ if sender_response != "success":
+ result["status"] = 1
+ result["msg"] = sender_response
+ return HttpResponse(json.dumps(result), content_type="application/json")
+ return HttpResponse(json.dumps(result), content_type="application/json")
# 检测实例配置
@superuser_required
def instance(request):
- result = {'status': 0, 'msg': 'ok', 'data': []}
- instance_id = request.POST.get('instance_id')
+ result = {"status": 0, "msg": "ok", "data": []}
+ instance_id = request.POST.get("instance_id")
instance = Instance.objects.get(id=instance_id)
try:
engine = get_engine(instance=instance)
test_result = engine.test_connection()
if test_result.error:
- result['status'] = 1
- result['msg'] = '无法连接实例,\n{}'.format(test_result.error)
+ result["status"] = 1
+ result["msg"] = "无法连接实例,\n{}".format(test_result.error)
except Exception as e:
- result['status'] = 1
- result['msg'] = '无法连接实例,\n{}'.format(str(e))
+ result["status"] = 1
+ result["msg"] = "无法连接实例,\n{}".format(str(e))
# 返回结果
- return HttpResponse(json.dumps(result), content_type='application/json')
+ return HttpResponse(json.dumps(result), content_type="application/json")
diff --git a/common/config.py b/common/config.py
index de4865563a..89475a8c52 100644
--- a/common/config.py
+++ b/common/config.py
@@ -9,7 +9,7 @@
from sql.models import Config
from django.db import transaction
-logger = logging.getLogger('default')
+logger = logging.getLogger("default")
class SysConfig(object):
@@ -20,14 +20,14 @@ def __init__(self):
def get_all_config(self):
try:
# 获取系统配置信息
- all_config = Config.objects.all().values('item', 'value')
+ all_config = Config.objects.all().values("item", "value")
sys_config = {}
for items in all_config:
- if items['value'] in ('true', 'True'):
- items['value'] = True
- elif items['value'] in ('false', 'False'):
- items['value'] = False
- sys_config[items['item']] = items['value']
+ if items["value"] in ("true", "True"):
+ items["value"] = True
+ elif items["value"] in ("false", "False"):
+ items["value"] = False
+ sys_config[items["item"]] = items["value"]
self.sys_config = sys_config
except Exception as m:
logger.error(f"获取系统配置信息失败:{m}{traceback.format_exc()}")
@@ -36,34 +36,41 @@ def get_all_config(self):
def get(self, key, default_value=None):
value = self.sys_config.get(key, default_value)
# 是字符串的话, 如果是空, 或者全是空格, 返回默认值
- if isinstance(value, str) and value.strip() == '':
+ if isinstance(value, str) and value.strip() == "":
return default_value
return value
def set(self, key, value):
if value is True:
- db_value = 'true'
+ db_value = "true"
elif value is False:
- db_value = 'false'
+ db_value = "false"
else:
db_value = value
- obj, created = Config.objects.update_or_create(item=key, defaults={"value": db_value})
+ obj, created = Config.objects.update_or_create(
+ item=key, defaults={"value": db_value}
+ )
if created:
self.sys_config.update({key: value})
def replace(self, configs):
- result = {'status': 0, 'msg': 'ok', 'data': []}
+ result = {"status": 0, "msg": "ok", "data": []}
# 清空并替换
try:
with transaction.atomic():
self.purge()
Config.objects.bulk_create(
- [Config(item=items['key'].strip(),
- value=str(items['value']).strip()) for items in json.loads(configs)])
+ [
+ Config(
+ item=items["key"].strip(), value=str(items["value"]).strip()
+ )
+ for items in json.loads(configs)
+ ]
+ )
except Exception as e:
logger.error(traceback.format_exc())
- result['status'] = 1
- result['msg'] = str(e)
+ result["status"] = 1
+ result["msg"] = str(e)
finally:
self.get_all_config()
return result
@@ -81,8 +88,8 @@ def purge(self):
# 修改系统配置
@superuser_required
def change_config(request):
- configs = request.POST.get('configs')
+ configs = request.POST.get("configs")
archer_config = SysConfig()
result = archer_config.replace(configs)
# 返回结果
- return HttpResponse(json.dumps(result), content_type='application/json')
+ return HttpResponse(json.dumps(result), content_type="application/json")
diff --git a/common/dashboard.py b/common/dashboard.py
index ee2d274a72..af7291a847 100644
--- a/common/dashboard.py
+++ b/common/dashboard.py
@@ -11,10 +11,10 @@
from pyecharts import options as opts
from pyecharts.charts import Pie, Bar, Line
-CurrentConfig.ONLINE_HOST = '/static/echarts/'
+CurrentConfig.ONLINE_HOST = "/static/echarts/"
-@permission_required('sql.menu_dashboard', raise_exception=True)
+@permission_required("sql.menu_dashboard", raise_exception=True)
def pyecharts(request):
"""dashboard view"""
# 工单数量统计
@@ -24,42 +24,44 @@ def pyecharts(request):
one_month_before = today - relativedelta(days=+30)
attr = chart_dao.get_date_list(one_month_before, today)
_dict = {}
- for row in data['rows']:
+ for row in data["rows"]:
_dict[row[0]] = row[1]
value = [_dict.get(day) if _dict.get(day) else 0 for day in attr]
- bar1 = Bar(init_opts=opts.InitOpts(width='600', height='380px'))
+ bar1 = Bar(init_opts=opts.InitOpts(width="600", height="380px"))
bar1.add_xaxis(attr)
bar1.add_yaxis("", value)
# 工单按组统计
data = chart_dao.workflow_by_group(30)
- attr = [row[0] for row in data['rows']]
- value = [row[1] for row in data['rows']]
- pie1 = Pie(init_opts=opts.InitOpts(width='600', height='380px'))
- pie1.set_global_opts(title_opts=opts.TitleOpts(title=''),
- legend_opts=opts.LegendOpts(
- orient="vertical", pos_top="15%", pos_left="2%", is_show=False
- ))
+ attr = [row[0] for row in data["rows"]]
+ value = [row[1] for row in data["rows"]]
+ pie1 = Pie(init_opts=opts.InitOpts(width="600", height="380px"))
+ pie1.set_global_opts(
+ title_opts=opts.TitleOpts(title=""),
+ legend_opts=opts.LegendOpts(
+ orient="vertical", pos_top="15%", pos_left="2%", is_show=False
+ ),
+ )
pie1.set_series_opts(label_opts=opts.LabelOpts(formatter="{b}: {c}"))
pie1.add("", [list(z) for z in zip(attr, value)]) if attr and data else None
# 工单按人统计
data = chart_dao.workflow_by_user(30)
- attr = [row[0] for row in data['rows']]
- value = [row[1] for row in data['rows']]
- bar2 = Bar(init_opts=opts.InitOpts(width='600', height='380px'))
+ attr = [row[0] for row in data["rows"]]
+ value = [row[1] for row in data["rows"]]
+ bar2 = Bar(init_opts=opts.InitOpts(width="600", height="380px"))
bar2.add_xaxis(attr)
bar2.add_yaxis("", value)
# SQL语句类型统计
data = chart_dao.syntax_type()
- attr = [row[0] for row in data['rows']]
- value = [row[1] for row in data['rows']]
+ attr = [row[0] for row in data["rows"]]
+ value = [row[1] for row in data["rows"]]
pie2 = Pie()
- pie2.set_global_opts(title_opts=opts.TitleOpts(title='SQL上线工单统计(类型)'),
- legend_opts=opts.LegendOpts(
- orient="vertical", pos_top="15%", pos_left="2%"
- ))
+ pie2.set_global_opts(
+ title_opts=opts.TitleOpts(title="SQL上线工单统计(类型)"),
+ legend_opts=opts.LegendOpts(orient="vertical", pos_top="15%", pos_left="2%"),
+ )
pie2.set_series_opts(label_opts=opts.LabelOpts(formatter="{b}: {c}"))
pie2.add("", [list(z) for z in zip(attr, value)]) if attr and data else None
@@ -67,65 +69,86 @@ def pyecharts(request):
attr = chart_dao.get_date_list(one_month_before, today)
effect_data = chart_dao.querylog_effect_row_by_date(30)
effect_dict = {}
- for row in effect_data['rows']:
+ for row in effect_data["rows"]:
effect_dict[row[0]] = int(row[1])
effect_value = [effect_dict.get(day) if effect_dict.get(day) else 0 for day in attr]
count_data = chart_dao.querylog_count_by_date(30)
count_dict = {}
- for row in count_data['rows']:
+ for row in count_data["rows"]:
count_dict[row[0]] = int(row[1])
count_value = [count_dict.get(day) if count_dict.get(day) else 0 for day in attr]
- line1 = Line(init_opts=opts.InitOpts(width='600', height='380px'))
- line1.set_global_opts(title_opts=opts.TitleOpts(title=''),
- legend_opts=opts.LegendOpts(selected_mode='single'))
+ line1 = Line(init_opts=opts.InitOpts(width="600", height="380px"))
+ line1.set_global_opts(
+ title_opts=opts.TitleOpts(title=""),
+ legend_opts=opts.LegendOpts(selected_mode="single"),
+ )
line1.add_xaxis(attr)
- line1.add_yaxis("检索行数", effect_value, is_smooth=True,
- markpoint_opts=opts.MarkPointOpts(data=[opts.MarkPointItem(type_="average")]))
- line1.add_yaxis("检索次数", count_value, is_smooth=True,
- markline_opts=opts.MarkLineOpts(data=[opts.MarkLineItem(type_="max"),
- opts.MarkLineItem(type_="average")]))
+ line1.add_yaxis(
+ "检索行数",
+ effect_value,
+ is_smooth=True,
+ markpoint_opts=opts.MarkPointOpts(data=[opts.MarkPointItem(type_="average")]),
+ )
+ line1.add_yaxis(
+ "检索次数",
+ count_value,
+ is_smooth=True,
+ markline_opts=opts.MarkLineOpts(
+ data=[opts.MarkLineItem(type_="max"), opts.MarkLineItem(type_="average")]
+ ),
+ )
# SQL查询统计(用户检索行数)
data = chart_dao.querylog_effect_row_by_user(30)
- attr = [row[0] for row in data['rows']]
- value = [int(row[1]) for row in data['rows']]
- pie4 = Pie(init_opts=opts.InitOpts(width='600', height='380px'))
- pie4.set_global_opts(title_opts=opts.TitleOpts(title=''),
- legend_opts=opts.LegendOpts(
- orient="vertical", pos_top="15%", pos_left="2%", is_show=False
- ))
+ attr = [row[0] for row in data["rows"]]
+ value = [int(row[1]) for row in data["rows"]]
+ pie4 = Pie(init_opts=opts.InitOpts(width="600", height="380px"))
+ pie4.set_global_opts(
+ title_opts=opts.TitleOpts(title=""),
+ legend_opts=opts.LegendOpts(
+ orient="vertical", pos_top="15%", pos_left="2%", is_show=False
+ ),
+ )
pie4.set_series_opts(label_opts=opts.LabelOpts(formatter="{b}: {c}"))
pie4.add("", [list(z) for z in zip(attr, value)]) if attr and data else None
# SQL查询统计(DB检索行数)
data = chart_dao.querylog_effect_row_by_db(30)
- attr = [row[0] for row in data['rows']]
- value = [int(row[1]) for row in data['rows']]
- pie5 = Pie(init_opts=opts.InitOpts(width='600', height='380px'))
- pie5.set_global_opts(title_opts=opts.TitleOpts(title=''),
- legend_opts=opts.LegendOpts(
- orient="vertical", pos_top="15%", pos_left="2%", is_show=False
- ))
- pie5.set_series_opts(label_opts=opts.LabelOpts(formatter="{b}: {c}", position="left"))
+ attr = [row[0] for row in data["rows"]]
+ value = [int(row[1]) for row in data["rows"]]
+ pie5 = Pie(init_opts=opts.InitOpts(width="600", height="380px"))
+ pie5.set_global_opts(
+ title_opts=opts.TitleOpts(title=""),
+ legend_opts=opts.LegendOpts(
+ orient="vertical", pos_top="15%", pos_left="2%", is_show=False
+ ),
+ )
+ pie5.set_series_opts(
+ label_opts=opts.LabelOpts(formatter="{b}: {c}", position="left")
+ )
pie5.add("", [list(z) for z in zip(attr, value)]) if attr and data else None
# 慢查询db/user维度统计(最近1天)
data = chart_dao.slow_query_count_by_db_by_user(1)
- attr = [row[0] for row in data['rows']]
- value = [int(row[1]) for row in data['rows']]
- pie3 = Pie(init_opts=opts.InitOpts(width='600',height='380px'))
- pie3.set_global_opts(title_opts=opts.TitleOpts(title=''),
- legend_opts=opts.LegendOpts(
- orient="vertical", pos_top="15%", pos_left="2%", is_show=False
- ))
- pie3.set_series_opts(label_opts=opts.LabelOpts(formatter="{b}: {c}", position="left"))
+ attr = [row[0] for row in data["rows"]]
+ value = [int(row[1]) for row in data["rows"]]
+ pie3 = Pie(init_opts=opts.InitOpts(width="600", height="380px"))
+ pie3.set_global_opts(
+ title_opts=opts.TitleOpts(title=""),
+ legend_opts=opts.LegendOpts(
+ orient="vertical", pos_top="15%", pos_left="2%", is_show=False
+ ),
+ )
+ pie3.set_series_opts(
+ label_opts=opts.LabelOpts(formatter="{b}: {c}", position="left")
+ )
pie3.add("", [list(z) for z in zip(attr, value)]) if attr and data else None
# 慢查询db维度统计(最近1天)
data = chart_dao.slow_query_count_by_db(1)
- attr = [row[0] for row in data['rows']]
- value = [row[1] for row in data['rows']]
- bar3 = Bar(init_opts=opts.InitOpts(width='600', height='380px'))
+ attr = [row[0] for row in data["rows"]]
+ value = [row[1] for row in data["rows"]]
+ bar3 = Bar(init_opts=opts.InitOpts(width="600", height="380px"))
bar3.add_xaxis(attr)
bar3.add_yaxis("", value)
@@ -147,7 +170,11 @@ def pyecharts(request):
"sql_wf_cnt": SqlWorkflow.objects.count(),
"query_wf_cnt": QueryPrivilegesApply.objects.count(),
"user_cnt": Users.objects.count(),
- "ins_cnt": Instance.objects.count()
+ "ins_cnt": Instance.objects.count(),
}
- return render(request, "dashboard.html", {"chart": chart, "count_stats": dashboard_count_stats})
+ return render(
+ request,
+ "dashboard.html",
+ {"chart": chart, "count_stats": dashboard_count_stats},
+ )
diff --git a/common/middleware/check_login_middleware.py b/common/middleware/check_login_middleware.py
index 16973b6385..59ca7876f4 100644
--- a/common/middleware/check_login_middleware.py
+++ b/common/middleware/check_login_middleware.py
@@ -3,15 +3,9 @@
from django.http import HttpResponseRedirect
from django.utils.deprecation import MiddlewareMixin
-IGNORE_URL = [
- '/login/',
- '/login/2fa/',
- '/authenticate/',
- '/signup/',
- '/api/info'
-]
+IGNORE_URL = ["/login/", "/login/2fa/", "/authenticate/", "/signup/", "/api/info"]
-IGNORE_URL_RE = r'/api/(v1|auth)/\w+'
+IGNORE_URL_RE = r"/api/(v1|auth)/\w+"
class CheckLoginMiddleware(MiddlewareMixin):
@@ -22,6 +16,12 @@ def process_request(request):
"""
if not request.user.is_authenticated:
# 以下是不用跳转到login页面的url白名单
- if request.path not in IGNORE_URL and re.match(IGNORE_URL_RE, request.path) is None \
- and not (re.match(r'/user/qrcode/\w+', request.path) and request.session.get('user')):
- return HttpResponseRedirect('/login/')
+ if (
+ request.path not in IGNORE_URL
+ and re.match(IGNORE_URL_RE, request.path) is None
+ and not (
+ re.match(r"/user/qrcode/\w+", request.path)
+ and request.session.get("user")
+ )
+ ):
+ return HttpResponseRedirect("/login/")
diff --git a/common/middleware/exception_logging_middleware.py b/common/middleware/exception_logging_middleware.py
index 835993c5e6..cfcfa46f27 100644
--- a/common/middleware/exception_logging_middleware.py
+++ b/common/middleware/exception_logging_middleware.py
@@ -2,10 +2,11 @@
import logging
from django.utils.deprecation import MiddlewareMixin
-logger = logging.getLogger('default')
+logger = logging.getLogger("default")
class ExceptionLoggingMiddleware(MiddlewareMixin):
def process_exception(self, request, exception):
import traceback
+
logger.error(traceback.format_exc())
diff --git a/common/storage.py b/common/storage.py
index 60097961a0..d4aeeb3500 100644
--- a/common/storage.py
+++ b/common/storage.py
@@ -6,7 +6,7 @@
@time: 2019/06/01
"""
-__author__ = 'hhyo'
+__author__ = "hhyo"
from django.contrib.staticfiles.storage import ManifestStaticFilesStorage
diff --git a/common/tests.py b/common/tests.py
index f32eb55714..da30292cd6 100644
--- a/common/tests.py
+++ b/common/tests.py
@@ -8,7 +8,13 @@
from common.config import SysConfig
from common.utils.sendmsg import MsgSender
from sql.engines import EngineBase, ResultSet
-from sql.models import Instance, SqlWorkflow, SqlWorkflowContent, QueryLog, ResourceGroup
+from sql.models import (
+ Instance,
+ SqlWorkflow,
+ SqlWorkflowContent,
+ QueryLog,
+ ResourceGroup,
+)
from common.utils.chart_dao import ChartDao
from common.auth import init_user
@@ -21,7 +27,7 @@ def setUp(self):
def test_purge(self):
archer_config = SysConfig()
- archer_config.set('some_key', 'some_value')
+ archer_config.set("some_key", "some_value")
archer_config.purge()
self.assertEqual({}, archer_config.sys_config)
archer_config2 = SysConfig()
@@ -30,39 +36,42 @@ def test_purge(self):
def test_replace_configs(self):
archer_config = SysConfig()
new_config = json.dumps(
- [{'key': 'numconfig', 'value': 1},
- {'key': 'strconfig', 'value': 'strconfig'},
- {'key': 'boolconfig', 'value': 'false'}])
+ [
+ {"key": "numconfig", "value": 1},
+ {"key": "strconfig", "value": "strconfig"},
+ {"key": "boolconfig", "value": "false"},
+ ]
+ )
archer_config.replace(new_config)
archer_config.get_all_config()
expected_config = {
- 'numconfig': '1',
- 'strconfig': 'strconfig',
- 'boolconfig': False
+ "numconfig": "1",
+ "strconfig": "strconfig",
+ "boolconfig": False,
}
self.assertEqual(archer_config.sys_config, expected_config)
def test_get_bool_transform(self):
- bool_config = json.dumps([{'key': 'boolconfig2', 'value': 'false'}])
+ bool_config = json.dumps([{"key": "boolconfig2", "value": "false"}])
archer_config = SysConfig()
archer_config.replace(bool_config)
- self.assertEqual(archer_config.sys_config['boolconfig2'], False)
+ self.assertEqual(archer_config.sys_config["boolconfig2"], False)
def test_set_bool_transform(self):
archer_config = SysConfig()
- archer_config.set('boolconfig3', False)
- self.assertEqual(archer_config.sys_config['boolconfig3'], False)
+ archer_config.set("boolconfig3", False)
+ self.assertEqual(archer_config.sys_config["boolconfig3"], False)
def test_get_other_data(self):
- new_config = json.dumps([{'key': 'other_config', 'value': 'testvalue'}])
+ new_config = json.dumps([{"key": "other_config", "value": "testvalue"}])
archer_config = SysConfig()
archer_config.replace(new_config)
- self.assertEqual(archer_config.sys_config['other_config'], 'testvalue')
+ self.assertEqual(archer_config.sys_config["other_config"], "testvalue")
def test_set_other_data(self):
archer_config = SysConfig()
- archer_config.set('other_config', 'testvalue3')
- self.assertEqual(archer_config.sys_config['other_config'], 'testvalue3')
+ archer_config.set("other_config", "testvalue3")
+ self.assertEqual(archer_config.sys_config["other_config"], "testvalue3")
class SendMessageTest(TestCase):
@@ -70,74 +79,74 @@ class SendMessageTest(TestCase):
def setUp(self):
archer_config = SysConfig()
- self.smtp_server = 'test_smtp_server'
- self.smtp_user = 'test_smtp_user'
- self.smtp_password = 'some_str'
+ self.smtp_server = "test_smtp_server"
+ self.smtp_user = "test_smtp_user"
+ self.smtp_password = "some_str"
self.smtp_port = 1234
self.smtp_ssl = True
- archer_config.set('mail_smtp_server', self.smtp_server)
- archer_config.set('mail_smtp_user', self.smtp_user)
- archer_config.set('mail_smtp_password', self.smtp_password)
- archer_config.set('mail_smtp_port', self.smtp_port)
- archer_config.set('mail_ssl', self.smtp_ssl)
+ archer_config.set("mail_smtp_server", self.smtp_server)
+ archer_config.set("mail_smtp_user", self.smtp_user)
+ archer_config.set("mail_smtp_password", self.smtp_password)
+ archer_config.set("mail_smtp_port", self.smtp_port)
+ archer_config.set("mail_ssl", self.smtp_ssl)
def testSenderInit(self):
sender = MsgSender()
self.assertEqual(sender.MAIL_REVIEW_SMTP_PORT, self.smtp_port)
archer_config = SysConfig()
- archer_config.set('mail_smtp_port', '')
+ archer_config.set("mail_smtp_port", "")
sender = MsgSender()
self.assertEqual(sender.MAIL_REVIEW_SMTP_PORT, 465)
- archer_config.set('mail_ssl', False)
+ archer_config.set("mail_ssl", False)
sender = MsgSender()
self.assertEqual(sender.MAIL_REVIEW_SMTP_PORT, 25)
- @patch.object(smtplib.SMTP, '__init__', return_value=None)
- @patch.object(smtplib.SMTP, 'login')
- @patch.object(smtplib.SMTP, 'sendmail')
- @patch.object(smtplib.SMTP, 'quit')
+ @patch.object(smtplib.SMTP, "__init__", return_value=None)
+ @patch.object(smtplib.SMTP, "login")
+ @patch.object(smtplib.SMTP, "sendmail")
+ @patch.object(smtplib.SMTP, "quit")
def testNoPasswordSendMail(self, _quit, sendmail, login, _):
"""无密码测试"""
- some_sub = 'test_subject'
- some_body = 'mail_body'
- some_to = ['mail_to']
+ some_sub = "test_subject"
+ some_body = "mail_body"
+ some_to = ["mail_to"]
archer_config = SysConfig()
- archer_config.set('mail_ssl', '')
+ archer_config.set("mail_ssl", "")
- archer_config.set('mail_smtp_password', '')
+ archer_config.set("mail_smtp_password", "")
sender2 = MsgSender()
sender2.send_email(some_sub, some_body, some_to)
login.assert_not_called()
- @patch.object(smtplib.SMTP, '__init__', return_value=None)
- @patch.object(smtplib.SMTP, 'login')
- @patch.object(smtplib.SMTP, 'sendmail')
- @patch.object(smtplib.SMTP, 'quit')
+ @patch.object(smtplib.SMTP, "__init__", return_value=None)
+ @patch.object(smtplib.SMTP, "login")
+ @patch.object(smtplib.SMTP, "sendmail")
+ @patch.object(smtplib.SMTP, "quit")
def testSendMail(self, _quit, sendmail, login, _):
"""有密码测试"""
- some_sub = 'test_subject'
- some_body = 'mail_body'
- some_to = ['mail_to']
+ some_sub = "test_subject"
+ some_body = "mail_body"
+ some_to = ["mail_to"]
archer_config = SysConfig()
- archer_config.set('mail_ssl', '')
- archer_config.set('mail_smtp_password', self.smtp_password)
+ archer_config.set("mail_ssl", "")
+ archer_config.set("mail_smtp_password", self.smtp_password)
sender = MsgSender()
sender.send_email(some_sub, some_body, some_to)
login.assert_called_once()
sendmail.assert_called_with(self.smtp_user, some_to, ANY)
_quit.assert_called_once()
- @patch.object(smtplib.SMTP, '__init__', return_value=None)
- @patch.object(smtplib.SMTP, 'login')
- @patch.object(smtplib.SMTP, 'sendmail')
- @patch.object(smtplib.SMTP, 'quit')
+ @patch.object(smtplib.SMTP, "__init__", return_value=None)
+ @patch.object(smtplib.SMTP, "login")
+ @patch.object(smtplib.SMTP, "sendmail")
+ @patch.object(smtplib.SMTP, "quit")
def testSSLSendMail(self, _quit, sendmail, login, _):
"""SSL 测试"""
- some_sub = 'test_subject'
- some_body = 'mail_body'
- some_to = ['mail_to']
+ some_sub = "test_subject"
+ some_body = "mail_body"
+ some_to = ["mail_to"]
archer_config = SysConfig()
- archer_config.set('mail_ssl', True)
+ archer_config.set("mail_ssl", True)
sender = MsgSender()
sender.send_email(some_sub, some_body, some_to)
sendmail.assert_called_with(self.smtp_user, some_to, ANY)
@@ -145,36 +154,33 @@ def testSSLSendMail(self, _quit, sendmail, login, _):
def tearDown(self):
archer_config = SysConfig()
- archer_config.set('mail_smtp_server', '')
- archer_config.set('mail_smtp_user', '')
- archer_config.set('mail_smtp_password', '')
- archer_config.set('mail_smtp_port', '')
- archer_config.set('mail_ssl', '')
+ archer_config.set("mail_smtp_server", "")
+ archer_config.set("mail_smtp_user", "")
+ archer_config.set("mail_smtp_password", "")
+ archer_config.set("mail_smtp_port", "")
+ archer_config.set("mail_ssl", "")
class DingTest(TestCase):
-
def setUp(self):
- self.url = 'some_url'
- self.content = 'some_content'
+ self.url = "some_url"
+ self.content = "some_content"
- @patch('requests.post')
+ @patch("requests.post")
def testDing(self, post):
sender = MsgSender()
- post.return_value.json.return_value = {'errcode': 0}
- with self.assertLogs('default', level='DEBUG') as lg:
+ post.return_value.json.return_value = {"errcode": 0}
+ with self.assertLogs("default", level="DEBUG") as lg:
sender.send_ding(self.url, self.content)
- post.assert_called_once_with(url=self.url, json={
- 'msgtype': 'text',
- 'text': {
- 'content': self.content
- }
- })
- self.assertIn('钉钉Webhook推送成功', lg.output[0])
- post.return_value.json.return_value = {'errcode': 1, 'errmsg': 'test_error'}
- with self.assertLogs('default', level='ERROR') as lg:
+ post.assert_called_once_with(
+ url=self.url,
+ json={"msgtype": "text", "text": {"content": self.content}},
+ )
+ self.assertIn("钉钉Webhook推送成功", lg.output[0])
+ post.return_value.json.return_value = {"errcode": 1, "errmsg": "test_error"}
+ with self.assertLogs("default", level="ERROR") as lg:
sender.send_ding(self.url, self.content)
- self.assertIn('test_error', lg.output[0])
+ self.assertIn("test_error", lg.output[0])
def tearDown(self):
pass
@@ -182,26 +188,26 @@ def tearDown(self):
class GlobalInfoTest(TestCase):
def setUp(self):
- self.u1 = User(username='test_user', display='中文显示', is_active=True)
+ self.u1 = User(username="test_user", display="中文显示", is_active=True)
self.u1.save()
- @patch('sql.utils.workflow_audit.Audit.todo')
+ @patch("sql.utils.workflow_audit.Audit.todo")
def testGlobalInfo(self, todo):
"""测试"""
c = Client()
- r = c.get('/', follow=True)
+ r = c.get("/", follow=True)
todo.assert_not_called()
- self.assertEqual(r.context['todo'], 0)
+ self.assertEqual(r.context["todo"], 0)
# 已登录用户
c.force_login(self.u1)
todo.return_value = 3
- r = c.get('/', follow=True)
+ r = c.get("/", follow=True)
todo.assert_called_once_with(self.u1)
- self.assertEqual(r.context['todo'], 3)
+ self.assertEqual(r.context["todo"], 3)
# 报异常
- todo.side_effect = NameError('some exception')
- r = c.get('/', follow=True)
- self.assertEqual(r.context['todo'], 0)
+ todo.side_effect = NameError("some exception")
+ r = c.get("/", follow=True)
+ self.assertEqual(r.context["todo"], 0)
def tearDown(self):
self.u1.delete()
@@ -211,120 +217,153 @@ class CheckTest(TestCase):
"""检查功能测试"""
def setUp(self):
- self.superuser1 = User(username='test_user', display='中文显示', is_active=True, is_superuser=True,
- email='XXX@xxx.com')
+ self.superuser1 = User(
+ username="test_user",
+ display="中文显示",
+ is_active=True,
+ is_superuser=True,
+ email="XXX@xxx.com",
+ )
self.superuser1.save()
- self.slave1 = Instance(instance_name='some_name', host='some_host', type='slave', db_type='mysql',
- user='some_user', port=1234, password='some_str')
+ self.slave1 = Instance(
+ instance_name="some_name",
+ host="some_host",
+ type="slave",
+ db_type="mysql",
+ user="some_user",
+ port=1234,
+ password="some_str",
+ )
self.slave1.save()
def tearDown(self):
self.superuser1.delete()
- @patch.object(MsgSender, '__init__', return_value=None)
- @patch.object(MsgSender, 'send_email')
+ @patch.object(MsgSender, "__init__", return_value=None)
+ @patch.object(MsgSender, "send_email")
def testEmailCheck(self, send_email, mailsender):
"""邮箱配置检查"""
- mail_switch = 'true'
- smtp_ssl = 'false'
- smtp_server = 'some_server'
- smtp_port = '1234'
- smtp_user = 'some_user'
- smtp_pass = 'some_str'
+ mail_switch = "true"
+ smtp_ssl = "false"
+ smtp_server = "some_server"
+ smtp_port = "1234"
+ smtp_user = "some_user"
+ smtp_pass = "some_str"
# 略过superuser校验
# 未开启mail开关
- mail_switch = 'false'
+ mail_switch = "false"
c = Client()
c.force_login(self.superuser1)
- r = c.post('/check/email/', data={
- 'mail': mail_switch,
- 'mail_ssl': smtp_ssl,
- 'mail_smtp_server': smtp_server,
- 'mail_smtp_port': smtp_port,
- 'mail_smtp_user': smtp_user,
- 'mail_smtp_password': smtp_pass
- })
+ r = c.post(
+ "/check/email/",
+ data={
+ "mail": mail_switch,
+ "mail_ssl": smtp_ssl,
+ "mail_smtp_server": smtp_server,
+ "mail_smtp_port": smtp_port,
+ "mail_smtp_user": smtp_user,
+ "mail_smtp_password": smtp_pass,
+ },
+ )
r_json = r.json()
- self.assertEqual(r_json['status'], 1)
- self.assertEqual(r_json['msg'], '请先开启邮件通知!')
- mail_switch = 'true'
+ self.assertEqual(r_json["status"], 1)
+ self.assertEqual(r_json["msg"], "请先开启邮件通知!")
+ mail_switch = "true"
# 填写非正整数端口号
- smtp_port = '-3'
- r = c.post('/check/email/', data={
- 'mail': mail_switch,
- 'mail_ssl': smtp_ssl,
- 'mail_smtp_server': smtp_server,
- 'mail_smtp_port': smtp_port,
- 'mail_smtp_user': smtp_user,
- 'mail_smtp_password': smtp_pass
- })
+ smtp_port = "-3"
+ r = c.post(
+ "/check/email/",
+ data={
+ "mail": mail_switch,
+ "mail_ssl": smtp_ssl,
+ "mail_smtp_server": smtp_server,
+ "mail_smtp_port": smtp_port,
+ "mail_smtp_user": smtp_user,
+ "mail_smtp_password": smtp_pass,
+ },
+ )
r_json = r.json()
- self.assertEqual(r_json['status'], 1)
- self.assertEqual(r_json['msg'], '端口号只能为正整数')
- smtp_port = '1234'
+ self.assertEqual(r_json["status"], 1)
+ self.assertEqual(r_json["msg"], "端口号只能为正整数")
+ smtp_port = "1234"
# 未填写用户邮箱
- self.superuser1.email = ''
+ self.superuser1.email = ""
self.superuser1.save()
- r = c.post('/check/email/', data={
- 'mail': mail_switch,
- 'mail_ssl': smtp_ssl,
- 'mail_smtp_server': smtp_server,
- 'mail_smtp_port': smtp_port,
- 'mail_smtp_user': smtp_user,
- 'mail_smtp_password': smtp_pass
- })
+ r = c.post(
+ "/check/email/",
+ data={
+ "mail": mail_switch,
+ "mail_ssl": smtp_ssl,
+ "mail_smtp_server": smtp_server,
+ "mail_smtp_port": smtp_port,
+ "mail_smtp_user": smtp_user,
+ "mail_smtp_password": smtp_pass,
+ },
+ )
r_json = r.json()
- self.assertEqual(r_json['status'], 1)
- self.assertEqual(r_json['msg'], '请先完善当前用户邮箱信息!')
- self.superuser1.email = 'XXX@xxx.com'
+ self.assertEqual(r_json["status"], 1)
+ self.assertEqual(r_json["msg"], "请先完善当前用户邮箱信息!")
+ self.superuser1.email = "XXX@xxx.com"
self.superuser1.save()
# 发送失败, 显示traceback
- send_email.return_value = 'some traceback'
- r = c.post('/check/email/', data={
- 'mail': mail_switch,
- 'mail_ssl': smtp_ssl,
- 'mail_smtp_server': smtp_server,
- 'mail_smtp_port': smtp_port,
- 'mail_smtp_user': smtp_user,
- 'mail_smtp_password': smtp_pass
- })
+ send_email.return_value = "some traceback"
+ r = c.post(
+ "/check/email/",
+ data={
+ "mail": mail_switch,
+ "mail_ssl": smtp_ssl,
+ "mail_smtp_server": smtp_server,
+ "mail_smtp_port": smtp_port,
+ "mail_smtp_user": smtp_user,
+ "mail_smtp_password": smtp_pass,
+ },
+ )
r_json = r.json()
- self.assertEqual(r_json['status'], 1)
- self.assertIn('some traceback', r_json['msg'])
+ self.assertEqual(r_json["status"], 1)
+ self.assertIn("some traceback", r_json["msg"])
send_email.reset_mock() # 重置``Mock``的调用计数
mailsender.reset_mock()
# 发送成功
- send_email.return_value = 'success'
- r = c.post('/check/email/', data={
- 'mail': mail_switch,
- 'mail_ssl': smtp_ssl,
- 'mail_smtp_server': smtp_server,
- 'mail_smtp_port': smtp_port,
- 'mail_smtp_user': smtp_user,
- 'mail_smtp_password': smtp_pass
- })
+ send_email.return_value = "success"
+ r = c.post(
+ "/check/email/",
+ data={
+ "mail": mail_switch,
+ "mail_ssl": smtp_ssl,
+ "mail_smtp_server": smtp_server,
+ "mail_smtp_port": smtp_port,
+ "mail_smtp_user": smtp_user,
+ "mail_smtp_password": smtp_pass,
+ },
+ )
r_json = r.json()
- mailsender.assert_called_once_with(server=smtp_server, port=int(smtp_port), user=smtp_user,
- password=smtp_pass, ssl=False)
- send_email.called_once_with('Archery 邮件发送测试', 'Archery 邮件发送测试...',
- [self.superuser1.email])
- self.assertEqual(r_json['status'], 0)
- self.assertEqual(r_json['msg'], 'ok')
-
- @patch('MySQLdb.connect')
- @patch('common.check.get_engine')
+ mailsender.assert_called_once_with(
+ server=smtp_server,
+ port=int(smtp_port),
+ user=smtp_user,
+ password=smtp_pass,
+ ssl=False,
+ )
+ send_email.called_once_with(
+ "Archery 邮件发送测试", "Archery 邮件发送测试...", [self.superuser1.email]
+ )
+ self.assertEqual(r_json["status"], 0)
+ self.assertEqual(r_json["msg"], "ok")
+
+ @patch("MySQLdb.connect")
+ @patch("common.check.get_engine")
def testInstanceCheck(self, _get_engine, _conn):
_get_engine.return_value.get_connection = _conn
- _get_engine.return_value.get_all_databases.return_value.rows.return_value = ResultSet(
- rows=((),),
- error='Wrong password')
+ _get_engine.return_value.get_all_databases.return_value.rows.return_value = (
+ ResultSet(rows=((),), error="Wrong password")
+ )
c = Client()
c.force_login(self.superuser1)
- r = c.post('/check/instance/', data={'instance_id': self.slave1.id})
+ r = c.post("/check/instance/", data={"instance_id": self.slave1.id})
r_json = r.json()
- self.assertEqual(r_json['status'], 1)
+ self.assertEqual(r_json["status"], 1)
- @patch('MySQLdb.connect')
+ @patch("MySQLdb.connect")
def test_go_inception_check(self, _conn):
c = Client()
c.force_login(self.superuser1)
@@ -334,11 +373,11 @@ def test_go_inception_check(self, _conn):
"inception_remote_backup_host": "mysql",
"inception_remote_backup_port": 3306,
"inception_remote_backup_user": "mysql",
- "inception_remote_backup_password": "123456"
+ "inception_remote_backup_password": "123456",
}
- r = c.post('/check/go_inception/', data=data)
+ r = c.post("/check/go_inception/", data=data)
r_json = r.json()
- self.assertEqual(r_json['status'], 0)
+ self.assertEqual(r_json["status"], 0)
class ChartTest(TestCase):
@@ -346,57 +385,78 @@ class ChartTest(TestCase):
@classmethod
def setUpClass(cls):
- cls.u1 = User(username='some_user', display='用户1')
+ cls.u1 = User(username="some_user", display="用户1")
cls.u1.save()
- cls.u2 = User(username='some_other_user', display='用户2')
+ cls.u2 = User(username="some_other_user", display="用户2")
cls.u2.save()
- cls.superuser1 = User(username='super1', is_superuser=True)
+ cls.superuser1 = User(username="super1", is_superuser=True)
cls.superuser1.save()
cls.now = datetime.datetime.now()
- cls.slave1 = Instance(instance_name='test_slave_instance', type='slave', db_type='mysql',
- host='testhost', port=3306, user='mysql_user', password='mysql_password')
+ cls.slave1 = Instance(
+ instance_name="test_slave_instance",
+ type="slave",
+ db_type="mysql",
+ host="testhost",
+ port=3306,
+ user="mysql_user",
+ password="mysql_password",
+ )
cls.slave1.save()
# 批量创建数据 ddl ,u1 ,g1, yesterday 组, 2 个数据
- ddl_workflow = [SqlWorkflow(
- workflow_name='ddl %s' % i,
- group_id=1,
- group_name='g1',
- engineer=cls.u1.username,
- engineer_display=cls.u1.display,
- audit_auth_groups='some_group',
- create_time=cls.now - datetime.timedelta(days=1),
- status='workflow_finish',
- is_backup=True,
- instance=cls.slave1,
- db_name='some_db',
- syntax_type=1
- ) for i in range(2)]
+ ddl_workflow = [
+ SqlWorkflow(
+ workflow_name="ddl %s" % i,
+ group_id=1,
+ group_name="g1",
+ engineer=cls.u1.username,
+ engineer_display=cls.u1.display,
+ audit_auth_groups="some_group",
+ create_time=cls.now - datetime.timedelta(days=1),
+ status="workflow_finish",
+ is_backup=True,
+ instance=cls.slave1,
+ db_name="some_db",
+ syntax_type=1,
+ )
+ for i in range(2)
+ ]
# 批量创建数据 dml ,u1 ,g2, the day before yesterday 组, 3 个数据
- dml_workflow = [SqlWorkflow(
- workflow_name='Test %s' % i,
- group_id=2,
- group_name='g2',
- engineer=cls.u2.username,
- engineer_display=cls.u2.display,
- audit_auth_groups='some_group',
- create_time=cls.now - datetime.timedelta(days=2),
- status='workflow_finish',
- is_backup=True,
- instance=cls.slave1,
- db_name='some_db',
- syntax_type=2
- ) for i in range(3)]
+ dml_workflow = [
+ SqlWorkflow(
+ workflow_name="Test %s" % i,
+ group_id=2,
+ group_name="g2",
+ engineer=cls.u2.username,
+ engineer_display=cls.u2.display,
+ audit_auth_groups="some_group",
+ create_time=cls.now - datetime.timedelta(days=2),
+ status="workflow_finish",
+ is_backup=True,
+ instance=cls.slave1,
+ db_name="some_db",
+ syntax_type=2,
+ )
+ for i in range(3)
+ ]
SqlWorkflow.objects.bulk_create(ddl_workflow + dml_workflow)
# 保存内容数据
- ddl_workflow_content = [SqlWorkflowContent(
- workflow=SqlWorkflow.objects.get(workflow_name='ddl %s' % i),
- sql_content='some_sql',
- ) for i in range(2)]
- dml_workflow_content = [SqlWorkflowContent(
- workflow=SqlWorkflow.objects.get(workflow_name='Test %s' % i),
- sql_content='some_sql',
- ) for i in range(3)]
- SqlWorkflowContent.objects.bulk_create(ddl_workflow_content + dml_workflow_content)
+ ddl_workflow_content = [
+ SqlWorkflowContent(
+ workflow=SqlWorkflow.objects.get(workflow_name="ddl %s" % i),
+ sql_content="some_sql",
+ )
+ for i in range(2)
+ ]
+ dml_workflow_content = [
+ SqlWorkflowContent(
+ workflow=SqlWorkflow.objects.get(workflow_name="Test %s" % i),
+ sql_content="some_sql",
+ )
+ for i in range(3)
+ ]
+ SqlWorkflowContent.objects.bulk_create(
+ ddl_workflow_content + dml_workflow_content
+ )
# query_logs = [QueryLog(
# instance_name = 'some_instance',
@@ -419,35 +479,35 @@ def testGetDateList(self):
begin = end - datetime.timedelta(days=3)
result = dao.get_date_list(begin, end)
self.assertEqual(len(result), 4)
- self.assertEqual(result[0], begin.strftime('%Y-%m-%d'))
- self.assertEqual(result[-1], end.strftime('%Y-%m-%d'))
+ self.assertEqual(result[0], begin.strftime("%Y-%m-%d"))
+ self.assertEqual(result[-1], end.strftime("%Y-%m-%d"))
def testSyntaxList(self):
"""工单以语法类型分组"""
dao = ChartDao()
- expected_rows = (('DDL', 2), ('DML', 3))
+ expected_rows = (("DDL", 2), ("DML", 3))
result = dao.syntax_type()
- self.assertEqual(result['rows'], expected_rows)
+ self.assertEqual(result["rows"], expected_rows)
def testWorkflowByDate(self):
"""TODO 按日分组工单数量统计测试"""
dao = ChartDao()
result = dao.workflow_by_date(30)
- self.assertEqual(len(result['rows'][0]), 2)
+ self.assertEqual(len(result["rows"][0]), 2)
def testWorkflowByGroup(self):
"""按组统计测试"""
dao = ChartDao()
result = dao.workflow_by_group(30)
- expected_rows = (('g2', 3), ('g1', 2))
- self.assertEqual(result['rows'], expected_rows)
+ expected_rows = (("g2", 3), ("g1", 2))
+ self.assertEqual(result["rows"], expected_rows)
def testWorkflowByUser(self):
"""按用户统计测试"""
dao = ChartDao()
result = dao.workflow_by_user(30)
expected_rows = ((self.u2.display, 3), (self.u1.display, 2))
- self.assertEqual(result['rows'], expected_rows)
+ self.assertEqual(result["rows"], expected_rows)
def testDashboard(self):
"""Dashboard测试"""
@@ -455,20 +515,19 @@ def testDashboard(self):
# TODO 需要具体查看pyecharst有没有被调用, 以及调用的参数
c = Client()
c.force_login(self.superuser1)
- r = c.get('/dashboard/')
+ r = c.get("/dashboard/")
self.assertEqual(r.status_code, 200)
class AuthTest(TestCase):
-
def setUp(self):
- self.username = 'some_user'
- self.password = 'some_str'
- self.u1 = User(username=self.username, password=self.password, display='用户1')
+ self.username = "some_user"
+ self.password = "some_str"
+ self.u1 = User(username=self.username, password=self.password, display="用户1")
self.u1.save()
- self.resource_group1 = ResourceGroup.objects.create(group_name='some_group')
+ self.resource_group1 = ResourceGroup.objects.create(group_name="some_group")
sys_config = SysConfig()
- sys_config.set('default_resource_group', self.resource_group1.group_name)
+ sys_config.set("default_resource_group", self.resource_group1.group_name)
def tearDown(self):
self.u1.delete()
diff --git a/common/twofa/__init__.py b/common/twofa/__init__.py
index 58b8e75af0..4dfdcdd6f8 100644
--- a/common/twofa/__init__.py
+++ b/common/twofa/__init__.py
@@ -2,7 +2,6 @@
class TwoFactorAuthBase:
-
def __init__(self, user=None):
self.user = user
@@ -14,16 +13,18 @@ def verify(self, otp):
def enable(self):
"""启用"""
- result = {'status': 1, 'msg': 'failed'}
+ result = {"status": 1, "msg": "failed"}
return result
def disable(self, auth_type):
"""禁用"""
- result = {'status': 0, 'msg': 'ok'}
+ result = {"status": 0, "msg": "ok"}
try:
- TwoFactorAuthConfig.objects.get(user=self.user, auth_type=auth_type).delete()
+ TwoFactorAuthConfig.objects.get(
+ user=self.user, auth_type=auth_type
+ ).delete()
except TwoFactorAuthConfig.DoesNotExist as e:
- result = {'status': 0, 'msg': str(e)}
+ result = {"status": 0, "msg": str(e)}
return result
@property
diff --git a/common/twofa/sms.py b/common/twofa/sms.py
index 63e01cdeac..7e69f05415 100644
--- a/common/twofa/sms.py
+++ b/common/twofa/sms.py
@@ -9,7 +9,7 @@
import json
import time
-logger = logging.getLogger('default')
+logger = logging.getLogger("default")
class SMS(TwoFactorAuthBase):
@@ -19,61 +19,65 @@ def __init__(self, user=None):
super(SMS, self).__init__(user=user)
self.user = user
- sms_provider = SysConfig().get('sms_provider', 'disabled')
- if sms_provider == 'aliyun':
+ sms_provider = SysConfig().get("sms_provider", "disabled")
+ if sms_provider == "aliyun":
from common.utils.aliyun_sms import AliyunSMS
+
self.client = AliyunSMS()
- elif sms_provider == 'tencent':
+ elif sms_provider == "tencent":
from common.utils.tencent_sms import TencentSMS
+
self.client = TencentSMS()
else:
self.client = None
def get_captcha(self, **kwargs):
"""获取验证码"""
- result = {'status': 0, 'msg': 'ok'}
- r = get_redis_connection('default')
+ result = {"status": 0, "msg": "ok"}
+ r = get_redis_connection("default")
data = r.get(f"captcha-{kwargs['phone']}")
if data:
- captcha = json.loads(data.decode('utf8'))
- if int(time.time()) - captcha['update_time'] > 60:
+ captcha = json.loads(data.decode("utf8"))
+ if int(time.time()) - captcha["update_time"] > 60:
if self.client:
result = self.client.send_code(**kwargs)
else:
- result = {'status': 1, 'msg': '系统未配置短信服务商!'}
+ result = {"status": 1, "msg": "系统未配置短信服务商!"}
else:
- result['status'] = 1
- result['msg'] = f"获取验证码太频繁,请于{captcha['update_time'] - int(time.time()) + 60}秒后再试"
+ result["status"] = 1
+ result[
+ "msg"
+ ] = f"获取验证码太频繁,请于{captcha['update_time'] - int(time.time()) + 60}秒后再试"
else:
if self.client:
result = self.client.send_code(**kwargs)
else:
- result = {'status': 1, 'msg': '系统未配置短信服务商!'}
+ result = {"status": 1, "msg": "系统未配置短信服务商!"}
return result
def verify(self, otp, phone=None):
"""校验验证码"""
- result = {'status': 0, 'msg': 'ok'}
+ result = {"status": 0, "msg": "ok"}
if phone:
phone = phone
else:
phone = TwoFactorAuthConfig.objects.get(username=self.user.username).phone
- r = get_redis_connection('default')
- data = r.get(f'captcha-{phone}')
+ r = get_redis_connection("default")
+ data = r.get(f"captcha-{phone}")
if not data:
- result['status'] = 1
- result['msg'] = '未获取验证码或验证码已过期!'
+ result["status"] = 1
+ result["msg"] = "未获取验证码或验证码已过期!"
else:
- captcha = json.loads(data.decode('utf8'))
- if otp != captcha['otp']:
- result['status'] = 1
- result['msg'] = '验证码不正确!'
+ captcha = json.loads(data.decode("utf8"))
+ if otp != captcha["otp"]:
+ result["status"] = 1
+ result["msg"] = "验证码不正确!"
return result
def save(self, phone):
"""保存2fa配置"""
- result = {'status': 0, 'msg': 'ok'}
+ result = {"status": 0, "msg": "ok"}
try:
with transaction.atomic():
@@ -84,11 +88,11 @@ def save(self, phone):
username=self.user.username,
auth_type=self.auth_type,
phone=phone,
- user=self.user
+ user=self.user,
)
except Exception as msg:
- result['status'] = 1
- result['msg'] = str(msg)
+ result["status"] = 1
+ result["msg"] = str(msg)
logger.error(traceback.format_exc())
return result
diff --git a/common/twofa/totp.py b/common/twofa/totp.py
index 25a4aad2fb..c3a4f88bb1 100644
--- a/common/twofa/totp.py
+++ b/common/twofa/totp.py
@@ -8,7 +8,7 @@
import logging
import pyotp
-logger = logging.getLogger('default')
+logger = logging.getLogger("default")
class TOTP(TwoFactorAuthBase):
@@ -20,34 +20,32 @@ def __init__(self, user=None):
def verify(self, otp, key=None):
"""校验一次性验证码"""
- result = {'status': 0, 'msg': 'ok'}
+ result = {"status": 0, "msg": "ok"}
if key:
secret_key = key
else:
- secret_key = TwoFactorAuthConfig.objects.get(username=self.user.username,
- auth_type=self.auth_type).secret_key
+ secret_key = TwoFactorAuthConfig.objects.get(
+ username=self.user.username, auth_type=self.auth_type
+ ).secret_key
t = pyotp.TOTP(secret_key)
status = t.verify(otp)
- result['status'] = 0 if status else 1
- result['msg'] = 'ok' if status else '验证码不正确!'
+ result["status"] = 0 if status else 1
+ result["msg"] = "ok" if status else "验证码不正确!"
return result
def generate_key(self):
"""生成secret key"""
- result = {'status': 0, 'msg': 'ok', 'data': {}}
+ result = {"status": 0, "msg": "ok", "data": {}}
# 生成用户secret_key
secret_key = pyotp.random_base32(32)
- result['data'] = {
- 'auth_type': self.auth_type,
- 'key': secret_key
- }
+ result["data"] = {"auth_type": self.auth_type, "key": secret_key}
return result
def save(self, secret_key):
"""保存2fa配置"""
- result = {'status': 0, 'msg': 'ok'}
+ result = {"status": 0, "msg": "ok"}
try:
with transaction.atomic():
@@ -58,11 +56,11 @@ def save(self, secret_key):
username=self.user.username,
auth_type=self.auth_type,
secret_key=secret_key,
- user=self.user
+ user=self.user,
)
except Exception as msg:
- result['status'] = 1
- result['msg'] = str(msg)
+ result["status"] = 1
+ result["msg"] = str(msg)
logger.error(traceback.format_exc())
return result
@@ -77,13 +75,16 @@ def generate_qrcode(request, data):
"""生成并返回二维码图片流"""
user = request.user
- username = user.username if user.is_authenticated else request.session.get('user')
+ username = user.username if user.is_authenticated else request.session.get("user")
secret_key = data
# 生成二维码
- qr_data = pyotp.totp.TOTP(secret_key).provisioning_uri(username, issuer_name="Archery")
- qrcode = QRCode(version=1, error_correction=constants.ERROR_CORRECT_L,
- box_size=6, border=4)
+ qr_data = pyotp.totp.TOTP(secret_key).provisioning_uri(
+ username, issuer_name="Archery"
+ )
+ qrcode = QRCode(
+ version=1, error_correction=constants.ERROR_CORRECT_L, box_size=6, border=4
+ )
try:
qrcode.add_data(qr_data)
qrcode.make(fit=True)
diff --git a/common/utils/aes_decryptor.py b/common/utils/aes_decryptor.py
index d51aff3579..10f9f6c040 100644
--- a/common/utils/aes_decryptor.py
+++ b/common/utils/aes_decryptor.py
@@ -2,42 +2,42 @@
from binascii import b2a_hex, a2b_hex
-class Prpcrypt():
+class Prpcrypt:
def __init__(self):
- self.key = 'eCcGFZQj6PNoSSma31LR39rTzTbLkU8E'.encode('utf-8')
+ self.key = "eCcGFZQj6PNoSSma31LR39rTzTbLkU8E".encode("utf-8")
self.mode = AES.MODE_CBC
# 加密函数,如果text不足16位就用空格补足为16位,
# 如果大于16当时不是16的倍数,那就补足为16的倍数。
def encrypt(self, text):
- cryptor = AES.new(self.key, self.mode, b'0000000000000000')
+ cryptor = AES.new(self.key, self.mode, b"0000000000000000")
# 这里密钥key 长度必须为16(AES-128),
# 24(AES-192),或者32 (AES-256)Bytes 长度
# 目前AES-128 足够目前使用
length = 16
count = len(text)
if count < length:
- add = (length - count)
+ add = length - count
# \0 backspace
- text = text + ('\0' * add)
+ text = text + ("\0" * add)
elif count > length:
- add = (length - (count % length))
- text = text + ('\0' * add)
- self.ciphertext = cryptor.encrypt(text.encode('utf-8'))
+ add = length - (count % length)
+ text = text + ("\0" * add)
+ self.ciphertext = cryptor.encrypt(text.encode("utf-8"))
# 因为AES加密时候得到的字符串不一定是ascii字符集的,输出到终端或者保存时候可能存在问题
# 所以这里统一把加密后的字符串转化为16进制字符串
- return b2a_hex(self.ciphertext).decode(encoding='utf-8')
+ return b2a_hex(self.ciphertext).decode(encoding="utf-8")
# 解密后,去掉补足的空格用strip() 去掉
def decrypt(self, text):
- cryptor = AES.new(self.key, self.mode, b'0000000000000000')
+ cryptor = AES.new(self.key, self.mode, b"0000000000000000")
plain_text = cryptor.decrypt(a2b_hex(text))
- return plain_text.decode().rstrip('\0')
+ return plain_text.decode().rstrip("\0")
-if __name__ == '__main__':
+if __name__ == "__main__":
pc = Prpcrypt() # 初始化密钥
- e = pc.encrypt('123456') # 加密
+ e = pc.encrypt("123456") # 加密
d = pc.decrypt(e) # 解密
print("加密:", str(e))
print("解密:", str(d))
diff --git a/common/utils/aliyun_sdk.py b/common/utils/aliyun_sdk.py
index 60d61fc299..1d1edb1051 100644
--- a/common/utils/aliyun_sdk.py
+++ b/common/utils/aliyun_sdk.py
@@ -3,12 +3,15 @@
import traceback
from aliyunsdkcore.client import AcsClient
-from aliyunsdkrds.request.v20140815 import DescribeSlowLogsRequest, DescribeSlowLogRecordsRequest, \
- RequestServiceOfCloudDBARequest
+from aliyunsdkrds.request.v20140815 import (
+ DescribeSlowLogsRequest,
+ DescribeSlowLogRecordsRequest,
+ RequestServiceOfCloudDBARequest,
+)
import simplejson as json
import logging
-logger = logging.getLogger('default')
+logger = logging.getLogger("default")
class Aliyun(object):
@@ -19,16 +22,21 @@ def __init__(self, rds):
secret = rds.ak.raw_key_secret
self.clt = AcsClient(ak=ak, secret=secret)
except Exception as m:
- raise Exception(f'阿里云认证失败:{m}{traceback.format_exc()}')
+ raise Exception(f"阿里云认证失败:{m}{traceback.format_exc()}")
def request_api(self, request, *values):
if values:
for value in values:
for k, v in value.items():
request.add_query_param(k, v)
- request.set_accept_format('json')
+ request.set_accept_format("json")
result = self.clt.do_action_with_exception(request)
- return json.dumps(json.loads(result.decode('utf-8')), indent=4, sort_keys=False, ensure_ascii=False)
+ return json.dumps(
+ json.loads(result.decode("utf-8")),
+ indent=4,
+ sort_keys=False,
+ ensure_ascii=False,
+ )
# 阿里云UTC时间转换为本地时区时间
@staticmethod
@@ -42,8 +50,13 @@ def utc2local(utc, utc_format):
def DescribeSlowLogs(self, StartTime, EndTime, **kwargs):
"""获取实例慢日志列表DBName,SortKey、PageSize、PageNumber"""
request = DescribeSlowLogsRequest.DescribeSlowLogsRequest()
- values = {"action_name": "DescribeSlowLogs", "DBInstanceId": self.DBInstanceId,
- "StartTime": StartTime, "EndTime": EndTime, "SortKey": "TotalExecutionCounts"}
+ values = {
+ "action_name": "DescribeSlowLogs",
+ "DBInstanceId": self.DBInstanceId,
+ "StartTime": StartTime,
+ "EndTime": EndTime,
+ "SortKey": "TotalExecutionCounts",
+ }
values = dict(values, **kwargs)
result = self.request_api(request, values)
return result
@@ -51,14 +64,19 @@ def DescribeSlowLogs(self, StartTime, EndTime, **kwargs):
def DescribeSlowLogRecords(self, StartTime, EndTime, **kwargs):
"""查看慢日志明细SQLId,DBName、PageSize、PageNumber"""
request = DescribeSlowLogRecordsRequest.DescribeSlowLogRecordsRequest()
- values = {"action_name": "DescribeSlowLogRecords", "DBInstanceId": self.DBInstanceId,
- "StartTime": StartTime, "EndTime": EndTime}
+ values = {
+ "action_name": "DescribeSlowLogRecords",
+ "DBInstanceId": self.DBInstanceId,
+ "StartTime": StartTime,
+ "EndTime": EndTime,
+ }
values = dict(values, **kwargs)
result = self.request_api(request, values)
return result
- def RequestServiceOfCloudDBA(self, ServiceRequestType, ServiceRequestParam,
- **kwargs):
+ def RequestServiceOfCloudDBA(
+ self, ServiceRequestType, ServiceRequestParam, **kwargs
+ ):
"""
获取统计信息:'GetTimedMonData',{"Language":"zh","KeyGroup":"mem_cpu_usage","KeyName":"","StartTime":"2018-01-15T04:03:26Z","EndTime":"2018-01-15T05:03:26Z"}
mem_cpu_usage、iops_usage、detailed_disk_space
@@ -68,8 +86,12 @@ def RequestServiceOfCloudDBA(self, ServiceRequestType, ServiceRequestParam,
获取资源利用信息:'GetResourceUsage',{"Language":"zh"}
"""
request = RequestServiceOfCloudDBARequest.RequestServiceOfCloudDBARequest()
- values = {"action_name": "RequestServiceOfCloudDBA", "DBInstanceId": self.DBInstanceId,
- "ServiceRequestType": ServiceRequestType, "ServiceRequestParam": ServiceRequestParam}
+ values = {
+ "action_name": "RequestServiceOfCloudDBA",
+ "DBInstanceId": self.DBInstanceId,
+ "ServiceRequestType": ServiceRequestType,
+ "ServiceRequestParam": ServiceRequestParam,
+ }
values = dict(values, **kwargs)
result = self.request_api(request, values)
return result
diff --git a/common/utils/aliyun_sms.py b/common/utils/aliyun_sms.py
index 162b039fcf..92458ed8cf 100644
--- a/common/utils/aliyun_sms.py
+++ b/common/utils/aliyun_sms.py
@@ -7,45 +7,44 @@
import traceback
import logging
-logger = logging.getLogger('default')
+logger = logging.getLogger("default")
class AliyunSMS:
def __init__(self):
all_config = SysConfig()
- self.access_key_id = all_config.get('aliyun_access_key_id', '')
- self.access_key_secret = all_config.get('aliyun_access_key_secret', '')
- self.sign_name = all_config.get('aliyun_sign_name', '')
- self.template_code = all_config.get('aliyun_template_code', '')
- self.variable_name = all_config.get('aliyun_variable_name', 'code')
+ self.access_key_id = all_config.get("aliyun_access_key_id", "")
+ self.access_key_secret = all_config.get("aliyun_access_key_secret", "")
+ self.sign_name = all_config.get("aliyun_sign_name", "")
+ self.template_code = all_config.get("aliyun_template_code", "")
+ self.variable_name = all_config.get("aliyun_variable_name", "code")
def create_client(self):
config = open_api_models.Config(
- access_key_id=self.access_key_id,
- access_key_secret=self.access_key_secret
+ access_key_id=self.access_key_id, access_key_secret=self.access_key_secret
)
- config.endpoint = f'dysmsapi.aliyuncs.com'
+ config.endpoint = f"dysmsapi.aliyuncs.com"
return Dysmsapi20170525Client(config)
def send_code(self, **kwargs):
- result = {'status': 0, 'msg': 'ok'}
+ result = {"status": 0, "msg": "ok"}
client = AliyunSMS.create_client(self)
send_sms_request = dysmsapi_20170525_models.SendSmsRequest(
- phone_numbers=kwargs['phone'],
+ phone_numbers=kwargs["phone"],
sign_name=self.sign_name,
template_code=self.template_code,
- template_param=f"{{{self.variable_name}: \'{kwargs['otp']}\'}}"
+ template_param=f"{{{self.variable_name}: '{kwargs['otp']}'}}",
)
runtime = util_models.RuntimeOptions()
try:
response = client.send_sms_with_options(send_sms_request, runtime)
- if response.body.code != 'OK':
- result['status'] = 1
- result['msg'] = response.body.message
+ if response.body.code != "OK":
+ result["status"] = 1
+ result["msg"] = response.body.message
except Exception as e:
- result['status'] = 1
- result['msg'] = str(e)
+ result["status"] = 1
+ result["msg"] = str(e)
logger.error(str(e))
logger.error(traceback.format_exc())
diff --git a/common/utils/chart_dao.py b/common/utils/chart_dao.py
index 2e4b250ee0..c4858326e0 100644
--- a/common/utils/chart_dao.py
+++ b/common/utils/chart_dao.py
@@ -16,10 +16,7 @@ def __query(sql):
if fields:
for i in fields:
column_list.append(i[0])
- return {
- 'column_list': column_list,
- 'rows': rows
- }
+ return {"column_list": column_list, "rows": rows}
# 获取连续时间
@staticmethod
@@ -33,7 +30,7 @@ def get_date_list(begin_date, end_date):
# 语法类型
def syntax_type(self):
- sql = '''
+ sql = """
select
case when syntax_type = 1
then 'DDL'
@@ -43,73 +40,83 @@ def syntax_type(self):
end as syntax_type,
count(*)
from sql_workflow
- group by syntax_type;'''
+ group by syntax_type;"""
return self.__query(sql)
# 工单数量统计
def workflow_by_date(self, cycle):
- sql = '''
+ sql = """
select
date_format(create_time, '%Y-%m-%d'),
count(*)
from sql_workflow
where create_time >= date_add(now(), interval -{} day)
group by date_format(create_time, '%Y-%m-%d')
- order by 1 asc;'''.format(cycle)
+ order by 1 asc;""".format(
+ cycle
+ )
return self.__query(sql)
# 工单按组统计
def workflow_by_group(self, cycle):
- sql = '''
+ sql = """
select
group_name,
count(*)
from sql_workflow
where create_time >= date_add(now(), interval -{} day )
group by group_id
- order by count(*) desc;'''.format(cycle)
+ order by count(*) desc;""".format(
+ cycle
+ )
return self.__query(sql)
def workflow_by_user(self, cycle):
"""工单按人统计"""
# TODO select 的对象应该为engineer ID, 查询时应作联合查询查出用户中文名
- sql = '''
+ sql = """
select
engineer_display,
count(*)
from sql_workflow
where create_time >= date_add(now(), interval -{} day)
group by engineer_display
- order by count(*) desc;'''.format(cycle)
+ order by count(*) desc;""".format(
+ cycle
+ )
return self.__query(sql)
# SQL查询统计(每日检索行数)
def querylog_effect_row_by_date(self, cycle):
- sql = '''
+ sql = """
select
date_format(create_time, '%Y-%m-%d'),
sum(effect_row)
from query_log
where create_time >= date_add(now(), interval -{} day )
group by date_format(create_time, '%Y-%m-%d')
- order by sum(effect_row) desc;'''.format(cycle)
+ order by sum(effect_row) desc;""".format(
+ cycle
+ )
return self.__query(sql)
# SQL查询统计(每日检索次数)
def querylog_count_by_date(self, cycle):
- sql = '''
+ sql = """
select
date_format(create_time, '%Y-%m-%d'),
count(*)
from query_log
where create_time >= date_add(now(), interval -{} day )
group by date_format(create_time, '%Y-%m-%d')
- order by count(*) desc;'''.format(cycle)
+ order by count(*) desc;""".format(
+ cycle
+ )
return self.__query(sql)
# SQL查询统计(用户检索行数)
def querylog_effect_row_by_user(self, cycle):
- sql = '''
+ sql = """
select
user_display,
sum(effect_row)
@@ -117,12 +124,14 @@ def querylog_effect_row_by_user(self, cycle):
where create_time >= date_add(now(), interval -{} day)
group by user_display
order by sum(effect_row) desc
- limit 10;'''.format(cycle)
+ limit 10;""".format(
+ cycle
+ )
return self.__query(sql)
# SQL查询统计(DB检索行数)
def querylog_effect_row_by_db(self, cycle):
- sql = '''
+ sql = """
select
db_name,
sum(effect_row)
@@ -130,7 +139,9 @@ def querylog_effect_row_by_db(self, cycle):
where create_time >= date_add(now(), interval -{} day)
group by db_name
order by sum(effect_row) desc
- limit 10;'''.format(cycle)
+ limit 10;""".format(
+ cycle
+ )
return self.__query(sql)
# 慢日志历史趋势图(按次数)
@@ -151,24 +162,28 @@ def slow_query_review_history_by_pct_95_time(self, checksum):
# 慢日志db/user维度统计
def slow_query_count_by_db_by_user(self, cycle):
- sql = '''
+ sql = """
select
concat(db_max,' user: ' ,user_max),
sum(ts_cnt)
from mysql_slow_query_review_history
where ts_min >= date_sub(now(), interval {} day)
group by db_max,user_max order by sum(ts_cnt) desc limit 50;
- '''.format(cycle)
+ """.format(
+ cycle
+ )
return self.__query(sql)
# 慢日志db维度统计
def slow_query_count_by_db(self, cycle):
- sql = '''
+ sql = """
select
db_max,
sum(ts_cnt)
from mysql_slow_query_review_history
where ts_min >= date_sub(now(), interval {} day)
group by db_max order by sum(ts_cnt) desc limit 50;
- '''.format(cycle)
- return self.__query(sql)
\ No newline at end of file
+ """.format(
+ cycle
+ )
+ return self.__query(sql)
diff --git a/common/utils/const.py b/common/utils/const.py
index dff832d080..96d4c363a5 100644
--- a/common/utils/const.py
+++ b/common/utils/const.py
@@ -1,76 +1,77 @@
# -*- coding: UTF-8 -*-
+
class Const(object):
# 定时任务id的前缀
workflowJobprefix = {
- 'query': 'query',
- 'sqlreview': 'sqlreview',
- 'archive': 'archive'
+ "query": "query",
+ "sqlreview": "sqlreview",
+ "archive": "archive",
}
class WorkflowDict:
# 工作流申请类型,1.query,2.SQL上线申请
workflow_type = {
- 'query': 1,
- 'query_display': '查询权限申请',
- 'sqlreview': 2,
- 'sqlreview_display': 'SQL上线申请',
- 'archive': 3,
- 'archive_display': '数据归档申请',
+ "query": 1,
+ "query_display": "查询权限申请",
+ "sqlreview": 2,
+ "sqlreview_display": "SQL上线申请",
+ "archive": 3,
+ "archive_display": "数据归档申请",
}
# 工作流状态,0.待审核 1.审核通过 2.审核不通过 3.审核取消
workflow_status = {
- 'audit_wait': 0,
- 'audit_wait_display': '待审核',
- 'audit_success': 1,
- 'audit_success_display': '审核通过',
- 'audit_reject': 2,
- 'audit_reject_display': '审核不通过',
- 'audit_abort': 3,
- 'audit_abort_display': '审核取消',
+ "audit_wait": 0,
+ "audit_wait_display": "待审核",
+ "audit_success": 1,
+ "audit_success_display": "审核通过",
+ "audit_reject": 2,
+ "audit_reject_display": "审核不通过",
+ "audit_abort": 3,
+ "audit_abort_display": "审核取消",
}
class SQLTuning:
SYS_PARM_FILTER = [
- 'BINLOG_CACHE_SIZE',
- 'BULK_INSERT_BUFFER_SIZE',
- 'HAVE_PARTITION_ENGINE',
- 'HAVE_QUERY_CACHE',
- 'INTERACTIVE_TIMEOUT',
- 'JOIN_BUFFER_SIZE',
- 'KEY_BUFFER_SIZE',
- 'KEY_CACHE_AGE_THRESHOLD',
- 'KEY_CACHE_BLOCK_SIZE',
- 'KEY_CACHE_DIVISION_LIMIT',
- 'LARGE_PAGES',
- 'LOCKED_IN_MEMORY',
- 'LONG_QUERY_TIME',
- 'MAX_ALLOWED_PACKET',
- 'MAX_BINLOG_CACHE_SIZE',
- 'MAX_BINLOG_SIZE',
- 'MAX_CONNECT_ERRORS',
- 'MAX_CONNECTIONS',
- 'MAX_JOIN_SIZE',
- 'MAX_LENGTH_FOR_SORT_DATA',
- 'MAX_SEEKS_FOR_KEY',
- 'MAX_SORT_LENGTH',
- 'MAX_TMP_TABLES',
- 'MAX_USER_CONNECTIONS',
- 'OPTIMIZER_PRUNE_LEVEL',
- 'OPTIMIZER_SEARCH_DEPTH',
- 'QUERY_CACHE_SIZE',
- 'QUERY_CACHE_TYPE',
- 'QUERY_PREALLOC_SIZE',
- 'RANGE_ALLOC_BLOCK_SIZE',
- 'READ_BUFFER_SIZE',
- 'READ_RND_BUFFER_SIZE',
- 'SORT_BUFFER_SIZE',
- 'SQL_MODE',
- 'TABLE_CACHE',
- 'THREAD_CACHE_SIZE',
- 'TMP_TABLE_SIZE',
- 'WAIT_TIMEOUT'
+ "BINLOG_CACHE_SIZE",
+ "BULK_INSERT_BUFFER_SIZE",
+ "HAVE_PARTITION_ENGINE",
+ "HAVE_QUERY_CACHE",
+ "INTERACTIVE_TIMEOUT",
+ "JOIN_BUFFER_SIZE",
+ "KEY_BUFFER_SIZE",
+ "KEY_CACHE_AGE_THRESHOLD",
+ "KEY_CACHE_BLOCK_SIZE",
+ "KEY_CACHE_DIVISION_LIMIT",
+ "LARGE_PAGES",
+ "LOCKED_IN_MEMORY",
+ "LONG_QUERY_TIME",
+ "MAX_ALLOWED_PACKET",
+ "MAX_BINLOG_CACHE_SIZE",
+ "MAX_BINLOG_SIZE",
+ "MAX_CONNECT_ERRORS",
+ "MAX_CONNECTIONS",
+ "MAX_JOIN_SIZE",
+ "MAX_LENGTH_FOR_SORT_DATA",
+ "MAX_SEEKS_FOR_KEY",
+ "MAX_SORT_LENGTH",
+ "MAX_TMP_TABLES",
+ "MAX_USER_CONNECTIONS",
+ "OPTIMIZER_PRUNE_LEVEL",
+ "OPTIMIZER_SEARCH_DEPTH",
+ "QUERY_CACHE_SIZE",
+ "QUERY_CACHE_TYPE",
+ "QUERY_PREALLOC_SIZE",
+ "RANGE_ALLOC_BLOCK_SIZE",
+ "READ_BUFFER_SIZE",
+ "READ_RND_BUFFER_SIZE",
+ "SORT_BUFFER_SIZE",
+ "SQL_MODE",
+ "TABLE_CACHE",
+ "THREAD_CACHE_SIZE",
+ "TMP_TABLE_SIZE",
+ "WAIT_TIMEOUT",
]
diff --git a/common/utils/convert.py b/common/utils/convert.py
index d464e78c82..9fb4e0eace 100644
--- a/common/utils/convert.py
+++ b/common/utils/convert.py
@@ -9,9 +9,11 @@ class Convert(Func):
"""
def __init__(self, expression, transcoding_name, **extra):
- super(Convert, self).__init__(expression=expression, transcoding_name=transcoding_name, **extra)
+ super(Convert, self).__init__(
+ expression=expression, transcoding_name=transcoding_name, **extra
+ )
def as_mysql(self, compiler, connection):
- self.function = 'CONVERT'
- self.template = '%(function)s(%(expression)s USING %(transcoding_name)s)'
+ self.function = "CONVERT"
+ self.template = "%(function)s(%(expression)s USING %(transcoding_name)s)"
return super(Convert, self).as_sql(compiler, connection)
diff --git a/common/utils/ding_api.py b/common/utils/ding_api.py
index 58bdef4195..6160204dad 100644
--- a/common/utils/ding_api.py
+++ b/common/utils/ding_api.py
@@ -10,8 +10,8 @@
from sql.models import Users
from sql.utils.tasks import add_sync_ding_user_schedule
-logger = logging.getLogger('default')
-rs = get_redis_connection('default')
+logger = logging.getLogger("default")
+rs = get_redis_connection("default")
def get_access_token():
@@ -26,13 +26,13 @@ def get_access_token():
return access_token.decode()
# 请求钉钉接口获取
sys_config = SysConfig()
- app_key = sys_config.get('ding_app_key')
- app_secret = sys_config.get('ding_app_secret')
+ app_key = sys_config.get("ding_app_key")
+ app_secret = sys_config.get("ding_app_secret")
url = f"https://oapi.dingtalk.com/gettoken?appkey={app_key}&appsecret={app_secret}"
resp = requests.get(url, timeout=3).json()
- if resp.get('errcode') == 0:
- access_token = resp.get('access_token')
- expires_in = resp.get('expires_in')
+ if resp.get("errcode") == 0:
+ access_token = resp.get("access_token")
+ expires_in = resp.get("expires_in")
rs.execute_command(f"SETEX ding_access_token {expires_in-60} {access_token}")
return access_token
else:
@@ -43,12 +43,12 @@ def get_access_token():
def get_ding_user_id(username):
"""更新用户ding_user_id"""
try:
- ding_user_id = rs.execute_command('GET {}'.format(username.lower()))
+ ding_user_id = rs.execute_command("GET {}".format(username.lower()))
if ding_user_id:
user = Users.objects.get(username=username)
if user.ding_user_id != str(ding_user_id, encoding="utf8"):
user.ding_user_id = str(ding_user_id, encoding="utf8")
- user.save(update_fields=['ding_user_id'])
+ user.save(update_fields=["ding_user_id"])
except Exception as e:
logger.error(f"更新用户ding_user_id失败:{e}")
@@ -56,9 +56,13 @@ def get_ding_user_id(username):
def get_dept_list_id_fetch_child(token, parent_dept_id):
"""获取所有子部门列表"""
ids = [int(parent_dept_id)]
- url = 'https://oapi.dingtalk.com/department/list_ids?id={0}&access_token={1}'.format(parent_dept_id, token)
+ url = (
+ "https://oapi.dingtalk.com/department/list_ids?id={0}&access_token={1}".format(
+ parent_dept_id, token
+ )
+ )
resp = requests.get(url, timeout=3).json()
- if resp.get('errcode') == 0:
+ if resp.get("errcode") == 0:
for dept_id in resp.get("sub_dept_id_list"):
ids.extend(get_dept_list_id_fetch_child(token, dept_id))
return list(set(ids))
@@ -70,40 +74,46 @@ def sync_ding_user_id():
所以可根据钉钉中 jobnumber 查到该用户的 ding_user_id。
"""
sys_config = SysConfig()
- ding_dept_ids = sys_config.get('ding_dept_ids', '')
- username2ding = sys_config.get('ding_archery_username')
+ ding_dept_ids = sys_config.get("ding_dept_ids", "")
+ username2ding = sys_config.get("ding_archery_username")
token = get_access_token()
if not token:
return False
# 获取全部部门列表
sub_dept_id_list = []
- for dept_id in list(set(ding_dept_ids.split(','))):
+ for dept_id in list(set(ding_dept_ids.split(","))):
sub_dept_id_list.extend(get_dept_list_id_fetch_child(token, dept_id))
# 遍历部门下的用户
user_ids = []
for sdi in sub_dept_id_list:
- url = f'https://oapi.dingtalk.com/user/getDeptMember?access_token={token}&deptId={sdi}'
+ url = f"https://oapi.dingtalk.com/user/getDeptMember?access_token={token}&deptId={sdi}"
try:
resp = requests.get(url, timeout=3).json()
- if resp.get('errcode') == 0:
- user_ids.extend(resp.get('userIds'))
+ if resp.get("errcode") == 0:
+ user_ids.extend(resp.get("userIds"))
else:
- raise Exception(f'获取部门用户出错:{resp}')
+ raise Exception(f"获取部门用户出错:{resp}")
except Exception as e:
- raise Exception(f'获取部门用户出错:{e}')
+ raise Exception(f"获取部门用户出错:{e}")
# 获取所有用户信息并缓存
for user_id in list(set(user_ids)):
- url = f'https://oapi.dingtalk.com/user/get?access_token={token}&userid={user_id}'
+ url = (
+ f"https://oapi.dingtalk.com/user/get?access_token={token}&userid={user_id}"
+ )
try:
resp = requests.get(url, timeout=3).json()
- if resp.get('errcode') == 0:
+ if resp.get("errcode") == 0:
if not resp.get(username2ding):
- raise Exception(f'钉钉用户信息不包含{username2ding}字段,无法获取id信息,请确认ding_archery_username配置{resp}')
- rs.execute_command(f"SETEX {resp.get(username2ding).lower()} 86400 {resp.get('userid')}")
+ raise Exception(
+ f"钉钉用户信息不包含{username2ding}字段,无法获取id信息,请确认ding_archery_username配置{resp}"
+ )
+ rs.execute_command(
+ f"SETEX {resp.get(username2ding).lower()} 86400 {resp.get('userid')}"
+ )
else:
- raise Exception(f'获取用户信息出错:{resp}')
+ raise Exception(f"获取用户信息出错:{resp}")
except Exception as e:
- raise Exception(f'获取用户信息出错:{e}')
+ raise Exception(f"获取用户信息出错:{e}")
return True
diff --git a/common/utils/extend_json_encoder.py b/common/utils/extend_json_encoder.py
index 8f8837ad1c..3232ccd393 100644
--- a/common/utils/extend_json_encoder.py
+++ b/common/utils/extend_json_encoder.py
@@ -13,17 +13,17 @@
@singledispatch
def convert(o):
- raise TypeError('can not convert type')
+ raise TypeError("can not convert type")
@convert.register(datetime)
def _(o):
- return o.strftime('%Y-%m-%d %H:%M:%S')
+ return o.strftime("%Y-%m-%d %H:%M:%S")
@convert.register(date)
def _(o):
- return o.strftime('%Y-%m-%d')
+ return o.strftime("%Y-%m-%d")
@convert.register(timedelta)
@@ -80,11 +80,10 @@ def default(self, obj):
class ExtendJSONEncoderFTime(json.JSONEncoder):
-
def default(self, obj):
try:
if isinstance(obj, datetime):
- return obj.isoformat(' ')
+ return obj.isoformat(" ")
else:
return convert(obj)
except TypeError:
@@ -93,19 +92,20 @@ def default(self, obj):
# 使用simplejson处理形如 b'\xaa' 的bytes类型数据会失败,但使用json模块构造这个对象时不能使用bigint_as_string方法
import json
+
+
class ExtendJSONEncoderBytes(json.JSONEncoder):
- def default(self, obj):
+ def default(self, obj):
try:
# 使用convert.register处理会报错 ValueError: Circular reference detected
# 不是utf-8格式的bytes格式需要先进行base64编码转换
if isinstance(obj, bytes):
try:
- return o.decode('utf-8')
+ return o.decode("utf-8")
except:
- return base64.b64encode(obj).decode('utf-8')
+ return base64.b64encode(obj).decode("utf-8")
else:
return convert(obj)
except TypeError:
print(type(obj))
return super(ExtendJSONEncoderBytes, self).default(obj)
-
diff --git a/common/utils/feishu_api.py b/common/utils/feishu_api.py
index 27e977070d..b765d8e696 100644
--- a/common/utils/feishu_api.py
+++ b/common/utils/feishu_api.py
@@ -5,13 +5,13 @@
from common.config import SysConfig
from django.core.cache import cache
-logger = logging.getLogger('default')
+logger = logging.getLogger("default")
def get_feishu_access_token():
# 优先获取缓存
try:
- token = cache.get('feishu_access_token')
+ token = cache.get("feishu_access_token")
except Exception as e:
logger.error(f"获取飞书token缓存出错:{e}")
token = None
@@ -19,20 +19,17 @@ def get_feishu_access_token():
return token
# 请求飞书接口获取
sys_config = SysConfig()
- app_id = sys_config.get('feishu_appid')
- app_secret = sys_config.get('feishu_app_secret')
+ app_id = sys_config.get("feishu_appid")
+ app_secret = sys_config.get("feishu_app_secret")
url = f"https://open.feishu.cn/open-apis/auth/v3/app_access_token/internal/"
- data = {
- "app_id": app_id,
- "app_secret": app_secret
- }
+ data = {"app_id": app_id, "app_secret": app_secret}
resp = requests.post(url, json=data, timeout=5).json()
# resp = requests.get(url, timeout=3).json()
logger.info(f"获取飞书access_token信息成功:{resp}")
- if resp.get('code') == 0:
- access_token = resp.get('app_access_token')
- expires_in = resp.get('expire')
- cache.set('feishu_access_token', access_token, timeout=expires_in - 60)
+ if resp.get("code") == 0:
+ access_token = resp.get("app_access_token")
+ expires_in = resp.get("expire")
+ cache.set("feishu_access_token", access_token, timeout=expires_in - 60)
return access_token
else:
logger.error(f"获取飞书access_token出错:{resp}")
@@ -40,11 +37,15 @@ def get_feishu_access_token():
def get_feishu_open_id(email):
- url = 'https://open.feishu.cn/open-apis/user/v1/batch_get_id?'
- resp = requests.get(url, timeout=3, headers={'Authorization': "Bearer " + get_feishu_access_token()},
- params={"emails": email}).json()
+ url = "https://open.feishu.cn/open-apis/user/v1/batch_get_id?"
+ resp = requests.get(
+ url,
+ timeout=3,
+ headers={"Authorization": "Bearer " + get_feishu_access_token()},
+ params={"emails": email},
+ ).json()
result = []
- if resp.get('code') == 0:
- for key in resp.get('data').get('email_users').values():
+ if resp.get("code") == 0:
+ for key in resp.get("data").get("email_users").values():
result.append(key[0][f"open_id"])
return result
diff --git a/common/utils/global_info.py b/common/utils/global_info.py
index 1994f1af75..5dc400a6c8 100644
--- a/common/utils/global_info.py
+++ b/common/utils/global_info.py
@@ -8,7 +8,7 @@
def global_info(request):
"""存放用户,菜单信息等."""
user = request.user
- twofa_type = 'disabled'
+ twofa_type = "disabled"
if user and user.is_authenticated:
# 获取待办数量
try:
@@ -20,17 +20,15 @@ def global_info(request):
if twofa_config:
twofa_type = twofa_config[0].auth_type
else:
- twofa_type = 'disabled'
+ twofa_type = "disabled"
else:
todo = 0
- watermark_enabled = SysConfig().get('watermark_enabled', False)
-
-
+ watermark_enabled = SysConfig().get("watermark_enabled", False)
return {
- 'todo': todo,
- 'archery_version': display_version,
- 'watermark_enabled': watermark_enabled,
- 'twofa_type': twofa_type
+ "todo": todo,
+ "archery_version": display_version,
+ "watermark_enabled": watermark_enabled,
+ "twofa_type": twofa_type,
}
diff --git a/common/utils/permission.py b/common/utils/permission.py
index cc334be1ab..0b8c10228b 100644
--- a/common/utils/permission.py
+++ b/common/utils/permission.py
@@ -12,10 +12,10 @@ def wrapper(request, *args, **kw):
if user.is_superuser is False:
if request.is_ajax():
- result = {'status': 1, 'msg': '您无权操作,请联系管理员', 'data': []}
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result = {"status": 1, "msg": "您无权操作,请联系管理员", "data": []}
+ return HttpResponse(json.dumps(result), content_type="application/json")
else:
- context = {'errMsg': "您无权操作,请联系管理员"}
+ context = {"errMsg": "您无权操作,请联系管理员"}
return render(request, "error.html", context)
return func(request, *args, **kw)
@@ -31,10 +31,12 @@ def wrapper(request, *args, **kw):
user = request.user
if user.role not in roles and user.is_superuser is False:
if request.is_ajax():
- result = {'status': 1, 'msg': '您无权操作,请联系管理员', 'data': []}
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result = {"status": 1, "msg": "您无权操作,请联系管理员", "data": []}
+ return HttpResponse(
+ json.dumps(result), content_type="application/json"
+ )
else:
- context = {'errMsg': "您无权操作,请联系管理员"}
+ context = {"errMsg": "您无权操作,请联系管理员"}
return render(request, "error.html", context)
return func(request, *args, **kw)
diff --git a/common/utils/sendmsg.py b/common/utils/sendmsg.py
index 4da8b3ebef..db551a2069 100755
--- a/common/utils/sendmsg.py
+++ b/common/utils/sendmsg.py
@@ -15,33 +15,32 @@
from common.utils.wx_api import get_wx_access_token
from common.utils.feishu_api import *
-logger = logging.getLogger('default')
+logger = logging.getLogger("default")
class MsgSender(object):
-
def __init__(self, **kwargs):
if kwargs:
- self.MAIL_REVIEW_SMTP_SERVER = kwargs.get('server')
- self.MAIL_REVIEW_SMTP_PORT = kwargs.get('port', 0)
- self.MAIL_REVIEW_FROM_ADDR = kwargs.get('user')
- self.MAIL_REVIEW_FROM_PASSWORD = kwargs.get('password')
- self.MAIL_SSL = kwargs.get('ssl')
+ self.MAIL_REVIEW_SMTP_SERVER = kwargs.get("server")
+ self.MAIL_REVIEW_SMTP_PORT = kwargs.get("port", 0)
+ self.MAIL_REVIEW_FROM_ADDR = kwargs.get("user")
+ self.MAIL_REVIEW_FROM_PASSWORD = kwargs.get("password")
+ self.MAIL_SSL = kwargs.get("ssl")
else:
sys_config = SysConfig()
# email信息
- self.MAIL_REVIEW_SMTP_SERVER = sys_config.get('mail_smtp_server')
- self.MAIL_REVIEW_SMTP_PORT = sys_config.get('mail_smtp_port', 0)
- self.MAIL_SSL = sys_config.get('mail_ssl')
- self.MAIL_REVIEW_FROM_ADDR = sys_config.get('mail_smtp_user')
- self.MAIL_REVIEW_FROM_PASSWORD = sys_config.get('mail_smtp_password')
+ self.MAIL_REVIEW_SMTP_SERVER = sys_config.get("mail_smtp_server")
+ self.MAIL_REVIEW_SMTP_PORT = sys_config.get("mail_smtp_port", 0)
+ self.MAIL_SSL = sys_config.get("mail_ssl")
+ self.MAIL_REVIEW_FROM_ADDR = sys_config.get("mail_smtp_user")
+ self.MAIL_REVIEW_FROM_PASSWORD = sys_config.get("mail_smtp_password")
# 钉钉信息
- self.ding_agent_id = sys_config.get('ding_agent_id')
+ self.ding_agent_id = sys_config.get("ding_agent_id")
# 企业微信信息
- self.wx_agent_id = sys_config.get('wx_agent_id')
+ self.wx_agent_id = sys_config.get("wx_agent_id")
# 飞书信息
- self.feishu_appid = sys_config.get('feishu_appid')
- self.feishu_app_secret = sys_config.get('feishu_app_secret')
+ self.feishu_appid = sys_config.get("feishu_appid")
+ self.feishu_app_secret = sys_config.get("feishu_app_secret")
if self.MAIL_REVIEW_SMTP_PORT:
self.MAIL_REVIEW_SMTP_PORT = int(self.MAIL_REVIEW_SMTP_PORT)
@@ -57,11 +56,14 @@ def _add_attachment(filename):
:param filename:
:return:
"""
- file_msg = email.mime.base.MIMEBase('application', 'octet-stream')
- file_msg.set_payload(open(filename, 'rb').read())
+ file_msg = email.mime.base.MIMEBase("application", "octet-stream")
+ file_msg.set_payload(open(filename, "rb").read())
# 附件如果有中文会出现乱码问题,加入gbk
- file_msg.add_header('Content-Disposition', 'attachment', filename=('gbk', '',
- filename.split('/')[-1]))
+ file_msg.add_header(
+ "Content-Disposition",
+ "attachment",
+ filename=("gbk", "", filename.split("/")[-1]),
+ )
encoders.encode_base64(file_msg)
return file_msg
@@ -79,49 +81,55 @@ def send_email(self, subject, body, to, **kwargs):
try:
if not to:
- logger.warning('收件人为空,无法发送邮件')
+ logger.warning("收件人为空,无法发送邮件")
return
if not isinstance(to, list):
- raise TypeError('收件人需要为列表')
- list_cc = kwargs.get('list_cc_addr', [])
+ raise TypeError("收件人需要为列表")
+ list_cc = kwargs.get("list_cc_addr", [])
if not isinstance(list_cc, list):
- raise TypeError('抄送人需要为列表')
+ raise TypeError("抄送人需要为列表")
# 构造MIMEMultipart对象做为根容器
main_msg = email.mime.multipart.MIMEMultipart()
# 添加文本内容
- text_msg = email.mime.text.MIMEText(body, 'plain', 'utf-8')
+ text_msg = email.mime.text.MIMEText(body, "plain", "utf-8")
main_msg.attach(text_msg)
# 添加附件
- filename_list = kwargs.get('filename_list')
+ filename_list = kwargs.get("filename_list")
if filename_list:
- for filename in kwargs['filename_list']:
+ for filename in kwargs["filename_list"]:
file_msg = self._add_attachment(filename)
main_msg.attach(file_msg)
# 消息内容:
- main_msg['Subject'] = Header(subject, "utf-8").encode()
- main_msg['From'] = formataddr(["Archery 通知", self.MAIL_REVIEW_FROM_ADDR])
- main_msg['To'] = ','.join(list(set(to)))
- main_msg['Cc'] = ', '.join(str(cc) for cc in list(set(list_cc)))
- main_msg['Date'] = email.utils.formatdate()
+ main_msg["Subject"] = Header(subject, "utf-8").encode()
+ main_msg["From"] = formataddr(["Archery 通知", self.MAIL_REVIEW_FROM_ADDR])
+ main_msg["To"] = ",".join(list(set(to)))
+ main_msg["Cc"] = ", ".join(str(cc) for cc in list(set(list_cc)))
+ main_msg["Date"] = email.utils.formatdate()
if self.MAIL_SSL:
- server = smtplib.SMTP_SSL(self.MAIL_REVIEW_SMTP_SERVER, self.MAIL_REVIEW_SMTP_PORT, timeout=3)
+ server = smtplib.SMTP_SSL(
+ self.MAIL_REVIEW_SMTP_SERVER, self.MAIL_REVIEW_SMTP_PORT, timeout=3
+ )
else:
- server = smtplib.SMTP(self.MAIL_REVIEW_SMTP_SERVER, self.MAIL_REVIEW_SMTP_PORT, timeout=3)
+ server = smtplib.SMTP(
+ self.MAIL_REVIEW_SMTP_SERVER, self.MAIL_REVIEW_SMTP_PORT, timeout=3
+ )
# 如果提供的密码为空,则不需要登录
if self.MAIL_REVIEW_FROM_PASSWORD:
server.login(self.MAIL_REVIEW_FROM_ADDR, self.MAIL_REVIEW_FROM_PASSWORD)
- server.sendmail(self.MAIL_REVIEW_FROM_ADDR, to + list_cc, main_msg.as_string())
+ server.sendmail(
+ self.MAIL_REVIEW_FROM_ADDR, to + list_cc, main_msg.as_string()
+ )
server.quit()
- logger.debug(f'邮件推送成功\n消息标题:{subject}\n通知对象:{to + list_cc}\n消息内容:{body}')
- return 'success'
+ logger.debug(f"邮件推送成功\n消息标题:{subject}\n通知对象:{to + list_cc}\n消息内容:{body}")
+ return "success"
except Exception:
- errmsg = '邮件推送失败\n{}'.format(traceback.format_exc())
+ errmsg = "邮件推送失败\n{}".format(traceback.format_exc())
logger.error(errmsg)
return errmsg
@@ -135,14 +143,12 @@ def send_ding(url, content):
"""
data = {
"msgtype": "text",
- "text": {
- "content": "{}".format(content)
- },
+ "text": {"content": "{}".format(content)},
}
r = requests.post(url=url, json=data)
r_json = r.json()
- if r_json['errcode'] == 0:
- logger.debug(f'钉钉Webhook推送成功\n通知对象:{url}\n消息内容:{content}')
+ if r_json["errcode"] == 0:
+ logger.debug(f"钉钉Webhook推送成功\n通知对象:{url}\n消息内容:{content}")
else:
logger.error(f"钉钉Webhook推送失败错误码\n请求url:{url}\n请求data:{data}\n请求响应:{r_json}")
@@ -156,89 +162,83 @@ def send_ding2user(self, userid_list, content):
access_token = get_access_token()
send_url = f"https://oapi.dingtalk.com/topapi/message/corpconversation/asyncsend_v2?access_token={access_token}"
data = {
- "userid_list": ','.join(list(set(userid_list))),
+ "userid_list": ",".join(list(set(userid_list))),
"agent_id": self.ding_agent_id,
"msg": {"msgtype": "text", "text": {"content": f"{content}"}},
}
r = requests.post(url=send_url, json=data, timeout=5)
r_json = r.json()
- if r_json['errcode'] == 0:
- logger.debug(f'钉钉推送成功\n通知对象:{userid_list}\n消息内容:{content}')
+ if r_json["errcode"] == 0:
+ logger.debug(f"钉钉推送成功\n通知对象:{userid_list}\n消息内容:{content}")
else:
- logger.error(f'钉钉推送失败\n请求连接:{send_url}\n请求参数:{data}\n请求响应:{r_json}')
+ logger.error(f"钉钉推送失败\n请求连接:{send_url}\n请求参数:{data}\n请求响应:{r_json}")
def send_wx2user(self, msg, user_list):
if not user_list:
- logger.error(f'企业微信推送失败,无法获取到推送的用户.')
+ logger.error(f"企业微信推送失败,无法获取到推送的用户.")
return
- to_user = '|'.join(list(set(user_list)))
+ to_user = "|".join(list(set(user_list)))
access_token = get_wx_access_token()
- send_url = f'https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token={access_token}'
+ send_url = f"https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token={access_token}"
data = {
"touser": to_user,
"msgtype": "text",
"agentid": self.wx_agent_id,
- "text": {
- "content": msg
- },
+ "text": {"content": msg},
}
res = requests.post(url=send_url, json=data, timeout=5)
r_json = res.json()
- if r_json['errcode'] == 0:
- logger.debug(f'企业微信推送成功\n通知对象:{to_user}')
+ if r_json["errcode"] == 0:
+ logger.debug(f"企业微信推送成功\n通知对象:{to_user}")
else:
- logger.error(f'企业微信推送失败\n请求连接:{send_url}\n请求参数:{data}\n请求响应:{r_json}')
+ logger.error(f"企业微信推送失败\n请求连接:{send_url}\n请求参数:{data}\n请求响应:{r_json}")
- def send_qywx_webhook(self,qywx_webhook, msg):
+ def send_qywx_webhook(self, qywx_webhook, msg):
send_url = qywx_webhook
# 对链接进行转换
- _msg = re.findall('https://.+(?=\n)|http://.+(?=\n)', msg)
+ _msg = re.findall("https://.+(?=\n)|http://.+(?=\n)", msg)
for url in _msg:
# 防止如 [xxx](http://www.a.com)\n 的字符串被再次替换
if url.strip()[-1] != ")":
- msg=msg.replace(url,'[请点击链接](%s)' % url)
+ msg = msg.replace(url, "[请点击链接](%s)" % url)
data = {
"msgtype": "markdown",
- "markdown": {
- "content": msg
- },
+ "markdown": {"content": msg},
}
res = requests.post(url=send_url, json=data, timeout=5)
r_json = res.json()
- if r_json['errcode'] == 0:
- logger.debug(f'企业微信机器人推送成功\n通知对象:机器人')
+ if r_json["errcode"] == 0:
+ logger.debug(f"企业微信机器人推送成功\n通知对象:机器人")
else:
- logger.error(f'企业微信机器人推送失败\n请求连接:{send_url}\n请求参数:{data}\n请求响应:{r_json}')
+ logger.error(f"企业微信机器人推送失败\n请求连接:{send_url}\n请求参数:{data}\n请求响应:{r_json}")
@staticmethod
def send_feishu_webhook(url, title, content):
- data = {
- "title": title,
- "text": content
- }
- if '/v2/' in url:
+ data = {"title": title, "text": content}
+ if "/v2/" in url:
data = {
- 'msg_type': 'post',
- 'content': {
- 'post': {
- 'zh_cn': {
- 'title': title,
- 'content': [[{
- 'tag': 'text',
- 'text': content
- }]]
+ "msg_type": "post",
+ "content": {
+ "post": {
+ "zh_cn": {
+ "title": title,
+ "content": [[{"tag": "text", "text": content}]],
}
}
- }
+ },
}
r = requests.post(url=url, json=data)
r_json = r.json()
- if 'ok' in r_json or ('StatusCode' in r_json and r_json['StatusCode'] == 0) or ('code' in r_json and r_json['code'] == 0):
- logger.debug(f'飞书Webhook推送成功\n通知对象:{url}\n消息内容:{content}')
+ if (
+ "ok" in r_json
+ or ("StatusCode" in r_json and r_json["StatusCode"] == 0)
+ or ("code" in r_json and r_json["code"] == 0)
+ ):
+ logger.debug(f"飞书Webhook推送成功\n通知对象:{url}\n消息内容:{content}")
else:
logger.error(f"飞书Webhook推送失败错误码\n请求url:{url}\n请求data:{data}\n请求响应:{r_json}")
@@ -252,12 +252,14 @@ def send_feishu_user(title, content, open_id, user_mail):
data = {
"open_ids": open_id,
"msg_type": "text",
- "content": {
- "text": f'{title}\n{content}'
- }
+ "content": {"text": f"{title}\n{content}"},
}
- r = requests.post(url=url, json=data, headers={'Authorization': "Bearer " + get_feishu_access_token()}).json()
- if r['code'] == 0:
- logger.debug(f'飞书单推推送成功\n通知对象:{url}\n消息内容:{content}')
+ r = requests.post(
+ url=url,
+ json=data,
+ headers={"Authorization": "Bearer " + get_feishu_access_token()},
+ ).json()
+ if r["code"] == 0:
+ logger.debug(f"飞书单推推送成功\n通知对象:{url}\n消息内容:{content}")
else:
logger.error(f"飞书单推推送失败错误码\n请求url:{url}\n请求data:{data}\n请求响应:{r}")
diff --git a/common/utils/tencent_sms.py b/common/utils/tencent_sms.py
index 9cf224aa9e..0cec8539e2 100644
--- a/common/utils/tencent_sms.py
+++ b/common/utils/tencent_sms.py
@@ -1,33 +1,32 @@
# -*- coding: utf-8 -*-
from tencentcloud.common import credential
-from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException
+from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
+ TencentCloudSDKException,
+)
from tencentcloud.sms.v20210111 import sms_client, models
from common.config import SysConfig
import traceback
import logging
-logger = logging.getLogger('default')
+logger = logging.getLogger("default")
class TencentSMS:
def __init__(self):
all_config = SysConfig()
- self.secret_id = all_config.get('tencent_secret_id', '')
- self.secret_key = all_config.get('tencent_secret_key', '')
- self.sign_name = all_config.get('tencent_sign_name', '')
- self.template_id = all_config.get('tencent_template_id', '')
- self.sdk_appid = all_config.get('tencent_sdk_appid', '')
+ self.secret_id = all_config.get("tencent_secret_id", "")
+ self.secret_key = all_config.get("tencent_secret_key", "")
+ self.sign_name = all_config.get("tencent_sign_name", "")
+ self.template_id = all_config.get("tencent_template_id", "")
+ self.sdk_appid = all_config.get("tencent_sdk_appid", "")
def create_client(self):
- cred = credential.Credential(
- self.secret_id,
- self.secret_key
- )
+ cred = credential.Credential(self.secret_id, self.secret_key)
client = sms_client.SmsClient(cred, "ap-guangzhou")
return client
def send_code(self, **kwargs):
- result = {'status': 0, 'msg': 'ok'}
+ result = {"status": 0, "msg": "ok"}
client = TencentSMS.create_client(self)
try:
@@ -35,16 +34,16 @@ def send_code(self, **kwargs):
req.SmsSdkAppId = self.sdk_appid
req.SignName = self.sign_name
req.TemplateId = self.template_id
- req.TemplateParamSet = [kwargs['otp']]
- req.PhoneNumberSet = [kwargs['phone']]
+ req.TemplateParamSet = [kwargs["otp"]]
+ req.PhoneNumberSet = [kwargs["phone"]]
resp = client.SendSms(req)
- if resp.SendStatusSet[0].Code != 'Ok':
- result['status'] = 1
- result['msg'] = resp.SendStatusSet[0].Message
+ if resp.SendStatusSet[0].Code != "Ok":
+ result["status"] = 1
+ result["msg"] = resp.SendStatusSet[0].Message
except TencentCloudSDKException as e:
- result['status'] = 1
- result['msg'] = str(e)
+ result["status"] = 1
+ result["msg"] = str(e)
logger.error(str(e))
logger.error(traceback.format_exc())
diff --git a/common/utils/timer.py b/common/utils/timer.py
index cea0f91e5b..c5cc11c17d 100644
--- a/common/utils/timer.py
+++ b/common/utils/timer.py
@@ -7,7 +7,7 @@
"""
import datetime
-__author__ = 'hhyo'
+__author__ = "hhyo"
class FuncTimer(object):
diff --git a/common/utils/wx_api.py b/common/utils/wx_api.py
index 6f8bc5d5d6..3975004249 100644
--- a/common/utils/wx_api.py
+++ b/common/utils/wx_api.py
@@ -5,13 +5,13 @@
from common.config import SysConfig
from django.core.cache import cache
-logger = logging.getLogger('default')
+logger = logging.getLogger("default")
def get_wx_access_token():
# 优先获取缓存
try:
- token = cache.get('wx_access_token')
+ token = cache.get("wx_access_token")
except Exception as e:
logger.error(f"获取企业微信token缓存出错:{e}")
token = None
@@ -19,14 +19,14 @@ def get_wx_access_token():
return token
# 请求企业微信接口获取
sys_config = SysConfig()
- corp_id = sys_config.get('wx_corpid')
- corp_secret = sys_config.get('wx_app_secret')
+ corp_id = sys_config.get("wx_corpid")
+ corp_secret = sys_config.get("wx_app_secret")
url = f"https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid={corp_id}&corpsecret={corp_secret}"
resp = requests.get(url, timeout=3).json()
- if resp.get('errcode') == 0:
- access_token = resp.get('access_token')
- expires_in = resp.get('expires_in')
- cache.set('wx_access_token', access_token, timeout=expires_in - 60)
+ if resp.get("errcode") == 0:
+ access_token = resp.get("access_token")
+ expires_in = resp.get("expires_in")
+ cache.set("wx_access_token", access_token, timeout=expires_in - 60)
return access_token
else:
logger.error(f"获取企业微信access_token出错:{resp}")
diff --git a/common/views.py b/common/views.py
index 5fdfac752e..cbf8a1d704 100644
--- a/common/views.py
+++ b/common/views.py
@@ -7,30 +7,32 @@
"""
from django.shortcuts import render
-__author__ = 'hhyo'
+__author__ = "hhyo"
from django.http import (
- HttpResponseBadRequest, HttpResponseForbidden, HttpResponseNotFound,
+ HttpResponseBadRequest,
+ HttpResponseForbidden,
+ HttpResponseNotFound,
HttpResponseServerError,
)
from django.views.decorators.csrf import requires_csrf_token
@requires_csrf_token
-def bad_request(request, exception, template_name='errors/400.html'):
+def bad_request(request, exception, template_name="errors/400.html"):
return HttpResponseBadRequest(render(request, template_name))
@requires_csrf_token
-def permission_denied(request, exception, template_name='errors/403.html'):
+def permission_denied(request, exception, template_name="errors/403.html"):
return HttpResponseForbidden(render(request, template_name))
@requires_csrf_token
-def page_not_found(request, exception, template_name='errors/404.html'):
+def page_not_found(request, exception, template_name="errors/404.html"):
return HttpResponseNotFound(render(request, template_name))
@requires_csrf_token
-def server_error(request, template_name='errors/500.html'):
+def server_error(request, template_name="errors/500.html"):
return HttpResponseServerError(render(request, template_name))
diff --git a/common/workflow.py b/common/workflow.py
index dd7988711a..2ca3edcfbc 100644
--- a/common/workflow.py
+++ b/common/workflow.py
@@ -13,11 +13,11 @@ def lists(request):
# 获取用户信息
user = request.user
- limit = int(request.POST.get('limit'))
- offset = int(request.POST.get('offset'))
- workflow_type = int(request.POST.get('workflow_type'))
+ limit = int(request.POST.get("limit"))
+ offset = int(request.POST.get("offset"))
+ workflow_type = int(request.POST.get("workflow_type"))
limit = offset + limit
- search = request.POST.get('search', '')
+ search = request.POST.get("search", "")
# 先获取用户所在资源组列表
group_list = user_groups(user)
@@ -31,43 +31,56 @@ def lists(request):
# 只返回所在资源组当前待自己审核的数据
workflow_audit = WorkflowAudit.objects.filter(
workflow_title__icontains=search,
- current_status=WorkflowDict.workflow_status['audit_wait'],
+ current_status=WorkflowDict.workflow_status["audit_wait"],
group_id__in=group_ids,
- current_audit__in=auth_group_ids
+ current_audit__in=auth_group_ids,
)
# 过滤工单类型
if workflow_type != 0:
workflow_audit = workflow_audit.filter(workflow_type=workflow_type)
audit_list_count = workflow_audit.count()
- audit_list = workflow_audit.order_by('-audit_id')[offset:limit].values(
- 'audit_id', 'workflow_type',
- 'workflow_title', 'create_user_display',
- 'create_time', 'current_status',
- 'audit_auth_groups',
- 'current_audit',
- 'group_name')
+ audit_list = workflow_audit.order_by("-audit_id")[offset:limit].values(
+ "audit_id",
+ "workflow_type",
+ "workflow_title",
+ "create_user_display",
+ "create_time",
+ "current_status",
+ "audit_auth_groups",
+ "current_audit",
+ "group_name",
+ )
# QuerySet 序列化
rows = [row for row in audit_list]
result = {"total": audit_list_count, "rows": rows}
# 返回查询结果
- return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
- content_type='application/json')
+ return HttpResponse(
+ json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
+ content_type="application/json",
+ )
# 获取工单日志
def log(request):
- workflow_id = request.POST.get('workflow_id')
- workflow_type = request.POST.get('workflow_type')
+ workflow_id = request.POST.get("workflow_id")
+ workflow_type = request.POST.get("workflow_type")
try:
- audit_id = WorkflowAudit.objects.get(workflow_id=workflow_id, workflow_type=workflow_type).audit_id
- workflow_logs = WorkflowLog.objects.filter(audit_id=audit_id).order_by('-id').values(
- 'operation_type_desc',
- 'operation_info',
- 'operator_display',
- 'operation_time')
+ audit_id = WorkflowAudit.objects.get(
+ workflow_id=workflow_id, workflow_type=workflow_type
+ ).audit_id
+ workflow_logs = (
+ WorkflowLog.objects.filter(audit_id=audit_id)
+ .order_by("-id")
+ .values(
+ "operation_type_desc",
+ "operation_info",
+ "operator_display",
+ "operation_time",
+ )
+ )
count = WorkflowLog.objects.filter(audit_id=audit_id).count()
except Exception:
workflow_logs = []
@@ -77,5 +90,7 @@ def log(request):
rows = [row for row in workflow_logs]
result = {"total": count, "rows": rows}
# 返回查询结果
- return HttpResponse(json.dumps(result, cls=ExtendJSONEncoderFTime, bigint_as_string=True),
- content_type='application/json')
+ return HttpResponse(
+ json.dumps(result, cls=ExtendJSONEncoderFTime, bigint_as_string=True),
+ content_type="application/json",
+ )
diff --git a/sql/admin.py b/sql/admin.py
index 36c9a6b713..13e80bb1c5 100755
--- a/sql/admin.py
+++ b/sql/admin.py
@@ -5,11 +5,31 @@
# Register your models here.
from django.forms import PasswordInput
-from .models import Users, Instance, SqlWorkflow, SqlWorkflowContent, QueryLog, DataMaskingColumns, DataMaskingRules, \
- AliyunRdsConfig, CloudAccessKey, ResourceGroup, QueryPrivilegesApply, \
- QueryPrivileges, InstanceAccount, InstanceDatabase, ArchiveConfig, \
- WorkflowAudit, WorkflowLog, ParamTemplate, ParamHistory, InstanceTag, \
- Tunnel, AuditEntry, TwoFactorAuthConfig
+from .models import (
+ Users,
+ Instance,
+ SqlWorkflow,
+ SqlWorkflowContent,
+ QueryLog,
+ DataMaskingColumns,
+ DataMaskingRules,
+ AliyunRdsConfig,
+ CloudAccessKey,
+ ResourceGroup,
+ QueryPrivilegesApply,
+ QueryPrivileges,
+ InstanceAccount,
+ InstanceDatabase,
+ ArchiveConfig,
+ WorkflowAudit,
+ WorkflowLog,
+ ParamTemplate,
+ ParamHistory,
+ InstanceTag,
+ Tunnel,
+ AuditEntry,
+ TwoFactorAuthConfig,
+)
from sql.form import TunnelForm, InstanceForm
@@ -17,65 +37,150 @@
# 用户管理
@admin.register(Users)
class UsersAdmin(UserAdmin):
- list_display = ('id', 'username', 'display', 'email', 'is_superuser', 'is_staff', 'is_active')
- search_fields = ('id', 'username', 'display', 'email')
- list_display_links = ('id', 'username',)
- ordering = ('id',)
+ list_display = (
+ "id",
+ "username",
+ "display",
+ "email",
+ "is_superuser",
+ "is_staff",
+ "is_active",
+ )
+ search_fields = ("id", "username", "display", "email")
+ list_display_links = (
+ "id",
+ "username",
+ )
+ ordering = ("id",)
# 编辑页显示内容
fieldsets = (
- ('认证信息', {'fields': ('username', 'password')}),
- ('个人信息', {'fields': ('display', 'email', 'ding_user_id', 'wx_user_id', 'feishu_open_id')}),
- ('权限信息', {'fields': ('is_superuser', 'is_active', 'is_staff', 'groups', 'user_permissions')}),
- ('资源组', {'fields': ('resource_group',)}),
- ('其他信息', {'fields': ('date_joined',)}),
+ ("认证信息", {"fields": ("username", "password")}),
+ (
+ "个人信息",
+ {
+ "fields": (
+ "display",
+ "email",
+ "ding_user_id",
+ "wx_user_id",
+ "feishu_open_id",
+ )
+ },
+ ),
+ (
+ "权限信息",
+ {
+ "fields": (
+ "is_superuser",
+ "is_active",
+ "is_staff",
+ "groups",
+ "user_permissions",
+ )
+ },
+ ),
+ ("资源组", {"fields": ("resource_group",)}),
+ ("其他信息", {"fields": ("date_joined",)}),
)
# 添加页显示内容
add_fieldsets = (
- ('认证信息', {'fields': ('username', 'password1', 'password2')}),
- ('个人信息', {'fields': ('display', 'email', 'ding_user_id', 'wx_user_id', 'feishu_open_id')}),
- ('权限信息', {'fields': ('is_superuser', 'is_active', 'is_staff', 'groups', 'user_permissions')}),
- ('资源组', {'fields': ('resource_group',)}),
+ ("认证信息", {"fields": ("username", "password1", "password2")}),
+ (
+ "个人信息",
+ {
+ "fields": (
+ "display",
+ "email",
+ "ding_user_id",
+ "wx_user_id",
+ "feishu_open_id",
+ )
+ },
+ ),
+ (
+ "权限信息",
+ {
+ "fields": (
+ "is_superuser",
+ "is_active",
+ "is_staff",
+ "groups",
+ "user_permissions",
+ )
+ },
+ ),
+ ("资源组", {"fields": ("resource_group",)}),
)
- filter_horizontal = ('groups', 'user_permissions', 'resource_group')
- list_filter = ('is_staff', 'is_superuser', 'is_active', 'groups', 'resource_group')
+ filter_horizontal = ("groups", "user_permissions", "resource_group")
+ list_filter = ("is_staff", "is_superuser", "is_active", "groups", "resource_group")
# 用户2fa管理
@admin.register(TwoFactorAuthConfig)
class TwoFactorAuthConfigAdmin(admin.ModelAdmin):
- list_display = ('id', 'username', 'auth_type', 'phone', 'secret_key', 'user_id')
+ list_display = ("id", "username", "auth_type", "phone", "secret_key", "user_id")
# 资源组管理
@admin.register(ResourceGroup)
class ResourceGroupAdmin(admin.ModelAdmin):
- list_display = ('group_id', 'group_name', 'ding_webhook', 'feishu_webhook', 'qywx_webhook', 'is_deleted')
- exclude = ('group_parent_id', 'group_sort', 'group_level',)
+ list_display = (
+ "group_id",
+ "group_name",
+ "ding_webhook",
+ "feishu_webhook",
+ "qywx_webhook",
+ "is_deleted",
+ )
+ exclude = (
+ "group_parent_id",
+ "group_sort",
+ "group_level",
+ )
# 实例标签配置
@admin.register(InstanceTag)
class InstanceTagAdmin(admin.ModelAdmin):
- list_display = ('id', 'tag_code', 'tag_name', 'active', 'create_time')
- list_display_links = ('id', 'tag_code',)
- fieldsets = (None, {'fields': ('tag_code', 'tag_name', 'active'), }),
+ list_display = ("id", "tag_code", "tag_name", "active", "create_time")
+ list_display_links = (
+ "id",
+ "tag_code",
+ )
+ fieldsets = (
+ (
+ None,
+ {
+ "fields": ("tag_code", "tag_name", "active"),
+ },
+ ),
+ )
# 不支持修改标签代码
def get_readonly_fields(self, request, obj=None):
- return ('tag_code',) if obj else ()
+ return ("tag_code",) if obj else ()
# 实例管理
@admin.register(Instance)
class InstanceAdmin(admin.ModelAdmin):
form = InstanceForm
- list_display = ('id', 'instance_name', 'db_type', 'type', 'host', 'port', 'user', 'create_time')
- search_fields = ['instance_name', 'host', 'port', 'user']
- list_filter = ('db_type', 'type', 'instance_tag')
+ list_display = (
+ "id",
+ "instance_name",
+ "db_type",
+ "type",
+ "host",
+ "port",
+ "user",
+ "create_time",
+ )
+ search_fields = ["instance_name", "host", "port", "user"]
+ list_filter = ("db_type", "type", "instance_tag")
def formfield_for_dbfield(self, db_field, **kwargs):
- if db_field.name == 'password':
- kwargs['widget'] = PasswordInput(render_value=True)
+ if db_field.name == "password":
+ kwargs["widget"] = PasswordInput(render_value=True)
return super(InstanceAdmin, self).formfield_for_dbfield(db_field, **kwargs)
# 阿里云实例关系配置
@@ -83,7 +188,10 @@ class AliRdsConfigInline(admin.TabularInline):
model = AliyunRdsConfig
# 实例资源组关联配置
- filter_horizontal = ('resource_group', 'instance_tag',)
+ filter_horizontal = (
+ "resource_group",
+ "instance_tag",
+ )
inlines = [AliRdsConfigInline]
@@ -91,28 +199,48 @@ class AliRdsConfigInline(admin.TabularInline):
# SSH隧道
@admin.register(Tunnel)
class TunnelAdmin(admin.ModelAdmin):
- list_display = ('id', 'tunnel_name', 'host', 'port', 'create_time')
- list_display_links = ('id', 'tunnel_name',)
- search_fields = ('id', 'tunnel_name')
+ list_display = ("id", "tunnel_name", "host", "port", "create_time")
+ list_display_links = (
+ "id",
+ "tunnel_name",
+ )
+ search_fields = ("id", "tunnel_name")
fieldsets = (
- None,
- {'fields': ('tunnel_name', 'host', 'port', 'user', 'password', 'pkey_path', 'pkey_password', 'pkey'), }),
- ordering = ('id',)
+ (
+ None,
+ {
+ "fields": (
+ "tunnel_name",
+ "host",
+ "port",
+ "user",
+ "password",
+ "pkey_path",
+ "pkey_password",
+ "pkey",
+ ),
+ },
+ ),
+ )
+ ordering = ("id",)
# 添加页显示内容
add_fieldsets = (
- ('隧道信息', {'fields': ('tunnel_name', 'host', 'port')}),
- ('连接信息', {'fields': ('user', 'password', 'pkey_path', 'pkey_password', 'pkey')}),
+ ("隧道信息", {"fields": ("tunnel_name", "host", "port")}),
+ (
+ "连接信息",
+ {"fields": ("user", "password", "pkey_path", "pkey_password", "pkey")},
+ ),
)
form = TunnelForm
def formfield_for_dbfield(self, db_field, **kwargs):
- if db_field.name in ['password', 'pkey_password']:
- kwargs['widget'] = PasswordInput(render_value=True)
+ if db_field.name in ["password", "pkey_password"]:
+ kwargs["widget"] = PasswordInput(render_value=True)
return super(TunnelAdmin, self).formfield_for_dbfield(db_field, **kwargs)
# 不支持修改标签代码
def get_readonly_fields(self, request, obj=None):
- return ('id',) if obj else ()
+ return ("id",) if obj else ()
# SQL工单内容
@@ -124,10 +252,31 @@ class SqlWorkflowContentInline(admin.TabularInline):
@admin.register(SqlWorkflow)
class SqlWorkflowAdmin(admin.ModelAdmin):
list_display = (
- 'id', 'workflow_name', 'group_name', 'instance', 'engineer_display', 'create_time', 'status', 'is_backup')
- search_fields = ['id', 'workflow_name', 'engineer_display', 'sqlworkflowcontent__sql_content']
- list_filter = ('group_name', 'instance__instance_name', 'status', 'syntax_type',)
- list_display_links = ('id', 'workflow_name',)
+ "id",
+ "workflow_name",
+ "group_name",
+ "instance",
+ "engineer_display",
+ "create_time",
+ "status",
+ "is_backup",
+ )
+ search_fields = [
+ "id",
+ "workflow_name",
+ "engineer_display",
+ "sqlworkflowcontent__sql_content",
+ ]
+ list_filter = (
+ "group_name",
+ "instance__instance_name",
+ "status",
+ "syntax_type",
+ )
+ list_display_links = (
+ "id",
+ "workflow_name",
+ )
inlines = [SqlWorkflowContentInline]
@@ -135,130 +284,247 @@ class SqlWorkflowAdmin(admin.ModelAdmin):
@admin.register(QueryLog)
class QueryLogAdmin(admin.ModelAdmin):
list_display = (
- 'instance_name', 'db_name', 'sqllog', 'effect_row', 'cost_time', 'user_display', 'create_time')
- search_fields = ['sqllog', 'user_display']
- list_filter = ('instance_name', 'db_name', 'user_display', 'priv_check', 'hit_rule', 'masking',)
+ "instance_name",
+ "db_name",
+ "sqllog",
+ "effect_row",
+ "cost_time",
+ "user_display",
+ "create_time",
+ )
+ search_fields = ["sqllog", "user_display"]
+ list_filter = (
+ "instance_name",
+ "db_name",
+ "user_display",
+ "priv_check",
+ "hit_rule",
+ "masking",
+ )
# 查询权限列表
@admin.register(QueryPrivileges)
class QueryPrivilegesAdmin(admin.ModelAdmin):
- list_display = ('privilege_id', 'user_display', 'instance', 'db_name', 'table_name',
- 'valid_date', 'limit_num', 'create_time')
- search_fields = ['user_display', 'instance__instance_name']
- list_filter = ('user_display', 'instance', 'db_name', 'table_name',)
+ list_display = (
+ "privilege_id",
+ "user_display",
+ "instance",
+ "db_name",
+ "table_name",
+ "valid_date",
+ "limit_num",
+ "create_time",
+ )
+ search_fields = ["user_display", "instance__instance_name"]
+ list_filter = (
+ "user_display",
+ "instance",
+ "db_name",
+ "table_name",
+ )
# 查询权限申请记录
@admin.register(QueryPrivilegesApply)
class QueryPrivilegesApplyAdmin(admin.ModelAdmin):
- list_display = ('apply_id', 'user_display', 'group_name', 'instance', 'valid_date', 'limit_num', 'create_time')
- search_fields = ['user_display', 'instance__instance_name', 'db_list', 'table_list']
- list_filter = ('user_display', 'group_name', 'instance')
+ list_display = (
+ "apply_id",
+ "user_display",
+ "group_name",
+ "instance",
+ "valid_date",
+ "limit_num",
+ "create_time",
+ )
+ search_fields = ["user_display", "instance__instance_name", "db_list", "table_list"]
+ list_filter = ("user_display", "group_name", "instance")
# 脱敏字段页面定义
@admin.register(DataMaskingColumns)
class DataMaskingColumnsAdmin(admin.ModelAdmin):
list_display = (
- 'column_id', 'rule_type', 'active', 'instance', 'table_schema', 'table_name', 'column_name', 'column_comment',
- 'create_time',)
- search_fields = ['table_name', 'column_name']
- list_filter = ('rule_type', 'active', 'instance__instance_name')
+ "column_id",
+ "rule_type",
+ "active",
+ "instance",
+ "table_schema",
+ "table_name",
+ "column_name",
+ "column_comment",
+ "create_time",
+ )
+ search_fields = ["table_name", "column_name"]
+ list_filter = ("rule_type", "active", "instance__instance_name")
# 脱敏规则页面定义
@admin.register(DataMaskingRules)
class DataMaskingRulesAdmin(admin.ModelAdmin):
list_display = (
- 'rule_type', 'rule_regex', 'hide_group', 'rule_desc', 'sys_time',)
+ "rule_type",
+ "rule_regex",
+ "hide_group",
+ "rule_desc",
+ "sys_time",
+ )
# 工作流审批列表
@admin.register(WorkflowAudit)
class WorkflowAuditAdmin(admin.ModelAdmin):
list_display = (
- 'workflow_title', 'group_name', 'workflow_type', 'current_status', 'create_user_display', 'create_time')
- search_fields = ['workflow_title', 'create_user_display']
- list_filter = ('create_user_display', 'group_name', 'workflow_type', 'current_status')
+ "workflow_title",
+ "group_name",
+ "workflow_type",
+ "current_status",
+ "create_user_display",
+ "create_time",
+ )
+ search_fields = ["workflow_title", "create_user_display"]
+ list_filter = (
+ "create_user_display",
+ "group_name",
+ "workflow_type",
+ "current_status",
+ )
# 工作流日志表
@admin.register(WorkflowLog)
class WorkflowLogAdmin(admin.ModelAdmin):
list_display = (
- 'operation_type_desc', 'operation_info', 'operator_display', 'operation_time',)
- list_filter = ('operation_type_desc', 'operator_display')
+ "operation_type_desc",
+ "operation_info",
+ "operator_display",
+ "operation_time",
+ )
+ list_filter = ("operation_type_desc", "operator_display")
# 实例数据库列表
@admin.register(InstanceDatabase)
class InstanceDatabaseAdmin(admin.ModelAdmin):
- list_display = ('db_name', 'owner_display', 'instance', 'remark')
- search_fields = ('db_name',)
- list_filter = ('instance', 'owner_display')
- list_display_links = ('db_name',)
+ list_display = ("db_name", "owner_display", "instance", "remark")
+ search_fields = ("db_name",)
+ list_filter = ("instance", "owner_display")
+ list_display_links = ("db_name",)
# 仅支持修改备注
def get_readonly_fields(self, request, obj=None):
- return ('instance', 'owner', 'owner_display') if obj else ()
+ return ("instance", "owner", "owner_display") if obj else ()
# 实例用户列表
@admin.register(InstanceAccount)
class InstanceAccountAdmin(admin.ModelAdmin):
- list_display = ('user', 'host', 'password', 'instance', 'remark')
- search_fields = ('user', 'host')
- list_filter = ('instance', 'host')
- list_display_links = ('user',)
+ list_display = ("user", "host", "password", "instance", "remark")
+ search_fields = ("user", "host")
+ list_filter = ("instance", "host")
+ list_display_links = ("user",)
# 仅支持修改备注
def get_readonly_fields(self, request, obj=None):
- return ('user', 'host', 'instance',) if obj else ()
+ return (
+ (
+ "user",
+ "host",
+ "instance",
+ )
+ if obj
+ else ()
+ )
# 实例参数配置表
@admin.register(ParamTemplate)
class ParamTemplateAdmin(admin.ModelAdmin):
- list_display = ('variable_name', 'db_type', 'default_value', 'editable', 'valid_values')
- search_fields = ('variable_name',)
- list_filter = ('db_type', 'editable')
- list_display_links = ('variable_name',)
+ list_display = (
+ "variable_name",
+ "db_type",
+ "default_value",
+ "editable",
+ "valid_values",
+ )
+ search_fields = ("variable_name",)
+ list_filter = ("db_type", "editable")
+ list_display_links = ("variable_name",)
# 实例参数修改历史
@admin.register(ParamHistory)
class ParamHistoryAdmin(admin.ModelAdmin):
- list_display = ('variable_name', 'instance', 'old_var', 'new_var', 'user_display', 'create_time')
- search_fields = ('variable_name',)
- list_filter = ('instance', 'user_display')
+ list_display = (
+ "variable_name",
+ "instance",
+ "old_var",
+ "new_var",
+ "user_display",
+ "create_time",
+ )
+ search_fields = ("variable_name",)
+ list_filter = ("instance", "user_display")
# 归档配置
@admin.register(ArchiveConfig)
class ArchiveConfigAdmin(admin.ModelAdmin):
list_display = (
- 'id', 'title', 'src_instance', 'src_db_name', 'src_table_name',
- 'dest_instance', 'dest_db_name', 'dest_table_name',
- 'mode', 'no_delete', 'status', 'state', 'user_display', 'create_time', 'resource_group')
- search_fields = ('title', 'src_table_name')
- list_display_links = ('id', 'title')
- list_filter = ('src_instance', 'src_db_name', 'mode', 'no_delete', 'state')
+ "id",
+ "title",
+ "src_instance",
+ "src_db_name",
+ "src_table_name",
+ "dest_instance",
+ "dest_db_name",
+ "dest_table_name",
+ "mode",
+ "no_delete",
+ "status",
+ "state",
+ "user_display",
+ "create_time",
+ "resource_group",
+ )
+ search_fields = ("title", "src_table_name")
+ list_display_links = ("id", "title")
+ list_filter = ("src_instance", "src_db_name", "mode", "no_delete", "state")
# 编辑页显示内容
- fields = ('title', 'resource_group', 'src_instance', 'src_db_name', 'src_table_name',
- 'dest_instance', 'dest_db_name', 'dest_table_name',
- 'mode', 'condition', 'sleep', 'no_delete', 'state', 'user_name', 'user_display')
+ fields = (
+ "title",
+ "resource_group",
+ "src_instance",
+ "src_db_name",
+ "src_table_name",
+ "dest_instance",
+ "dest_db_name",
+ "dest_table_name",
+ "mode",
+ "condition",
+ "sleep",
+ "no_delete",
+ "state",
+ "user_name",
+ "user_display",
+ )
# 云服务认证信息配置
@admin.register(CloudAccessKey)
class CloudAccessKeyAdmin(admin.ModelAdmin):
- list_display = ('type', 'key_id', 'key_secret', 'remark')
+ list_display = ("type", "key_id", "key_secret", "remark")
# 登录审计日志
@admin.register(AuditEntry)
class AuditEntryAdmin(admin.ModelAdmin):
- list_display = ('user_id', 'user_name', 'user_display', 'action', 'extra_info', 'action_time')
- list_filter = ('user_id', 'user_name', 'user_display', 'action', 'extra_info')
-
+ list_display = (
+ "user_id",
+ "user_name",
+ "user_display",
+ "action",
+ "extra_info",
+ "action_time",
+ )
+ list_filter = ("user_id", "user_name", "user_display", "action", "extra_info")
diff --git a/sql/aliyun_rds.py b/sql/aliyun_rds.py
index d94552ae23..e2a9fd0416 100644
--- a/sql/aliyun_rds.py
+++ b/sql/aliyun_rds.py
@@ -9,23 +9,23 @@
# 获取SQL慢日志统计
def slowquery_review(request):
- instance_name = request.POST.get('instance_name')
- db_name = request.POST.get('db_name')
- start_time = request.POST.get('StartTime')
- end_time = request.POST.get('EndTime')
- limit = request.POST.get('limit')
- offset = request.POST.get('offset')
+ instance_name = request.POST.get("instance_name")
+ db_name = request.POST.get("db_name")
+ start_time = request.POST.get("StartTime")
+ end_time = request.POST.get("EndTime")
+ limit = request.POST.get("limit")
+ offset = request.POST.get("offset")
# 计算页数
page_number = (int(offset) + int(limit)) / int(limit)
values = {"PageSize": int(limit), "PageNumber": int(page_number)}
# DBName非必传
if db_name:
- values['DBName'] = db_name
+ values["DBName"] = db_name
# UTC时间转化成阿里云需求的时间格式
- start_time = '%sZ' % start_time
- end_time = '%sZ' % end_time
+ start_time = "%sZ" % start_time
+ end_time = "%sZ" % end_time
# 通过实例名称获取关联的rds实例id
instance_info = AliyunRdsConfig.objects.get(instance__instance_name=instance_name)
@@ -33,55 +33,70 @@ def slowquery_review(request):
slowsql = Aliyun(rds=instance_info).DescribeSlowLogs(start_time, end_time, **values)
# 解决table数据丢失精度、格式化时间
- sql_slow_log = json.loads(slowsql)['Items']['SQLSlowLog']
+ sql_slow_log = json.loads(slowsql)["Items"]["SQLSlowLog"]
for SlowLog in sql_slow_log:
- SlowLog['SQLId'] = str(SlowLog['SQLHASH'])
- SlowLog['CreateTime'] = Aliyun.utc2local(SlowLog['CreateTime'], utc_format="%Y-%m-%dZ")
-
- result = {"total": json.loads(slowsql)['TotalRecordCount'], "rows": sql_slow_log,
- "PageSize": json.loads(slowsql)['PageRecordCount'], "PageNumber": json.loads(slowsql)['PageNumber']}
+ SlowLog["SQLId"] = str(SlowLog["SQLHASH"])
+ SlowLog["CreateTime"] = Aliyun.utc2local(
+ SlowLog["CreateTime"], utc_format="%Y-%m-%dZ"
+ )
+
+ result = {
+ "total": json.loads(slowsql)["TotalRecordCount"],
+ "rows": sql_slow_log,
+ "PageSize": json.loads(slowsql)["PageRecordCount"],
+ "PageNumber": json.loads(slowsql)["PageNumber"],
+ }
# 返回查询结果
return result
# 获取SQL慢日志明细
def slowquery_review_history(request):
- instance_name = request.POST.get('instance_name')
- start_time = request.POST.get('StartTime')
- end_time = request.POST.get('EndTime')
- db_name = request.POST.get('db_name')
- sql_id = request.POST.get('SQLId')
- limit = request.POST.get('limit')
- offset = request.POST.get('offset')
+ instance_name = request.POST.get("instance_name")
+ start_time = request.POST.get("StartTime")
+ end_time = request.POST.get("EndTime")
+ db_name = request.POST.get("db_name")
+ sql_id = request.POST.get("SQLId")
+ limit = request.POST.get("limit")
+ offset = request.POST.get("offset")
# 计算页数
page_number = (int(offset) + int(limit)) / int(limit)
values = {"PageSize": int(limit), "PageNumber": int(page_number)}
# SQLId、DBName非必传
if sql_id:
- values['SQLHASH'] = sql_id
+ values["SQLHASH"] = sql_id
if db_name:
- values['DBName'] = db_name
+ values["DBName"] = db_name
# UTC时间转化成阿里云需求的时间格式
- start_time = datetime.datetime.strptime(start_time, "%Y-%m-%d").date() - datetime.timedelta(days=1)
- start_time = '%sT16:00Z' % start_time
- end_time = '%sT15:59Z' % end_time
+ start_time = datetime.datetime.strptime(
+ start_time, "%Y-%m-%d"
+ ).date() - datetime.timedelta(days=1)
+ start_time = "%sT16:00Z" % start_time
+ end_time = "%sT15:59Z" % end_time
# 通过实例名称获取关联的rds实例id
instance_info = AliyunRdsConfig.objects.get(instance__instance_name=instance_name)
# 调用aliyun接口获取SQL慢日志统计
- slowsql = Aliyun(rds=instance_info).DescribeSlowLogRecords(start_time, end_time, **values)
+ slowsql = Aliyun(rds=instance_info).DescribeSlowLogRecords(
+ start_time, end_time, **values
+ )
# 格式化时间\过滤HostAddress
- sql_slow_record = json.loads(slowsql)['Items']['SQLSlowRecord']
+ sql_slow_record = json.loads(slowsql)["Items"]["SQLSlowRecord"]
for SlowRecord in sql_slow_record:
- SlowRecord['ExecutionStartTime'] = Aliyun.utc2local(SlowRecord['ExecutionStartTime'],
- utc_format='%Y-%m-%dT%H:%M:%SZ')
- SlowRecord['HostAddress'] = SlowRecord['HostAddress'].split('[')[0]
-
- result = {"total": json.loads(slowsql)['TotalRecordCount'], "rows": sql_slow_record,
- "PageSize": json.loads(slowsql)['PageRecordCount'], "PageNumber": json.loads(slowsql)['PageNumber']}
+ SlowRecord["ExecutionStartTime"] = Aliyun.utc2local(
+ SlowRecord["ExecutionStartTime"], utc_format="%Y-%m-%dT%H:%M:%SZ"
+ )
+ SlowRecord["HostAddress"] = SlowRecord["HostAddress"].split("[")[0]
+
+ result = {
+ "total": json.loads(slowsql)["TotalRecordCount"],
+ "rows": sql_slow_record,
+ "PageSize": json.loads(slowsql)["PageRecordCount"],
+ "PageNumber": json.loads(slowsql)["PageNumber"],
+ }
# 返回查询结果
return result
@@ -89,23 +104,24 @@ def slowquery_review_history(request):
# 问题诊断--进程列表
def process_status(request):
- instance_name = request.POST.get('instance_name')
- command_type = request.POST.get('command_type')
+ instance_name = request.POST.get("instance_name")
+ command_type = request.POST.get("command_type")
- if command_type is None or command_type == '':
- command_type = 'Query'
+ if command_type is None or command_type == "":
+ command_type = "Query"
# 通过实例名称获取关联的rds实例id
instance_info = AliyunRdsConfig.objects.get(instance__instance_name=instance_name)
# 调用aliyun接口获取进程数据
process_info = Aliyun(rds=instance_info).RequestServiceOfCloudDBA(
- 'ShowProcessList', {"Language": "zh", "Command": command_type})
+ "ShowProcessList", {"Language": "zh", "Command": command_type}
+ )
# 提取进程列表
- process_list = json.loads(process_info)['AttrData']
- process_list = json.loads(process_list)['ProcessList']
+ process_list = json.loads(process_info)["AttrData"]
+ process_list = json.loads(process_list)["ProcessList"]
- result = {'status': 0, 'msg': 'ok', 'rows': process_list}
+ result = {"status": 0, "msg": "ok", "rows": process_list}
# 返回查询结果
return result
@@ -113,20 +129,22 @@ def process_status(request):
# 问题诊断--通过进程id构建请求id
def create_kill_session(request):
- instance_name = request.POST.get('instance_name')
- thread_ids = request.POST.get('ThreadIDs')
+ instance_name = request.POST.get("instance_name")
+ thread_ids = request.POST.get("ThreadIDs")
- result = {'status': 0, 'msg': 'ok', 'data': []}
+ result = {"status": 0, "msg": "ok", "data": []}
# 通过实例名称获取关联的rds实例id
instance_info = AliyunRdsConfig.objects.get(instance__instance_name=instance_name)
# 调用aliyun接口获取进程数据
request_info = Aliyun(rds=instance_info).RequestServiceOfCloudDBA(
- 'CreateKillSessionRequest', {"Language": "zh", "ThreadIDs": json.loads(thread_ids)})
+ "CreateKillSessionRequest",
+ {"Language": "zh", "ThreadIDs": json.loads(thread_ids)},
+ )
# 提取进程列表
- request_list = json.loads(request_info)['AttrData']
+ request_list = json.loads(request_info)["AttrData"]
- result['data'] = request_list
+ result["data"] = request_list
# 返回处理结果
return result
@@ -134,21 +152,23 @@ def create_kill_session(request):
# 问题诊断--终止会话
def kill_session(request):
- instance_name = request.POST.get('instance_name')
- request_params = request.POST.get('request_params')
+ instance_name = request.POST.get("instance_name")
+ request_params = request.POST.get("request_params")
- result = {'status': 0, 'msg': 'ok', 'data': []}
+ result = {"status": 0, "msg": "ok", "data": []}
# 通过实例名称获取关联的rds实例id
instance_info = AliyunRdsConfig.objects.get(instance__instance_name=instance_name)
# 调用aliyun接口获取终止进程
request_params = json.loads(request_params)
service_request_param = dict({"Language": "zh"}, **request_params)
- kill_result = Aliyun(rds=instance_info).RequestServiceOfCloudDBA('ConfirmKillSessionRequest', service_request_param)
+ kill_result = Aliyun(rds=instance_info).RequestServiceOfCloudDBA(
+ "ConfirmKillSessionRequest", service_request_param
+ )
# 获取处理结果
- kill_result = json.loads(kill_result)['AttrData']
+ kill_result = json.loads(kill_result)["AttrData"]
- result['data'] = kill_result
+ result["data"] = kill_result
# 返回查询结果
return result
@@ -156,22 +176,23 @@ def kill_session(request):
# 问题诊断--空间列表
def sapce_status(request):
- instance_name = request.POST.get('instance_name')
+ instance_name = request.POST.get("instance_name")
# 通过实例名称获取关联的rds实例id
instance_info = AliyunRdsConfig.objects.get(instance__instance_name=instance_name)
# 调用aliyun接口获取进程数据
space_info = Aliyun(rds=instance_info).RequestServiceOfCloudDBA(
- 'GetSpaceStatForTables', {"Language": "zh", "OrderType": "Data"})
+ "GetSpaceStatForTables", {"Language": "zh", "OrderType": "Data"}
+ )
# 提取进程列表
- space_list = json.loads(space_info)['ListData']
+ space_list = json.loads(space_info)["ListData"]
if space_list:
space_list = json.loads(space_list)
else:
space_list = []
- result = {'status': 0, 'msg': 'ok', 'rows': space_list}
+ result = {"status": 0, "msg": "ok", "rows": space_list}
# 返回查询结果
return result
diff --git a/sql/archiver.py b/sql/archiver.py
index 063aed1eaf..9189149e57 100644
--- a/sql/archiver.py
+++ b/sql/archiver.py
@@ -32,11 +32,11 @@
from sql.models import ArchiveConfig, ArchiveLog, Instance, ResourceGroup
from sql.utils.workflow_audit import Audit
-logger = logging.getLogger('default')
-__author__ = 'hhyo'
+logger = logging.getLogger("default")
+__author__ = "hhyo"
-@permission_required('sql.menu_archive', raise_exception=True)
+@permission_required("sql.menu_archive", raise_exception=True)
def archive_list(request):
"""
获取归档申请列表
@@ -44,47 +44,62 @@ def archive_list(request):
:return:
"""
user = request.user
- filter_instance_id = request.GET.get('filter_instance_id')
- state = request.GET.get('state')
- limit = int(request.GET.get('limit', 0))
- offset = int(request.GET.get('offset', 0))
+ filter_instance_id = request.GET.get("filter_instance_id")
+ state = request.GET.get("state")
+ limit = int(request.GET.get("limit", 0))
+ offset = int(request.GET.get("offset", 0))
limit = offset + limit
- search = request.GET.get('search', '')
+ search = request.GET.get("search", "")
# 组合筛选项
filter_dict = dict()
if filter_instance_id:
- filter_dict['src_instance'] = filter_instance_id
- if state == 'true':
- filter_dict['state'] = True
- elif state == 'false':
- filter_dict['state'] = False
+ filter_dict["src_instance"] = filter_instance_id
+ if state == "true":
+ filter_dict["state"] = True
+ elif state == "false":
+ filter_dict["state"] = False
# 管理员可以看到全部数据
if user.is_superuser:
pass
# 拥有审核权限、可以查看组内所有工单
- elif user.has_perm('sql.archive_review'):
+ elif user.has_perm("sql.archive_review"):
# 先获取用户所在资源组列表
group_list = user_groups(user)
group_ids = [group.group_id for group in group_list]
- filter_dict['resource_group__in'] = group_ids
+ filter_dict["resource_group__in"] = group_ids
# 其他人只能看到自己提交的工单
else:
- filter_dict['user_name'] = user.username
+ filter_dict["user_name"] = user.username
# 过滤组合筛选项
archive_config = ArchiveConfig.objects.filter(**filter_dict)
# 过滤搜索项,支持模糊搜索标题、用户
if search:
- archive_config = archive_config.filter(Q(title__icontains=search) | Q(user_display__icontains=search))
+ archive_config = archive_config.filter(
+ Q(title__icontains=search) | Q(user_display__icontains=search)
+ )
count = archive_config.count()
- lists = archive_config.order_by('-id')[offset:limit].values(
- 'id', 'title', 'src_instance__instance_name', 'src_db_name', 'src_table_name',
- 'dest_instance__instance_name', 'dest_db_name', 'dest_table_name', 'sleep',
- 'mode', 'no_delete', 'status', 'state', 'user_display', 'create_time', 'resource_group__group_name'
+ lists = archive_config.order_by("-id")[offset:limit].values(
+ "id",
+ "title",
+ "src_instance__instance_name",
+ "src_db_name",
+ "src_table_name",
+ "dest_instance__instance_name",
+ "dest_db_name",
+ "dest_table_name",
+ "sleep",
+ "mode",
+ "no_delete",
+ "status",
+ "state",
+ "user_display",
+ "create_time",
+ "resource_group__group_name",
)
# QuerySet 序列化
@@ -92,55 +107,75 @@ def archive_list(request):
result = {"total": count, "rows": rows}
# 返回查询结果
- return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
- content_type='application/json')
+ return HttpResponse(
+ json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
+ content_type="application/json",
+ )
-@permission_required('sql.archive_apply', raise_exception=True)
+@permission_required("sql.archive_apply", raise_exception=True)
def archive_apply(request):
"""申请归档实例数据"""
user = request.user
- title = request.POST.get('title')
- group_name = request.POST.get('group_name')
- src_instance_name = request.POST.get('src_instance_name')
- src_db_name = request.POST.get('src_db_name')
- src_table_name = request.POST.get('src_table_name')
- mode = request.POST.get('mode')
- dest_instance_name = request.POST.get('dest_instance_name')
- dest_db_name = request.POST.get('dest_db_name')
- dest_table_name = request.POST.get('dest_table_name')
- condition = request.POST.get('condition')
- no_delete = True if request.POST.get('no_delete') == 'true' else False
- sleep = request.POST.get('sleep') or 0
- result = {'status': 0, 'msg': 'ok', 'data': {}}
+ title = request.POST.get("title")
+ group_name = request.POST.get("group_name")
+ src_instance_name = request.POST.get("src_instance_name")
+ src_db_name = request.POST.get("src_db_name")
+ src_table_name = request.POST.get("src_table_name")
+ mode = request.POST.get("mode")
+ dest_instance_name = request.POST.get("dest_instance_name")
+ dest_db_name = request.POST.get("dest_db_name")
+ dest_table_name = request.POST.get("dest_table_name")
+ condition = request.POST.get("condition")
+ no_delete = True if request.POST.get("no_delete") == "true" else False
+ sleep = request.POST.get("sleep") or 0
+ result = {"status": 0, "msg": "ok", "data": {}}
# 参数校验
- if not all(
- [title, group_name, src_instance_name, src_db_name, src_table_name, mode, condition]) or no_delete is None:
- return JsonResponse({'status': 1, 'msg': '请填写完整!', 'data': {}})
- if mode == 'dest' and not all([dest_instance_name, dest_db_name, dest_table_name]):
- return JsonResponse({'status': 1, 'msg': '归档到实例时目标实例信息必选!', 'data': {}})
+ if (
+ not all(
+ [
+ title,
+ group_name,
+ src_instance_name,
+ src_db_name,
+ src_table_name,
+ mode,
+ condition,
+ ]
+ )
+ or no_delete is None
+ ):
+ return JsonResponse({"status": 1, "msg": "请填写完整!", "data": {}})
+ if mode == "dest" and not all([dest_instance_name, dest_db_name, dest_table_name]):
+ return JsonResponse({"status": 1, "msg": "归档到实例时目标实例信息必选!", "data": {}})
# 获取源实例信息
try:
- s_ins = user_instances(request.user, db_type=['mysql']).get(instance_name=src_instance_name)
+ s_ins = user_instances(request.user, db_type=["mysql"]).get(
+ instance_name=src_instance_name
+ )
except Instance.DoesNotExist:
- return JsonResponse({'status': 1, 'msg': '你所在组未关联该实例!', 'data': {}})
+ return JsonResponse({"status": 1, "msg": "你所在组未关联该实例!", "data": {}})
# 获取目标实例信息
- if mode == 'dest':
+ if mode == "dest":
try:
- d_ins = user_instances(request.user, db_type=['mysql']).get(instance_name=dest_instance_name)
+ d_ins = user_instances(request.user, db_type=["mysql"]).get(
+ instance_name=dest_instance_name
+ )
except Instance.DoesNotExist:
- return JsonResponse({'status': 1, 'msg': '你所在组未关联该实例!', 'data': {}})
+ return JsonResponse({"status": 1, "msg": "你所在组未关联该实例!", "data": {}})
else:
d_ins = None
# 获取资源组和审批信息
res_group = ResourceGroup.objects.get(group_name=group_name)
- audit_auth_groups = Audit.settings(res_group.group_id, WorkflowDict.workflow_type['archive'])
+ audit_auth_groups = Audit.settings(
+ res_group.group_id, WorkflowDict.workflow_type["archive"]
+ )
if not audit_auth_groups:
- return JsonResponse({'status': 1, 'msg': '审批流程不能为空,请先配置审批流程', 'data': {}})
+ return JsonResponse({"status": 1, "msg": "审批流程不能为空,请先配置审批流程", "data": {}})
# 使用事务保持数据一致性
try:
@@ -160,28 +195,34 @@ def archive_apply(request):
mode=mode,
no_delete=no_delete,
sleep=sleep,
- status=WorkflowDict.workflow_status['audit_wait'],
+ status=WorkflowDict.workflow_status["audit_wait"],
state=False,
user_name=user.username,
user_display=user.display,
)
archive_id = archive_info.id
# 调用工作流插入审核信息
- audit_result = Audit.add(WorkflowDict.workflow_type['archive'], archive_id)
+ audit_result = Audit.add(WorkflowDict.workflow_type["archive"], archive_id)
except Exception as msg:
logger.error(traceback.format_exc())
- result['status'] = 1
- result['msg'] = str(msg)
+ result["status"] = 1
+ result["msg"] = str(msg)
else:
result = audit_result
# 消息通知
- audit_id = Audit.detail_by_workflow_id(workflow_id=archive_id,
- workflow_type=WorkflowDict.workflow_type['archive']).audit_id
- async_task(notify_for_audit, audit_id=audit_id, timeout=60, task_name=f'archive-apply-{archive_id}')
- return HttpResponse(json.dumps(result), content_type='application/json')
-
-
-@permission_required('sql.archive_review', raise_exception=True)
+ audit_id = Audit.detail_by_workflow_id(
+ workflow_id=archive_id, workflow_type=WorkflowDict.workflow_type["archive"]
+ ).audit_id
+ async_task(
+ notify_for_audit,
+ audit_id=audit_id,
+ timeout=60,
+ task_name=f"archive-apply-{archive_id}",
+ )
+ return HttpResponse(json.dumps(result), content_type="application/json")
+
+
+@permission_required("sql.archive_review", raise_exception=True)
def archive_audit(request):
"""
审核数据归档申请
@@ -190,39 +231,51 @@ def archive_audit(request):
"""
# 获取用户信息
user = request.user
- archive_id = int(request.POST['archive_id'])
- audit_status = int(request.POST['audit_status'])
- audit_remark = request.POST.get('audit_remark')
+ archive_id = int(request.POST["archive_id"])
+ audit_status = int(request.POST["audit_status"])
+ audit_remark = request.POST.get("audit_remark")
if audit_remark is None:
- audit_remark = ''
+ audit_remark = ""
if Audit.can_review(request.user, archive_id, 3) is False:
- context = {'errMsg': '你无权操作当前工单!'}
- return render(request, 'error.html', context)
+ context = {"errMsg": "你无权操作当前工单!"}
+ return render(request, "error.html", context)
# 使用事务保持数据一致性
try:
with transaction.atomic():
- audit_id = Audit.detail_by_workflow_id(workflow_id=archive_id,
- workflow_type=WorkflowDict.workflow_type['archive']).audit_id
+ audit_id = Audit.detail_by_workflow_id(
+ workflow_id=archive_id,
+ workflow_type=WorkflowDict.workflow_type["archive"],
+ ).audit_id
# 调用工作流插入审核信息,更新业务表审核状态
- audit_status = Audit.audit(audit_id, audit_status, user.username, audit_remark)['data']['workflow_status']
- ArchiveConfig(id=archive_id,
- status=audit_status,
- state=True if audit_status == WorkflowDict.workflow_status['audit_success'] else False
- ).save(update_fields=['status', 'state'])
+ audit_status = Audit.audit(
+ audit_id, audit_status, user.username, audit_remark
+ )["data"]["workflow_status"]
+ ArchiveConfig(
+ id=archive_id,
+ status=audit_status,
+ state=True
+ if audit_status == WorkflowDict.workflow_status["audit_success"]
+ else False,
+ ).save(update_fields=["status", "state"])
except Exception as msg:
logger.error(traceback.format_exc())
- context = {'errMsg': msg}
- return render(request, 'error.html', context)
+ context = {"errMsg": msg}
+ return render(request, "error.html", context)
else:
# 消息通知
- async_task(notify_for_audit, audit_id=audit_id, audit_remark=audit_remark, timeout=60,
- task_name=f'archive-audit-{archive_id}')
+ async_task(
+ notify_for_audit,
+ audit_id=audit_id,
+ audit_remark=audit_remark,
+ timeout=60,
+ task_name=f"archive-audit-{archive_id}",
+ )
- return HttpResponseRedirect(reverse('sql:archive_detail', args=(archive_id,)))
+ return HttpResponseRedirect(reverse("sql:archive_detail", args=(archive_id,)))
def add_archive_task(archive_ids=None):
@@ -237,18 +290,25 @@ def add_archive_task(archive_ids=None):
# 没有传archive_id代表全部归档任务统一调度
if archive_ids:
archive_cnf_list = ArchiveConfig.objects.filter(
- id__in=archive_ids, state=True, status=WorkflowDict.workflow_status['audit_success'])
+ id__in=archive_ids,
+ state=True,
+ status=WorkflowDict.workflow_status["audit_success"],
+ )
else:
archive_cnf_list = ArchiveConfig.objects.filter(
- state=True, status=WorkflowDict.workflow_status['audit_success'])
+ state=True, status=WorkflowDict.workflow_status["audit_success"]
+ )
# 添加task任务
for archive_info in archive_cnf_list:
archive_id = archive_info.id
- async_task('sql.archiver.archive',
- archive_id,
- group=f'archive-{time.strftime("%Y-%m-%d %H:%M:%S ")}', timeout=-1,
- task_name=f'archive-{archive_id}')
+ async_task(
+ "sql.archiver.archive",
+ archive_id,
+ group=f'archive-{time.strftime("%Y-%m-%d %H:%M:%S ")}',
+ timeout=-1,
+ task_name=f"archive-{archive_id}",
+ )
def archive(archive_id):
@@ -269,28 +329,30 @@ def archive(archive_id):
s_engine = get_engine(s_ins)
s_db = s_engine.schema_object.databases[src_db_name]
s_tb = s_db.tables[src_table_name]
- s_charset = s_tb.options['charset'].value
+ s_charset = s_tb.options["charset"].value
if s_charset is None:
- s_charset = s_db.options['charset'].value
+ s_charset = s_db.options["charset"].value
pt_archiver = PtArchiver()
# 准备参数
- source = fr"h={s_ins.host},u={s_ins.user},p='{s_ins.password}'," \
- fr"P={s_ins.port},D={src_db_name},t={src_table_name},A={s_charset}"
+ source = (
+ rf"h={s_ins.host},u={s_ins.user},p='{s_ins.password}',"
+ rf"P={s_ins.port},D={src_db_name},t={src_table_name},A={s_charset}"
+ )
args = {
"no-version-check": True,
"source": source,
"where": condition,
"progress": 5000,
"statistics": True,
- "charset": 'utf8',
+ "charset": "utf8",
"limit": 10000,
"txn-size": 1000,
- "sleep": sleep
+ "sleep": sleep,
}
# 归档到目标实例
- if mode == 'dest':
+ if mode == "dest":
d_ins = archive_info.dest_instance
dest_db_name = archive_info.dest_db_name
dest_table_name = archive_info.dest_table_name
@@ -298,27 +360,31 @@ def archive(archive_id):
d_engine = get_engine(d_ins)
d_db = d_engine.schema_object.databases[dest_db_name]
d_tb = d_db.tables[dest_table_name]
- d_charset = d_tb.options['charset'].value
+ d_charset = d_tb.options["charset"].value
if d_charset is None:
- d_charset = d_db.options['charset'].value
+ d_charset = d_db.options["charset"].value
# dest
- dest = fr"h={d_ins.host},u={d_ins.user},p={d_ins.password},P={d_ins.port}," \
- fr"D={dest_db_name},t={dest_table_name},A={d_charset}"
- args['dest'] = dest
+ dest = (
+ rf"h={d_ins.host},u={d_ins.user},p={d_ins.password},P={d_ins.port},"
+ rf"D={dest_db_name},t={dest_table_name},A={d_charset}"
+ )
+ args["dest"] = dest
if no_delete:
- args['no-delete'] = True
- elif mode == 'file':
- output_directory = os.path.join(settings.BASE_DIR, 'downloads/archiver')
+ args["no-delete"] = True
+ elif mode == "file":
+ output_directory = os.path.join(settings.BASE_DIR, "downloads/archiver")
os.makedirs(output_directory, exist_ok=True)
- args['file'] = f'{output_directory}/{s_ins.instance_name}-{src_db_name}-{src_table_name}.txt'
+ args[
+ "file"
+ ] = f"{output_directory}/{s_ins.instance_name}-{src_db_name}-{src_table_name}.txt"
if no_delete:
- args['no-delete'] = True
- elif mode == 'purge':
- args['purge'] = True
+ args["no-delete"] = True
+ elif mode == "purge":
+ args["purge"] = True
# 参数检查
args_check_result = pt_archiver.check_args(args)
- if args_check_result['status'] == 1:
+ if args_check_result["status"] == 1:
return JsonResponse(args_check_result)
# 参数转换
cmd_args = pt_archiver.generate_args2cmd(args, shell=True)
@@ -328,15 +394,15 @@ def archive(archive_id):
delete_cnt = 0
with FuncTimer() as t:
p = pt_archiver.execute_cmd(cmd_args, shell=True)
- stdout = ''
- for line in iter(p.stdout.readline, ''):
- if re.match(r'^SELECT\s(\d+)$', line, re.I):
- select_cnt = re.findall(r'^SELECT\s(\d+)$', line)
- elif re.match(r'^INSERT\s(\d+)$', line, re.I):
- insert_cnt = re.findall(r'^INSERT\s(\d+)$', line)
- elif re.match(r'^DELETE\s(\d+)$', line, re.I):
- delete_cnt = re.findall(r'^DELETE\s(\d+)$', line)
- stdout += f'{line}\n'
+ stdout = ""
+ for line in iter(p.stdout.readline, ""):
+ if re.match(r"^SELECT\s(\d+)$", line, re.I):
+ select_cnt = re.findall(r"^SELECT\s(\d+)$", line)
+ elif re.match(r"^INSERT\s(\d+)$", line, re.I):
+ insert_cnt = re.findall(r"^INSERT\s(\d+)$", line)
+ elif re.match(r"^DELETE\s(\d+)$", line, re.I):
+ delete_cnt = re.findall(r"^DELETE\s(\d+)$", line)
+ stdout += f"{line}\n"
statistics = stdout
# 获取异常信息
stderr = p.stderr.read()
@@ -347,22 +413,22 @@ def archive(archive_id):
select_cnt = int(select_cnt[0]) if select_cnt else 0
insert_cnt = int(insert_cnt[0]) if insert_cnt else 0
delete_cnt = int(delete_cnt[0]) if delete_cnt else 0
- error_info = ''
+ error_info = ""
success = True
if stderr:
- error_info = f'命令执行报错:{stderr}'
+ error_info = f"命令执行报错:{stderr}"
success = False
- if mode == 'dest':
+ if mode == "dest":
# 删除源数据,判断删除数量和写入数量
if not no_delete and (insert_cnt != delete_cnt):
error_info = f"删除和写入数量不一致:{insert_cnt}!={delete_cnt}"
success = False
- elif mode == 'file':
+ elif mode == "file":
# 删除源数据,判断查询数量和删除数量
if not no_delete and (select_cnt != delete_cnt):
error_info = f"查询和删除数量不一致:{select_cnt}!={delete_cnt}"
success = False
- elif mode == 'purge':
+ elif mode == "purge":
# 直接删除。判断查询数量和删除数量
if select_cnt != delete_cnt:
error_info = f"查询和删除数量不一致:{select_cnt}!={delete_cnt}"
@@ -372,12 +438,15 @@ def archive(archive_id):
if connection.connection and not connection.is_usable():
close_old_connections()
# 更新最后归档时间
- ArchiveConfig(id=archive_id, last_archive_time=t.end).save(update_fields=['last_archive_time'])
+ ArchiveConfig(id=archive_id, last_archive_time=t.end).save(
+ update_fields=["last_archive_time"]
+ )
# 替换密码信息后保存
ArchiveLog.objects.create(
archive=archive_info,
- cmd=cmd_args.replace(s_ins.password, '***').replace(
- d_ins.password, '***') if mode == 'dest' else cmd_args.replace(s_ins.password, '***'),
+ cmd=cmd_args.replace(s_ins.password, "***").replace(d_ins.password, "***")
+ if mode == "dest"
+ else cmd_args.replace(s_ins.password, "***"),
condition=condition,
mode=mode,
no_delete=no_delete,
@@ -389,51 +458,69 @@ def archive(archive_id):
success=success,
error_info=error_info,
start_time=t.start,
- end_time=t.end
+ end_time=t.end,
)
if not success:
- raise Exception(f'{error_info}\n{statistics}')
+ raise Exception(f"{error_info}\n{statistics}")
-@permission_required('sql.menu_archive', raise_exception=True)
+@permission_required("sql.menu_archive", raise_exception=True)
def archive_log(request):
"""获取归档日志列表"""
- limit = int(request.GET.get('limit', 0))
- offset = int(request.GET.get('offset', 0))
+ limit = int(request.GET.get("limit", 0))
+ offset = int(request.GET.get("offset", 0))
limit = offset + limit
- archive_id = request.GET.get('archive_id')
+ archive_id = request.GET.get("archive_id")
archive_logs = ArchiveLog.objects.filter(archive=archive_id).annotate(
- info=Concat('cmd', V('\n'), 'statistics', output_field=TextField()))
+ info=Concat("cmd", V("\n"), "statistics", output_field=TextField())
+ )
count = archive_logs.count()
- lists = archive_logs.order_by('-id')[offset:limit].values(
- 'cmd', 'info', 'condition', 'mode', 'no_delete', 'select_cnt',
- 'insert_cnt', 'delete_cnt', 'success', 'error_info', 'start_time', 'end_time'
+ lists = archive_logs.order_by("-id")[offset:limit].values(
+ "cmd",
+ "info",
+ "condition",
+ "mode",
+ "no_delete",
+ "select_cnt",
+ "insert_cnt",
+ "delete_cnt",
+ "success",
+ "error_info",
+ "start_time",
+ "end_time",
)
# QuerySet 序列化
rows = [row for row in lists]
result = {"total": count, "rows": rows}
# 返回查询结果
- return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
- content_type='application/json')
+ return HttpResponse(
+ json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
+ content_type="application/json",
+ )
-@permission_required('sql.archive_mgt', raise_exception=True)
+@permission_required("sql.archive_mgt", raise_exception=True)
def archive_switch(request):
"""开启关闭归档任务"""
- archive_id = request.POST.get('archive_id')
- state = True if request.POST.get('state') == 'true' else False
+ archive_id = request.POST.get("archive_id")
+ state = True if request.POST.get("state") == "true" else False
# 更新启用状态
try:
- ArchiveConfig(id=archive_id, state=state).save(update_fields=['state'])
- return JsonResponse({'status': 0, 'msg': 'ok', 'data': {}})
+ ArchiveConfig(id=archive_id, state=state).save(update_fields=["state"])
+ return JsonResponse({"status": 0, "msg": "ok", "data": {}})
except Exception as msg:
- return JsonResponse({'status': 1, 'msg': f'{msg}', 'data': {}})
+ return JsonResponse({"status": 1, "msg": f"{msg}", "data": {}})
-@permission_required('sql.archive_mgt', raise_exception=True)
+@permission_required("sql.archive_mgt", raise_exception=True)
def archive_once(request):
"""单次立即调用归档任务"""
- archive_id = request.GET.get('archive_id')
- async_task('sql.archiver.archive', archive_id, timeout=-1, task_name=f'archive-{archive_id}')
- return JsonResponse({'status': 0, 'msg': 'ok', 'data': {}})
+ archive_id = request.GET.get("archive_id")
+ async_task(
+ "sql.archiver.archive",
+ archive_id,
+ timeout=-1,
+ task_name=f"archive-{archive_id}",
+ )
+ return JsonResponse({"status": 0, "msg": "ok", "data": {}})
diff --git a/sql/audit_log.py b/sql/audit_log.py
index daaccf217d..f4e29ec326 100644
--- a/sql/audit_log.py
+++ b/sql/audit_log.py
@@ -7,77 +7,93 @@
from django.utils import timezone
from django.http import HttpResponse
from django.db.models import Q
-from django.contrib.auth.decorators import login_required,permission_required
-from django.contrib.auth.signals import user_logged_in, user_logged_out, user_login_failed
+from django.contrib.auth.decorators import login_required, permission_required
+from django.contrib.auth.signals import (
+ user_logged_in,
+ user_logged_out,
+ user_login_failed,
+)
from .models import AuditEntry, Users
from common.utils.permission import superuser_required
from common.utils.extend_json_encoder import ExtendJSONEncoder
-log = logging.getLogger('default')
+log = logging.getLogger("default")
@login_required
def audit_input(request):
"""用户提交的操作信息"""
result = {}
- action = request.POST.get('action')
- extra_info = request.POST.get('extra_info','')
+ action = request.POST.get("action")
+ extra_info = request.POST.get("extra_info", "")
- result['user_id'] = request.user.id
- result['user_name'] = request.user.username
- result['user_display'] = request.user.display
- result['action'] = action
- result['extra_info'] = extra_info
+ result["user_id"] = request.user.id
+ result["user_name"] = request.user.username
+ result["user_display"] = request.user.display
+ result["action"] = action
+ result["extra_info"] = extra_info
audit = AuditEntry(**result)
audit.save()
- return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
- content_type='application/json')
+ return HttpResponse(
+ json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
+ content_type="application/json",
+ )
-@permission_required('sql.audit_user', raise_exception=True)
+@permission_required("sql.audit_user", raise_exception=True)
def audit_log(request):
"""获取审计日志列表"""
- limit = int(request.POST.get('limit',0))
- offset = int(request.POST.get('offset',0))
+ limit = int(request.POST.get("limit", 0))
+ offset = int(request.POST.get("offset", 0))
limit = offset + limit
limit = limit if limit else None
- search = request.POST.get('search', '')
- action = request.POST.get('action','')
- start_date = request.POST.get('start_date')
- end_date = request.POST.get('end_date')
+ search = request.POST.get("search", "")
+ action = request.POST.get("action", "")
+ start_date = request.POST.get("start_date")
+ end_date = request.POST.get("end_date")
filter_dict = dict()
if start_date and end_date:
- end_date = datetime.datetime.strptime(end_date, '%Y-%m-%d') + datetime.timedelta(days=1)
- filter_dict['action_time__range'] = (start_date, end_date)
+ end_date = datetime.datetime.strptime(
+ end_date, "%Y-%m-%d"
+ ) + datetime.timedelta(days=1)
+ filter_dict["action_time__range"] = (start_date, end_date)
if action:
- filter_dict['action'] = action
-
+ filter_dict["action"] = action
+
audit_log_obj = AuditEntry.objects.filter(**filter_dict)
if search:
- audit_log_obj = audit_log_obj.filter(Q(user_name__icontains=search) | Q(action__icontains=search)| Q(extra_info__icontains=search))
+ audit_log_obj = audit_log_obj.filter(
+ Q(user_name__icontains=search)
+ | Q(action__icontains=search)
+ | Q(extra_info__icontains=search)
+ )
audit_log_count = audit_log_obj.count()
- audit_log_list = audit_log_obj.order_by('-action_time')[offset:limit].values('user_id', 'user_name', 'user_display', 'action', 'extra_info', 'action_time')
+ audit_log_list = audit_log_obj.order_by("-action_time")[offset:limit].values(
+ "user_id", "user_name", "user_display", "action", "extra_info", "action_time"
+ )
# QuerySet 序列化
rows = [row for row in audit_log_list]
result = {"total": audit_log_count, "rows": rows}
# 返回查询结果
- return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
- content_type='application/json')
+ return HttpResponse(
+ json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
+ content_type="application/json",
+ )
def get_client_ip(request):
- x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
+ x_forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR")
if x_forwarded_for:
- ip = x_forwarded_for.split(',')[0]
+ ip = x_forwarded_for.split(",")[0]
else:
- ip = request.META.get('REMOTE_ADDR')
+ ip = request.META.get("REMOTE_ADDR")
return ip
@@ -85,26 +101,45 @@ def get_client_ip(request):
def user_logged_in_callback(sender, request, user, **kwargs):
ip = get_client_ip(request)
now = timezone.now()
- AuditEntry.objects.create(action=u'登入', extra_info=ip, user_id=user.id, user_name=user.username, user_display=user.display, action_time=now)
+ AuditEntry.objects.create(
+ action="登入",
+ extra_info=ip,
+ user_id=user.id,
+ user_name=user.username,
+ user_display=user.display,
+ action_time=now,
+ )
@receiver(user_logged_out)
def user_logged_out_callback(sender, request, user, **kwargs):
ip = get_client_ip(request)
now = timezone.now()
- AuditEntry.objects.create(action=u'登出', extra_info=ip, user_id=user.id, user_name=user.username, user_display=user.display, action_time=now)
+ AuditEntry.objects.create(
+ action="登出",
+ extra_info=ip,
+ user_id=user.id,
+ user_name=user.username,
+ user_display=user.display,
+ action_time=now,
+ )
@receiver(user_login_failed)
def user_login_failed_callback(sender, credentials, **kwargs):
now = timezone.now()
- user_name = credentials.get('username', None)
+ user_name = credentials.get("username", None)
user_obj = Users.objects.filter(username=user_name)[0:1]
user_count = user_obj.count()
user_id = 0
- user_display = ''
+ user_display = ""
if user_count > 0:
user_id = user_obj[0].id
user_display = user_obj[0].display
- AuditEntry.objects.create(action=u'登入失败', user_id=user_id, user_name=user_name, user_display=user_display, action_time=now)
-
+ AuditEntry.objects.create(
+ action="登入失败",
+ user_id=user_id,
+ user_name=user_name,
+ user_display=user_display,
+ action_time=now,
+ )
diff --git a/sql/binlog.py b/sql/binlog.py
index c16ec02f87..2f1b70b33e 100644
--- a/sql/binlog.py
+++ b/sql/binlog.py
@@ -19,24 +19,24 @@
from sql.notify import notify_for_my2sql
from .models import Instance
-logger = logging.getLogger('default')
+logger = logging.getLogger("default")
-@permission_required('sql.menu_my2sql', raise_exception=True)
+@permission_required("sql.menu_my2sql", raise_exception=True)
def binlog_list(request):
"""
获取binlog列表
:param request:
:return:
"""
- instance_name = request.POST.get('instance_name')
+ instance_name = request.POST.get("instance_name")
try:
instance = Instance.objects.get(instance_name=instance_name)
except Instance.DoesNotExist:
- result = {'status': 1, 'msg': '实例不存在', 'data': []}
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result = {"status": 1, "msg": "实例不存在", "data": []}
+ return HttpResponse(json.dumps(result), content_type="application/json")
query_engine = get_engine(instance=instance)
- query_result = query_engine.query('information_schema', 'show binary logs;')
+ query_result = query_engine.query("information_schema", "show binary logs;")
if not query_result.error:
column_list = query_result.column_list
rows = []
@@ -45,103 +45,129 @@ def binlog_list(request):
for row_index, row_item in enumerate(row):
row_info[column_list[row_index]] = row_item
rows.append(row_info)
- result = {'status': 0, 'msg': 'ok', 'data': rows}
+ result = {"status": 0, "msg": "ok", "data": rows}
else:
- result = {'status': 1, 'msg': query_result.error}
+ result = {"status": 1, "msg": query_result.error}
# 返回查询结果
- return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
- content_type='application/json')
+ return HttpResponse(
+ json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
+ content_type="application/json",
+ )
-@permission_required('sql.binlog_del', raise_exception=True)
+@permission_required("sql.binlog_del", raise_exception=True)
def del_binlog(request):
- instance_id = request.POST.get('instance_id')
- binlog = request.POST.get('binlog', '')
+ instance_id = request.POST.get("instance_id")
+ binlog = request.POST.get("binlog", "")
try:
instance = Instance.objects.get(id=instance_id)
except Instance.DoesNotExist:
- result = {'status': 1, 'msg': '实例不存在', 'data': []}
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result = {"status": 1, "msg": "实例不存在", "data": []}
+ return HttpResponse(json.dumps(result), content_type="application/json")
# escape
- binlog = MySQLdb.escape_string(binlog).decode('utf-8')
+ binlog = MySQLdb.escape_string(binlog).decode("utf-8")
if binlog:
query_engine = get_engine(instance=instance)
- query_result = query_engine.query(sql=fr"purge master logs to '{binlog}';")
+ query_result = query_engine.query(sql=rf"purge master logs to '{binlog}';")
if query_result.error is None:
- result = {'status': 0, 'msg': '清理成功', 'data': ''}
+ result = {"status": 0, "msg": "清理成功", "data": ""}
else:
- result = {'status': 2, 'msg': f'清理失败,Error:{query_result.error}', 'data': ''}
+ result = {
+ "status": 2,
+ "msg": f"清理失败,Error:{query_result.error}",
+ "data": "",
+ }
else:
- result = {'status': 1, 'msg': 'Error:未选择binlog!', 'data': ''}
- return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
- content_type='application/json')
+ result = {"status": 1, "msg": "Error:未选择binlog!", "data": ""}
+ return HttpResponse(
+ json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
+ content_type="application/json",
+ )
-@permission_required('sql.menu_my2sql', raise_exception=True)
+@permission_required("sql.menu_my2sql", raise_exception=True)
def my2sql(request):
"""
通过解析binlog获取SQL--使用my2sql
:param request:
:return:
"""
- instance_name = request.POST.get('instance_name')
- save_sql = True if request.POST.get('save_sql') == 'true' else False
+ instance_name = request.POST.get("instance_name")
+ save_sql = True if request.POST.get("save_sql") == "true" else False
instance = Instance.objects.get(instance_name=instance_name)
- work_type = 'rollback' if request.POST.get('rollback') == 'true' else '2sql'
- num = 30 if request.POST.get('num') == '' else int(request.POST.get('num'))
- threads = 4 if request.POST.get('threads') == '' else int(request.POST.get('threads'))
- start_file = request.POST.get('start_file')
- start_pos = request.POST.get('start_pos') if request.POST.get('start_pos') == '' else int(
- request.POST.get('start_pos'))
- end_file = request.POST.get('end_file')
- end_pos = request.POST.get('end_pos') if request.POST.get('end_pos') == '' else int(request.POST.get('end_pos'))
- stop_time = request.POST.get('stop_time')
- start_time = request.POST.get('start_time')
- only_schemas = request.POST.getlist('only_schemas')
- only_tables = request.POST.getlist('only_tables[]')
- sql_type = [] if request.POST.getlist('sql_type[]') == [] else request.POST.getlist('sql_type[]')
- extra_info = True if request.POST.get('extra_info') == 'true' else False
- ignore_primary_key = True if request.POST.get('ignore_primary_key') == 'true' else False
- full_columns = True if request.POST.get('full_columns') == 'true' else False
- no_db_prefix = True if request.POST.get('no_db_prefix') == 'true' else False
- file_per_table = True if request.POST.get('file_per_table') == 'true' else False
-
- result = {'status': 0, 'msg': 'ok', 'data': []}
+ work_type = "rollback" if request.POST.get("rollback") == "true" else "2sql"
+ num = 30 if request.POST.get("num") == "" else int(request.POST.get("num"))
+ threads = (
+ 4 if request.POST.get("threads") == "" else int(request.POST.get("threads"))
+ )
+ start_file = request.POST.get("start_file")
+ start_pos = (
+ request.POST.get("start_pos")
+ if request.POST.get("start_pos") == ""
+ else int(request.POST.get("start_pos"))
+ )
+ end_file = request.POST.get("end_file")
+ end_pos = (
+ request.POST.get("end_pos")
+ if request.POST.get("end_pos") == ""
+ else int(request.POST.get("end_pos"))
+ )
+ stop_time = request.POST.get("stop_time")
+ start_time = request.POST.get("start_time")
+ only_schemas = request.POST.getlist("only_schemas")
+ only_tables = request.POST.getlist("only_tables[]")
+ sql_type = (
+ []
+ if request.POST.getlist("sql_type[]") == []
+ else request.POST.getlist("sql_type[]")
+ )
+ extra_info = True if request.POST.get("extra_info") == "true" else False
+ ignore_primary_key = (
+ True if request.POST.get("ignore_primary_key") == "true" else False
+ )
+ full_columns = True if request.POST.get("full_columns") == "true" else False
+ no_db_prefix = True if request.POST.get("no_db_prefix") == "true" else False
+ file_per_table = True if request.POST.get("file_per_table") == "true" else False
+
+ result = {"status": 0, "msg": "ok", "data": []}
# 提交给my2sql进行解析
my2sql = My2SQL()
# 准备参数
instance_password = shlex.quote(f'"{str(instance.password)}"')
- args = {"conn_options": fr"-host {shlex.quote(str(instance.host))} -user {shlex.quote(str(instance.user))} \
+ args = {
+ "conn_options": rf"-host {shlex.quote(str(instance.host))} -user {shlex.quote(str(instance.user))} \
-password '{instance_password}' -port {shlex.quote(str(instance.port))} ",
- "work-type": work_type,
- "start-file": start_file,
- "start-pos": start_pos,
- "stop-file": end_file,
- "stop-pos": end_pos,
- "start-datetime": '"'+start_time+'"',
- "stop-datetime": '"'+stop_time+'"',
- "databases": ' '.join(only_schemas),
- "tables": ','.join(only_tables),
- "sql": ','.join(sql_type),
- "instance": instance,
- "threads": threads,
- "add-extraInfo": extra_info,
- "ignore-primaryKey-forInsert": ignore_primary_key,
- "full-columns": full_columns,
- "do-not-add-prifixDb": no_db_prefix,
- "file-per-table": file_per_table,
- "output-toScreen": True
- }
+ "work-type": work_type,
+ "start-file": start_file,
+ "start-pos": start_pos,
+ "stop-file": end_file,
+ "stop-pos": end_pos,
+ "start-datetime": '"' + start_time + '"',
+ "stop-datetime": '"' + stop_time + '"',
+ "databases": " ".join(only_schemas),
+ "tables": ",".join(only_tables),
+ "sql": ",".join(sql_type),
+ "instance": instance,
+ "threads": threads,
+ "add-extraInfo": extra_info,
+ "ignore-primaryKey-forInsert": ignore_primary_key,
+ "full-columns": full_columns,
+ "do-not-add-prifixDb": no_db_prefix,
+ "file-per-table": file_per_table,
+ "output-toScreen": True,
+ }
# 参数检查
args_check_result = my2sql.check_args(args)
- if args_check_result['status'] == 1:
- return HttpResponse(json.dumps(args_check_result), content_type='application/json')
+ if args_check_result["status"] == 1:
+ return HttpResponse(
+ json.dumps(args_check_result), content_type="application/json"
+ )
# 参数转换
cmd_args = my2sql.generate_args2cmd(args, shell=True)
@@ -151,15 +177,15 @@ def my2sql(request):
# 读取前num行后结束
rows = []
n = 1
- for line in iter(p.stdout.readline, ''):
+ for line in iter(p.stdout.readline, ""):
if n <= num and isinstance(line, str):
- if line[0:6].upper() in ('INSERT', 'DELETE', 'UPDATE'):
+ if line[0:6].upper() in ("INSERT", "DELETE", "UPDATE"):
n = n + 1
row_info = {}
try:
- row_info['sql'] = line + ';'
+ row_info["sql"] = line + ";"
except IndexError:
- row_info['sql'] = line + ';'
+ row_info["sql"] = line + ";"
rows.append(row_info)
else:
break
@@ -168,27 +194,35 @@ def my2sql(request):
# 判断是否有异常
stderr = p.stderr.read()
if stderr and isinstance(stderr, str):
- result['status'] = 1
- result['msg'] = stderr
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result["status"] = 1
+ result["msg"] = stderr
+ return HttpResponse(json.dumps(result), content_type="application/json")
# 终止子进程
p.kill()
- result['data'] = rows
+ result["data"] = rows
except Exception as e:
logger.error(traceback.format_exc())
- result['status'] = 1
- result['msg'] = str(e)
+ result["status"] = 1
+ result["msg"] = str(e)
# 异步保存到文件
if save_sql:
- args.pop('conn_options')
- args.pop('output-toScreen')
- async_task(my2sql_file, args=args, user=request.user, hook=notify_for_my2sql, timeout=-1,
- task_name=f'my2sql-{time.time()}')
+ args.pop("conn_options")
+ args.pop("output-toScreen")
+ async_task(
+ my2sql_file,
+ args=args,
+ user=request.user,
+ hook=notify_for_my2sql,
+ timeout=-1,
+ task_name=f"my2sql-{time.time()}",
+ )
# 返回查询结果
- return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
- content_type='application/json')
+ return HttpResponse(
+ json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
+ content_type="application/json",
+ )
def my2sql_file(args, user):
@@ -199,12 +233,12 @@ def my2sql_file(args, user):
:return:
"""
my2sql = My2SQL()
- instance = args.get('instance')
+ instance = args.get("instance")
instance_password = shlex.quote(f'"{str(instance.password)}"')
- conn_options = fr"-host {shlex.quote(str(instance.host))} -user {shlex.quote(str(instance.user))} \
+ conn_options = rf"-host {shlex.quote(str(instance.host))} -user {shlex.quote(str(instance.user))} \
-password '{instance_password}' -port {shlex.quote(str(instance.port))} "
- args['conn_options'] = conn_options
- path = os.path.join(settings.BASE_DIR, 'downloads/my2sql/')
+ args["conn_options"] = conn_options
+ path = os.path.join(settings.BASE_DIR, "downloads/my2sql/")
os.makedirs(path, exist_ok=True)
# 参数转换
diff --git a/sql/data_dictionary.py b/sql/data_dictionary.py
index 13aa3816e0..2ee48b8d5e 100644
--- a/sql/data_dictionary.py
+++ b/sql/data_dictionary.py
@@ -16,73 +16,91 @@
from .models import Instance
-@permission_required('sql.menu_data_dictionary', raise_exception=True)
+@permission_required("sql.menu_data_dictionary", raise_exception=True)
def table_list(request):
"""数据字典获取表列表"""
- instance_name = request.GET.get('instance_name', '')
- db_name = request.GET.get('db_name', '')
- db_type = request.GET.get('db_type', '')
+ instance_name = request.GET.get("instance_name", "")
+ db_name = request.GET.get("db_name", "")
+ db_type = request.GET.get("db_type", "")
if instance_name and db_name:
try:
- instance = Instance.objects.get(instance_name=instance_name, db_type=db_type)
+ instance = Instance.objects.get(
+ instance_name=instance_name, db_type=db_type
+ )
query_engine = get_engine(instance=instance)
data = query_engine.get_group_tables_by_db(db_name=db_name)
- res = {'status': 0, 'data': data}
+ res = {"status": 0, "data": data}
except Instance.DoesNotExist:
- res = {'status': 1, 'msg': 'Instance.DoesNotExist'}
+ res = {"status": 1, "msg": "Instance.DoesNotExist"}
except Exception as e:
- res = {'status': 1, 'msg': str(e)}
+ res = {"status": 1, "msg": str(e)}
else:
- res = {'status': 1, 'msg': '非法调用!'}
- return HttpResponse(json.dumps(res, cls=ExtendJSONEncoder, bigint_as_string=True),
- content_type='application/json')
+ res = {"status": 1, "msg": "非法调用!"}
+ return HttpResponse(
+ json.dumps(res, cls=ExtendJSONEncoder, bigint_as_string=True),
+ content_type="application/json",
+ )
-@permission_required('sql.menu_data_dictionary', raise_exception=True)
+@permission_required("sql.menu_data_dictionary", raise_exception=True)
def table_info(request):
"""数据字典获取表信息"""
- instance_name = request.GET.get('instance_name', '')
- db_name = request.GET.get('db_name', '')
- tb_name = request.GET.get('tb_name', '')
- db_type = request.GET.get('db_type', '')
+ instance_name = request.GET.get("instance_name", "")
+ db_name = request.GET.get("db_name", "")
+ tb_name = request.GET.get("tb_name", "")
+ db_type = request.GET.get("db_type", "")
if instance_name and db_name and tb_name:
data = {}
try:
- instance = Instance.objects.get(instance_name=instance_name, db_type=db_type)
+ instance = Instance.objects.get(
+ instance_name=instance_name, db_type=db_type
+ )
query_engine = get_engine(instance=instance)
- data['meta_data'] = query_engine.get_table_meta_data(db_name=db_name, tb_name=tb_name)
- data['desc'] = query_engine.get_table_desc_data(db_name=db_name, tb_name=tb_name)
- data['index'] = query_engine.get_table_index_data(db_name=db_name, tb_name=tb_name)
+ data["meta_data"] = query_engine.get_table_meta_data(
+ db_name=db_name, tb_name=tb_name
+ )
+ data["desc"] = query_engine.get_table_desc_data(
+ db_name=db_name, tb_name=tb_name
+ )
+ data["index"] = query_engine.get_table_index_data(
+ db_name=db_name, tb_name=tb_name
+ )
# mysql数据库可以获取创建表格的SQL语句,mssql暂无找到生成创建表格的SQL语句
- if instance.db_type == 'mysql':
- _create_sql = query_engine.query(db_name, "show create table `%s`;" % tb_name)
- data['create_sql'] = _create_sql.rows
- res = {'status': 0, 'data': data}
+ if instance.db_type == "mysql":
+ _create_sql = query_engine.query(
+ db_name, "show create table `%s`;" % tb_name
+ )
+ data["create_sql"] = _create_sql.rows
+ res = {"status": 0, "data": data}
except Instance.DoesNotExist:
- res = {'status': 1, 'msg': 'Instance.DoesNotExist'}
+ res = {"status": 1, "msg": "Instance.DoesNotExist"}
except Exception as e:
- res = {'status': 1, 'msg': str(e)}
+ res = {"status": 1, "msg": str(e)}
else:
- res = {'status': 1, 'msg': '非法调用!'}
- return HttpResponse(json.dumps(res, cls=ExtendJSONEncoder, bigint_as_string=True),
- content_type='application/json')
+ res = {"status": 1, "msg": "非法调用!"}
+ return HttpResponse(
+ json.dumps(res, cls=ExtendJSONEncoder, bigint_as_string=True),
+ content_type="application/json",
+ )
-@permission_required('sql.data_dictionary_export', raise_exception=True)
+@permission_required("sql.data_dictionary_export", raise_exception=True)
def export(request):
"""导出数据字典"""
- instance_name = request.GET.get('instance_name', '')
- db_name = request.GET.get('db_name', '')
+ instance_name = request.GET.get("instance_name", "")
+ db_name = request.GET.get("db_name", "")
# escape
- db_name = MySQLdb.escape_string(db_name).decode('utf-8')
+ db_name = MySQLdb.escape_string(db_name).decode("utf-8")
try:
- instance = user_instances(request.user, db_type=['mysql', 'mssql', 'oracle']).get(instance_name=instance_name)
+ instance = user_instances(
+ request.user, db_type=["mysql", "mssql", "oracle"]
+ ).get(instance_name=instance_name)
query_engine = get_engine(instance=instance)
except Instance.DoesNotExist:
- return JsonResponse({'status': 1, 'msg': '你所在组未关联该实例!', 'data': []})
+ return JsonResponse({"status": 1, "msg": "你所在组未关联该实例!", "data": []})
# 普通用户仅可以获取指定数据库的字典信息
if db_name:
@@ -91,24 +109,38 @@ def export(request):
elif request.user.is_superuser:
dbs = query_engine.get_all_databases().rows
else:
- return JsonResponse({'status': 1, 'msg': f'仅管理员可以导出整个实例的字典信息!', 'data': []})
+ return JsonResponse({"status": 1, "msg": f"仅管理员可以导出整个实例的字典信息!", "data": []})
# 获取数据,存入目录
- path = os.path.join(settings.BASE_DIR, 'downloads/dictionary')
+ path = os.path.join(settings.BASE_DIR, "downloads/dictionary")
os.makedirs(path, exist_ok=True)
for db in dbs:
table_metas = query_engine.get_tables_metas_data(db_name=db)
- context = {"db_name": db_name, "tables": table_metas, "export_time": datetime.datetime.now()}
- data = loader.render_to_string(template_name="dictionaryexport.html", context=context, request=request)
- with open(f'{path}/{instance_name}_{db}.html', 'w') as f:
+ context = {
+ "db_name": db_name,
+ "tables": table_metas,
+ "export_time": datetime.datetime.now(),
+ }
+ data = loader.render_to_string(
+ template_name="dictionaryexport.html", context=context, request=request
+ )
+ with open(f"{path}/{instance_name}_{db}.html", "w") as f:
f.write(data)
# 关闭连接
query_engine.close()
if db_name:
- response = FileResponse(open(f'{path}/{instance_name}_{db_name}.html', 'rb'))
- response['Content-Type'] = 'application/octet-stream'
- response['Content-Disposition'] = f'attachment;filename="{quote(instance_name)}_{quote(db_name)}.html"'
+ response = FileResponse(open(f"{path}/{instance_name}_{db_name}.html", "rb"))
+ response["Content-Type"] = "application/octet-stream"
+ response[
+ "Content-Disposition"
+ ] = f'attachment;filename="{quote(instance_name)}_{quote(db_name)}.html"'
return response
else:
- return JsonResponse({'status': 0, 'msg': f'实例{instance_name}数据字典导出成功,请到downloads目录下载!', 'data': []})
+ return JsonResponse(
+ {
+ "status": 0,
+ "msg": f"实例{instance_name}数据字典导出成功,请到downloads目录下载!",
+ "data": [],
+ }
+ )
diff --git a/sql/db_diagnostic.py b/sql/db_diagnostic.py
index 8f57fde337..be8fdfc12b 100644
--- a/sql/db_diagnostic.py
+++ b/sql/db_diagnostic.py
@@ -1,7 +1,8 @@
import logging
import traceback
import MySQLdb
-#import simplejson as json
+
+# import simplejson as json
import json
from django.contrib.auth.decorators import permission_required
@@ -12,206 +13,245 @@
from sql.utils.resource_group import user_instances
from .models import AliyunRdsConfig, Instance
-from .aliyun_rds import process_status as aliyun_process_status, create_kill_session as aliyun_create_kill_session, \
- kill_session as aliyun_kill_session, sapce_status as aliyun_sapce_status
+from .aliyun_rds import (
+ process_status as aliyun_process_status,
+ create_kill_session as aliyun_create_kill_session,
+ kill_session as aliyun_kill_session,
+ sapce_status as aliyun_sapce_status,
+)
-logger = logging.getLogger('default')
+logger = logging.getLogger("default")
# 问题诊断--进程列表
-@permission_required('sql.process_view', raise_exception=True)
+@permission_required("sql.process_view", raise_exception=True)
def process(request):
- instance_name = request.POST.get('instance_name')
- command_type = request.POST.get('command_type')
+ instance_name = request.POST.get("instance_name")
+ command_type = request.POST.get("command_type")
try:
instance = user_instances(request.user).get(instance_name=instance_name)
except Instance.DoesNotExist:
- result = {'status': 1, 'msg': '你所在组未关联该实例', 'data': []}
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result = {"status": 1, "msg": "你所在组未关联该实例", "data": []}
+ return HttpResponse(json.dumps(result), content_type="application/json")
query_engine = get_engine(instance=instance)
query_result = None
- if instance.db_type == 'mysql':
+ if instance.db_type == "mysql":
# 判断是RDS还是其他实例
if AliyunRdsConfig.objects.filter(instance=instance, is_enable=True).exists():
result = aliyun_process_status(request)
else:
query_result = query_engine.processlist(command_type)
- elif instance.db_type == 'mongo':
- query_result = query_engine.current_op(command_type)
+ elif instance.db_type == "mongo":
+ query_result = query_engine.current_op(command_type)
else:
- result = {'status': 1, 'msg': '暂时不支持{}类型数据库的进程列表查询'.format(instance.db_type), 'data': []}
- return HttpResponse(json.dumps(result), content_type='application/json')
-
- if query_result:
+ result = {
+ "status": 1,
+ "msg": "暂时不支持{}类型数据库的进程列表查询".format(instance.db_type),
+ "data": [],
+ }
+ return HttpResponse(json.dumps(result), content_type="application/json")
+
+ if query_result:
if not query_result.error:
processlist = query_result.to_dict()
- result = {'status': 0, 'msg': 'ok', 'rows': processlist}
+ result = {"status": 0, "msg": "ok", "rows": processlist}
else:
- result = {'status': 1, 'msg': query_result.error}
-
+ result = {"status": 1, "msg": query_result.error}
+
# 返回查询结果
# ExtendJSONEncoderBytes 使用json模块,bigint_as_string只支持simplejson
- return HttpResponse(json.dumps(result, cls=ExtendJSONEncoderBytes),
- content_type='application/json')
+ return HttpResponse(
+ json.dumps(result, cls=ExtendJSONEncoderBytes), content_type="application/json"
+ )
# 问题诊断--通过线程id构建请求 这里只是用于确定将要kill的线程id还在运行
-@permission_required('sql.process_kill', raise_exception=True)
+@permission_required("sql.process_kill", raise_exception=True)
def create_kill_session(request):
- instance_name = request.POST.get('instance_name')
- thread_ids = request.POST.get('ThreadIDs')
+ instance_name = request.POST.get("instance_name")
+ thread_ids = request.POST.get("ThreadIDs")
try:
instance = user_instances(request.user).get(instance_name=instance_name)
except Instance.DoesNotExist:
- result = {'status': 1, 'msg': '你所在组未关联该实例', 'data': []}
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result = {"status": 1, "msg": "你所在组未关联该实例", "data": []}
+ return HttpResponse(json.dumps(result), content_type="application/json")
- result = {'status': 0, 'msg': 'ok', 'data': []}
+ result = {"status": 0, "msg": "ok", "data": []}
query_engine = get_engine(instance=instance)
- if instance.db_type == 'mysql':
+ if instance.db_type == "mysql":
# 判断是RDS还是其他实例
if AliyunRdsConfig.objects.filter(instance=instance, is_enable=True).exists():
result = aliyun_create_kill_session(request)
else:
- result['data'] = query_engine.get_kill_command(json.loads(thread_ids))
- elif instance.db_type == 'mongo':
+ result["data"] = query_engine.get_kill_command(json.loads(thread_ids))
+ elif instance.db_type == "mongo":
kill_command = query_engine.get_kill_command(json.loads(thread_ids))
- result['data'] = kill_command
+ result["data"] = kill_command
else:
- result = {'status': 1, 'msg': '暂时不支持{}类型数据库通过进程id构建请求'.format(instance.db_type), 'data': []}
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result = {
+ "status": 1,
+ "msg": "暂时不支持{}类型数据库通过进程id构建请求".format(instance.db_type),
+ "data": [],
+ }
+ return HttpResponse(json.dumps(result), content_type="application/json")
# 返回查询结果
- return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
- content_type='application/json')
+ return HttpResponse(
+ json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
+ content_type="application/json",
+ )
# 问题诊断--终止会话 这里是实际执行kill的操作
-@permission_required('sql.process_kill', raise_exception=True)
+@permission_required("sql.process_kill", raise_exception=True)
def kill_session(request):
- instance_name = request.POST.get('instance_name')
- thread_ids = request.POST.get('ThreadIDs')
- result = {'status': 0, 'msg': 'ok', 'data': []}
+ instance_name = request.POST.get("instance_name")
+ thread_ids = request.POST.get("ThreadIDs")
+ result = {"status": 0, "msg": "ok", "data": []}
try:
instance = user_instances(request.user).get(instance_name=instance_name)
except Instance.DoesNotExist:
- result = {'status': 1, 'msg': '你所在组未关联该实例', 'data': []}
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result = {"status": 1, "msg": "你所在组未关联该实例", "data": []}
+ return HttpResponse(json.dumps(result), content_type="application/json")
engine = get_engine(instance=instance)
r = None
- if instance.db_type == 'mysql':
+ if instance.db_type == "mysql":
# 判断是RDS还是其他实例
if AliyunRdsConfig.objects.filter(instance=instance, is_enable=True).exists():
result = aliyun_kill_session(request)
else:
r = engine.kill(json.loads(thread_ids))
- elif instance.db_type == 'mongo':
+ elif instance.db_type == "mongo":
r = engine.kill_op(json.loads(thread_ids))
else:
- result = {'status': 1, 'msg': '暂时不支持{}类型数据库终止会话'.format(instance.db_type), 'data': []}
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result = {
+ "status": 1,
+ "msg": "暂时不支持{}类型数据库终止会话".format(instance.db_type),
+ "data": [],
+ }
+ return HttpResponse(json.dumps(result), content_type="application/json")
if r and r.error:
- result = {'status': 1, 'msg': r.error, 'data': []}
+ result = {"status": 1, "msg": r.error, "data": []}
# 返回查询结果
- return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
- content_type='application/json')
+ return HttpResponse(
+ json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
+ content_type="application/json",
+ )
# 问题诊断--表空间信息
-@permission_required('sql.tablespace_view', raise_exception=True)
+@permission_required("sql.tablespace_view", raise_exception=True)
def tablesapce(request):
- instance_name = request.POST.get('instance_name')
- offset = int(request.POST.get('offset',0))
- limit = int(request.POST.get('limit',14))
+ instance_name = request.POST.get("instance_name")
+ offset = int(request.POST.get("offset", 0))
+ limit = int(request.POST.get("limit", 14))
try:
instance = user_instances(request.user).get(instance_name=instance_name)
except Instance.DoesNotExist:
- result = {'status': 1, 'msg': '你所在组未关联该实例', 'data': []}
- return HttpResponse(json.dumps(result), content_type='application/json')
-
+ result = {"status": 1, "msg": "你所在组未关联该实例", "data": []}
+ return HttpResponse(json.dumps(result), content_type="application/json")
+
query_engine = get_engine(instance=instance)
- if instance.db_type == 'mysql':
+ if instance.db_type == "mysql":
# 判断是RDS还是其他实例
if AliyunRdsConfig.objects.filter(instance=instance, is_enable=True).exists():
result = aliyun_sapce_status(request)
else:
- query_result = query_engine.tablesapce(offset,limit)
+ query_result = query_engine.tablesapce(offset, limit)
r = query_engine.tablesapce_num()
total = r.rows[0][0]
else:
- result = {'status': 1, 'msg': '暂时不支持{}类型数据库的表空间信息查询'.format(instance.db_type), 'data': []}
- return HttpResponse(json.dumps(result), content_type='application/json')
-
- if query_result:
+ result = {
+ "status": 1,
+ "msg": "暂时不支持{}类型数据库的表空间信息查询".format(instance.db_type),
+ "data": [],
+ }
+ return HttpResponse(json.dumps(result), content_type="application/json")
+
+ if query_result:
if not query_result.error:
table_space = query_result.to_dict()
- result = {'status': 0, 'msg': 'ok', 'rows': table_space, 'total':total}
+ result = {"status": 0, "msg": "ok", "rows": table_space, "total": total}
else:
- result = {'status': 1, 'msg': query_result.error}
+ result = {"status": 1, "msg": query_result.error}
# 返回查询结果
- return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
- content_type='application/json')
+ return HttpResponse(
+ json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
+ content_type="application/json",
+ )
# 问题诊断--锁等待
-@permission_required('sql.trxandlocks_view', raise_exception=True)
+@permission_required("sql.trxandlocks_view", raise_exception=True)
def trxandlocks(request):
- instance_name = request.POST.get('instance_name')
+ instance_name = request.POST.get("instance_name")
try:
instance = user_instances(request.user).get(instance_name=instance_name)
except Instance.DoesNotExist:
- result = {'status': 1, 'msg': '你所在组未关联该实例', 'data': []}
- return HttpResponse(json.dumps(result), content_type='application/json')
-
+ result = {"status": 1, "msg": "你所在组未关联该实例", "data": []}
+ return HttpResponse(json.dumps(result), content_type="application/json")
+
query_engine = get_engine(instance=instance)
- if instance.db_type == 'mysql':
+ if instance.db_type == "mysql":
query_result = query_engine.trxandlocks()
-
+
else:
- result = {'status': 1, 'msg': '暂时不支持{}类型数据库的锁等待查询'.format(instance.db_type), 'data': []}
- return HttpResponse(json.dumps(result), content_type='application/json')
-
+ result = {
+ "status": 1,
+ "msg": "暂时不支持{}类型数据库的锁等待查询".format(instance.db_type),
+ "data": [],
+ }
+ return HttpResponse(json.dumps(result), content_type="application/json")
+
if not query_result.error:
trxandlocks = query_result.to_dict()
- result = {'status': 0, 'msg': 'ok', 'rows': trxandlocks}
+ result = {"status": 0, "msg": "ok", "rows": trxandlocks}
else:
- result = {'status': 1, 'msg': query_result.error}
+ result = {"status": 1, "msg": query_result.error}
# 返回查询结果
- return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
- content_type='application/json')
+ return HttpResponse(
+ json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
+ content_type="application/json",
+ )
# 问题诊断--长事务
-@permission_required('sql.trx_view', raise_exception=True)
+@permission_required("sql.trx_view", raise_exception=True)
def innodb_trx(request):
- instance_name = request.POST.get('instance_name')
+ instance_name = request.POST.get("instance_name")
try:
instance = user_instances(request.user).get(instance_name=instance_name)
except Instance.DoesNotExist:
- result = {'status': 1, 'msg': '你所在组未关联该实例', 'data': []}
- return HttpResponse(json.dumps(result), content_type='application/json')
-
+ result = {"status": 1, "msg": "你所在组未关联该实例", "data": []}
+ return HttpResponse(json.dumps(result), content_type="application/json")
+
query_engine = get_engine(instance=instance)
- if instance.db_type == 'mysql':
+ if instance.db_type == "mysql":
query_result = query_engine.get_long_transaction()
else:
- result = {'status': 1, 'msg': '暂时不支持{}类型数据库的长事务查询'.format(instance.db_type), 'data': []}
- return HttpResponse(json.dumps(result), content_type='application/json')
-
+ result = {
+ "status": 1,
+ "msg": "暂时不支持{}类型数据库的长事务查询".format(instance.db_type),
+ "data": [],
+ }
+ return HttpResponse(json.dumps(result), content_type="application/json")
+
if not query_result.error:
trx = query_result.to_dict()
- result = {'status': 0, 'msg': 'ok', 'rows': trx}
+ result = {"status": 0, "msg": "ok", "rows": trx}
else:
- result = {'status': 1, 'msg': query_result.error}
+ result = {"status": 1, "msg": query_result.error}
# 返回查询结果
- return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
- content_type='application/json')
+ return HttpResponse(
+ json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
+ content_type="application/json",
+ )
diff --git a/sql/engines/__init__.py b/sql/engines/__init__.py
index 5b9e108915..a101abed13 100644
--- a/sql/engines/__init__.py
+++ b/sql/engines/__init__.py
@@ -5,6 +5,7 @@
class EngineBase:
"""enginebase 只定义了init函数和若干方法的名字, 具体实现用mysql.py pg.py等实现"""
+
test_query = None
def __init__(self, instance=None):
@@ -35,14 +36,14 @@ def __init__(self, instance=None):
self.host, self.port = self.ssh.get_ssh()
def __del__(self):
- if hasattr(self, 'ssh'):
+ if hasattr(self, "ssh"):
del self.ssh
- if hasattr(self, 'remotessh'):
+ if hasattr(self, "remotessh"):
del self.remotessh
def remote_instance_conn(self, instance=None):
# 判断如果配置了隧道则连接隧道
- if not hasattr(self, 'remotessh') and instance.tunnel:
+ if not hasattr(self, "remotessh") and instance.tunnel:
self.remotessh = SSHConnection(
instance.host,
instance.port,
@@ -61,7 +62,12 @@ def remote_instance_conn(self, instance=None):
self.remote_port = instance.port
self.remote_user = instance.user
self.remote_password = instance.password
- return self.remote_host, self.remote_port, self.remote_user, self.remote_password
+ return (
+ self.remote_host,
+ self.remote_port,
+ self.remote_user,
+ self.remote_password,
+ )
def get_connection(self, db_name=None):
"""返回一个conn实例"""
@@ -137,20 +143,20 @@ def describe_table(self, db_name, tb_name, **kwargs):
def query_check(self, db_name=None, sql=""):
"""查询语句的检查、注释去除、切分, 返回一个字典 {'bad_query': bool, 'filtered_sql': str}"""
- def filter_sql(self, sql='', limit_num=0):
+ def filter_sql(self, sql="", limit_num=0):
"""给查询语句增加结果级限制或者改写语句, 返回修改后的语句"""
return sql.strip()
- def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs):
+ def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs):
"""实际查询 返回一个ResultSet"""
return ResultSet()
- def query_masking(self, db_name=None, sql='', resultset=None):
+ def query_masking(self, db_name=None, sql="", resultset=None):
"""传入 sql语句, db名, 结果集,
返回一个脱敏后的结果集"""
return resultset
- def execute_check(self, db_name=None, sql=''):
+ def execute_check(self, db_name=None, sql=""):
"""执行语句的检查 返回一个ReviewSet"""
return ReviewSet()
@@ -209,12 +215,12 @@ def get_engine(instance=None): # pragma: no cover
return PhoenixEngine(instance=instance)
- elif instance.db_type == 'odps':
+ elif instance.db_type == "odps":
from .odps import ODPSEngine
return ODPSEngine(instance=instance)
- elif instance.db_type == 'clickhouse':
+ elif instance.db_type == "clickhouse":
from .clickhouse import ClickHouseEngine
return ClickHouseEngine(instance=instance)
diff --git a/sql/engines/clickhouse.py b/sql/engines/clickhouse.py
index 3d991cd9c2..8944d27b89 100644
--- a/sql/engines/clickhouse.py
+++ b/sql/engines/clickhouse.py
@@ -9,7 +9,7 @@
import logging
import re
-logger = logging.getLogger('default')
+logger = logging.getLogger("default")
class ClickHouseEngine(EngineBase):
@@ -23,20 +23,31 @@ def get_connection(self, db_name=None):
if self.conn:
return self.conn
if db_name:
- self.conn = connect(host=self.host, port=self.port, user=self.user, password=self.password,
- database=db_name, connect_timeout=10)
+ self.conn = connect(
+ host=self.host,
+ port=self.port,
+ user=self.user,
+ password=self.password,
+ database=db_name,
+ connect_timeout=10,
+ )
else:
- self.conn = connect(host=self.host, port=self.port, user=self.user, password=self.password,
- connect_timeout=10)
+ self.conn = connect(
+ host=self.host,
+ port=self.port,
+ user=self.user,
+ password=self.password,
+ connect_timeout=10,
+ )
return self.conn
@property
def name(self):
- return 'ClickHouse'
+ return "ClickHouse"
@property
def info(self):
- return 'ClickHouse engine'
+ return "ClickHouse engine"
@property
def auto_backup(self):
@@ -47,8 +58,8 @@ def auto_backup(self):
def server_version(self):
sql = "select value from system.build_options where name = 'VERSION_FULL';"
result = self.query(sql=sql)
- version = result.rows[0][0].split(' ')[1]
- return tuple([int(n) for n in version.split('.')[:3]])
+ version = result.rows[0][0].split(" ")[1]
+ return tuple([int(n) for n in version.split(".")[:3]])
def get_table_engine(self, tb_name):
"""获取某个table的engine type"""
@@ -58,17 +69,21 @@ def get_table_engine(self, tb_name):
and name='{tb_name.split('.')[1]}'"""
query_result = self.query(sql=sql)
if query_result.rows:
- result = {'status': 1, 'engine': query_result.rows[0][0]}
+ result = {"status": 1, "engine": query_result.rows[0][0]}
else:
- result = {'status': 0, 'engine': 'None'}
+ result = {"status": 0, "engine": "None"}
return result
def get_all_databases(self):
"""获取数据库列表, 返回一个ResultSet"""
sql = "show databases"
result = self.query(sql=sql)
- db_list = [row[0] for row in result.rows
- if row[0] not in ('system', 'INFORMATION_SCHEMA', 'information_schema', 'datasets')]
+ db_list = [
+ row[0]
+ for row in result.rows
+ if row[0]
+ not in ("system", "INFORMATION_SCHEMA", "information_schema", "datasets")
+ ]
result.rows = db_list
return result
@@ -101,11 +116,13 @@ def describe_table(self, db_name, tb_name, **kwargs):
sql = f"show create table `{tb_name}`;"
result = self.query(db_name=db_name, sql=sql)
- result.rows[0] = (tb_name,) + (result.rows[0][0].replace('(', '(\n ').replace(',', ',\n '),)
+ result.rows[0] = (tb_name,) + (
+ result.rows[0][0].replace("(", "(\n ").replace(",", ",\n "),
+ )
return result
- def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs):
- """返回 ResultSet """
+ def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs):
+ """返回 ResultSet"""
result_set = ResultSet(full_sql=sql)
try:
conn = self.get_connection(db_name=db_name)
@@ -122,91 +139,97 @@ def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs):
result_set.affected_rows = len(rows)
except Exception as e:
logger.warning(f"ClickHouse语句执行报错,语句:{sql},错误信息{e}")
- result_set.error = str(e).split('Stack trace')[0]
+ result_set.error = str(e).split("Stack trace")[0]
finally:
if close_conn:
self.close()
return result_set
- def query_check(self, db_name=None, sql=''):
+ def query_check(self, db_name=None, sql=""):
# 查询语句的检查、注释去除、切分
- result = {'msg': '', 'bad_query': False, 'filtered_sql': sql, 'has_star': False}
+ result = {"msg": "", "bad_query": False, "filtered_sql": sql, "has_star": False}
# 删除注释语句,进行语法判断,执行第一条有效sql
try:
sql = sqlparse.format(sql, strip_comments=True)
sql = sqlparse.split(sql)[0]
- result['filtered_sql'] = sql.strip()
+ result["filtered_sql"] = sql.strip()
except IndexError:
- result['bad_query'] = True
- result['msg'] = '没有有效的SQL语句'
+ result["bad_query"] = True
+ result["msg"] = "没有有效的SQL语句"
if re.match(r"^select|^show|^explain", sql, re.I) is None:
- result['bad_query'] = True
- result['msg'] = '不支持的查询语法类型!'
- if '*' in sql:
- result['has_star'] = True
- result['msg'] = 'SQL语句中含有 * '
+ result["bad_query"] = True
+ result["msg"] = "不支持的查询语法类型!"
+ if "*" in sql:
+ result["has_star"] = True
+ result["msg"] = "SQL语句中含有 * "
# clickhouse 20.6.3版本开始正式支持explain语法
if re.match(r"^explain", sql, re.I) and self.server_version < (20, 6, 3):
- result['bad_query'] = True
- result['msg'] = f"当前ClickHouse实例版本低于20.6.3,不支持explain!"
+ result["bad_query"] = True
+ result["msg"] = f"当前ClickHouse实例版本低于20.6.3,不支持explain!"
# select语句先使用Explain判断语法是否正确
if re.match(r"^select", sql, re.I) and self.server_version >= (20, 6, 3):
explain_result = self.query(db_name=db_name, sql=f"explain {sql}")
if explain_result.error:
- result['bad_query'] = True
- result['msg'] = explain_result.error
+ result["bad_query"] = True
+ result["msg"] = explain_result.error
return result
- def filter_sql(self, sql='', limit_num=0):
+ def filter_sql(self, sql="", limit_num=0):
# 对查询sql增加limit限制,limit n 或 limit n,n 或 limit n offset n统一改写成limit n
- sql = sql.rstrip(';').strip()
+ sql = sql.rstrip(";").strip()
if re.match(r"^select", sql, re.I):
# LIMIT N
- limit_n = re.compile(r'limit\s+(\d+)\s*$', re.I)
+ limit_n = re.compile(r"limit\s+(\d+)\s*$", re.I)
# LIMIT M OFFSET N
- limit_offset = re.compile(r'limit\s+(\d+)\s+offset\s+(\d+)\s*$', re.I)
+ limit_offset = re.compile(r"limit\s+(\d+)\s+offset\s+(\d+)\s*$", re.I)
# LIMIT M,N
- offset_comma_limit = re.compile(r'limit\s+(\d+)\s*,\s*(\d+)\s*$', re.I)
+ offset_comma_limit = re.compile(r"limit\s+(\d+)\s*,\s*(\d+)\s*$", re.I)
if limit_n.search(sql):
sql_limit = limit_n.search(sql).group(1)
limit_num = min(int(limit_num), int(sql_limit))
- sql = limit_n.sub(f'limit {limit_num};', sql)
+ sql = limit_n.sub(f"limit {limit_num};", sql)
elif limit_offset.search(sql):
sql_limit = limit_offset.search(sql).group(1)
sql_offset = limit_offset.search(sql).group(2)
limit_num = min(int(limit_num), int(sql_limit))
- sql = limit_offset.sub(f'limit {limit_num} offset {sql_offset};', sql)
+ sql = limit_offset.sub(f"limit {limit_num} offset {sql_offset};", sql)
elif offset_comma_limit.search(sql):
sql_offset = offset_comma_limit.search(sql).group(1)
sql_limit = offset_comma_limit.search(sql).group(2)
limit_num = min(int(limit_num), int(sql_limit))
- sql = offset_comma_limit.sub(f'limit {sql_offset},{limit_num};', sql)
+ sql = offset_comma_limit.sub(f"limit {sql_offset},{limit_num};", sql)
else:
- sql = f'{sql} limit {limit_num};'
+ sql = f"{sql} limit {limit_num};"
else:
- sql = f'{sql};'
+ sql = f"{sql};"
return sql
- def explain_check(self, check_result, db_name=None, line=0, statement=''):
+ def explain_check(self, check_result, db_name=None, line=0, statement=""):
"""使用explain ast检查sql语法, 返回Review set"""
- result = ReviewResult(id=line, errlevel=0,
- stagestatus='Audit completed',
- errormessage='None',
- sql=statement,
- affected_rows=0,
- execute_time=0, )
+ result = ReviewResult(
+ id=line,
+ errlevel=0,
+ stagestatus="Audit completed",
+ errormessage="None",
+ sql=statement,
+ affected_rows=0,
+ execute_time=0,
+ )
# clickhouse版本>=21.1.2 explain ast才支持非select语句检查
if self.server_version >= (21, 1, 2):
explain_result = self.query(db_name=db_name, sql=f"explain ast {statement}")
if explain_result.error:
- result = ReviewResult(id=line, errlevel=2,
- stagestatus='驳回未通过检查SQL',
- errormessage=f'explain语法检查错误:{explain_result.error}',
- sql=statement)
+ result = ReviewResult(
+ id=line,
+ errlevel=2,
+ stagestatus="驳回未通过检查SQL",
+ errormessage=f"explain语法检查错误:{explain_result.error}",
+ sql=statement,
+ )
return result
- def execute_check(self, db_name=None, sql=''):
+ def execute_check(self, db_name=None, sql=""):
"""上线单执行前的检查, 返回Review set"""
sql = sqlparse.format(sql, strip_comments=True)
sql_list = sqlparse.split(sql)
@@ -214,111 +237,162 @@ def execute_check(self, db_name=None, sql=''):
# 禁用/高危语句检查
check_result = ReviewSet(full_sql=sql)
line = 1
- critical_ddl_regex = self.config.get('critical_ddl_regex', '')
+ critical_ddl_regex = self.config.get("critical_ddl_regex", "")
p = re.compile(critical_ddl_regex)
check_result.syntax_type = 2 # TODO 工单类型 0、其他 1、DDL,2、DML
for statement in sql_list:
- statement = statement.rstrip(';')
+ statement = statement.rstrip(";")
# 禁用语句
if re.match(r"^select|^show", statement.lower()):
- result = ReviewResult(id=line, errlevel=2,
- stagestatus='驳回不支持语句',
- errormessage='仅支持DML和DDL语句,查询语句请使用SQL查询功能!',
- sql=statement)
+ result = ReviewResult(
+ id=line,
+ errlevel=2,
+ stagestatus="驳回不支持语句",
+ errormessage="仅支持DML和DDL语句,查询语句请使用SQL查询功能!",
+ sql=statement,
+ )
# 高危语句
elif critical_ddl_regex and p.match(statement.strip().lower()):
- result = ReviewResult(id=line, errlevel=2,
- stagestatus='驳回高危SQL',
- errormessage='禁止提交匹配' + critical_ddl_regex + '条件的语句!',
- sql=statement)
+ result = ReviewResult(
+ id=line,
+ errlevel=2,
+ stagestatus="驳回高危SQL",
+ errormessage="禁止提交匹配" + critical_ddl_regex + "条件的语句!",
+ sql=statement,
+ )
# alter语句
elif re.match(r"^alter", statement.lower()):
# alter table语句
if re.match(r"^alter\s+table\s+(.+?)\s+", statement.lower()):
- table_name = re.match(r"^alter\s+table\s+(.+?)\s+", statement.lower(), re.M).group(1)
- if '.' not in table_name:
+ table_name = re.match(
+ r"^alter\s+table\s+(.+?)\s+", statement.lower(), re.M
+ ).group(1)
+ if "." not in table_name:
table_name = f"{db_name}.{table_name}"
- table_engine = self.get_table_engine(table_name)['engine']
- table_exist = self.get_table_engine(table_name)['status']
+ table_engine = self.get_table_engine(table_name)["engine"]
+ table_exist = self.get_table_engine(table_name)["status"]
if table_exist == 1:
- if not table_engine.endswith('MergeTree') and table_engine not in ('Merge', 'Distributed'):
- result = ReviewResult(id=line, errlevel=2,
- stagestatus='驳回不支持SQL',
- errormessage='ALTER TABLE仅支持*MergeTree,Merge以及Distributed等引擎表!',
- sql=statement)
+ if not table_engine.endswith(
+ "MergeTree"
+ ) and table_engine not in ("Merge", "Distributed"):
+ result = ReviewResult(
+ id=line,
+ errlevel=2,
+ stagestatus="驳回不支持SQL",
+ errormessage="ALTER TABLE仅支持*MergeTree,Merge以及Distributed等引擎表!",
+ sql=statement,
+ )
else:
# delete与update语句,实际是alter语句的变种
- if re.match(r"^alter\s+table\s+(.+?)\s+(delete|update)\s+", statement.lower()):
- if not table_engine.endswith('MergeTree'):
- result = ReviewResult(id=line, errlevel=2,
- stagestatus='驳回不支持SQL',
- errormessage='DELETE与UPDATE仅支持*MergeTree引擎表!',
- sql=statement)
+ if re.match(
+ r"^alter\s+table\s+(.+?)\s+(delete|update)\s+",
+ statement.lower(),
+ ):
+ if not table_engine.endswith("MergeTree"):
+ result = ReviewResult(
+ id=line,
+ errlevel=2,
+ stagestatus="驳回不支持SQL",
+ errormessage="DELETE与UPDATE仅支持*MergeTree引擎表!",
+ sql=statement,
+ )
else:
- result = self.explain_check(check_result, db_name, line, statement)
+ result = self.explain_check(
+ check_result, db_name, line, statement
+ )
else:
- result = self.explain_check(check_result, db_name, line, statement)
+ result = self.explain_check(
+ check_result, db_name, line, statement
+ )
else:
- result = ReviewResult(id=line, errlevel=2,
- stagestatus='表不存在',
- errormessage=f'表 {table_name} 不存在!',
- sql=statement)
+ result = ReviewResult(
+ id=line,
+ errlevel=2,
+ stagestatus="表不存在",
+ errormessage=f"表 {table_name} 不存在!",
+ sql=statement,
+ )
# 其他alter语句
else:
result = self.explain_check(check_result, db_name, line, statement)
# truncate语句
elif re.match(r"^truncate\s+table\s+(.+?)(\s|$)", statement.lower()):
- table_name = re.match(r"^truncate\s+table\s+(.+?)(\s|$)", statement.lower(), re.M).group(1)
- if '.' not in table_name:
+ table_name = re.match(
+ r"^truncate\s+table\s+(.+?)(\s|$)", statement.lower(), re.M
+ ).group(1)
+ if "." not in table_name:
table_name = f"{db_name}.{table_name}"
- table_engine = self.get_table_engine(table_name)['engine']
- table_exist = self.get_table_engine(table_name)['status']
+ table_engine = self.get_table_engine(table_name)["engine"]
+ table_exist = self.get_table_engine(table_name)["status"]
if table_exist == 1:
- if table_engine in ('View', 'File,', 'URL', 'Buffer', 'Null'):
- result = ReviewResult(id=line, errlevel=2,
- stagestatus='驳回不支持SQL',
- errormessage='TRUNCATE不支持View,File,URL,Buffer和Null表引擎!',
- sql=statement)
+ if table_engine in ("View", "File,", "URL", "Buffer", "Null"):
+ result = ReviewResult(
+ id=line,
+ errlevel=2,
+ stagestatus="驳回不支持SQL",
+ errormessage="TRUNCATE不支持View,File,URL,Buffer和Null表引擎!",
+ sql=statement,
+ )
else:
- result = self.explain_check(check_result, db_name, line, statement)
+ result = self.explain_check(
+ check_result, db_name, line, statement
+ )
else:
- result = ReviewResult(id=line, errlevel=2,
- stagestatus='表不存在',
- errormessage=f'表 {table_name} 不存在!',
- sql=statement)
+ result = ReviewResult(
+ id=line,
+ errlevel=2,
+ stagestatus="表不存在",
+ errormessage=f"表 {table_name} 不存在!",
+ sql=statement,
+ )
# insert语句,explain无法正确判断,暂时只做表存在性检查与简单关键字匹配
elif re.match(r"^insert", statement.lower()):
- if re.match(r"^insert\s+into\s+(.+?)(\s+|\s*\(.+?)(values|format|select)(\s+|\()", statement.lower()):
- table_name = re.match(r"^insert\s+into\s+(.+?)(\s+|\s*\(.+?)(values|format|select)(\s+|\()",
- statement.lower(), re.M).group(1)
- if '.' not in table_name:
+ if re.match(
+ r"^insert\s+into\s+(.+?)(\s+|\s*\(.+?)(values|format|select)(\s+|\()",
+ statement.lower(),
+ ):
+ table_name = re.match(
+ r"^insert\s+into\s+(.+?)(\s+|\s*\(.+?)(values|format|select)(\s+|\()",
+ statement.lower(),
+ re.M,
+ ).group(1)
+ if "." not in table_name:
table_name = f"{db_name}.{table_name}"
- table_exist = self.get_table_engine(table_name)['status']
+ table_exist = self.get_table_engine(table_name)["status"]
if table_exist == 1:
- result = ReviewResult(id=line, errlevel=0,
- stagestatus='Audit completed',
- errormessage='None',
- sql=statement,
- affected_rows=0,
- execute_time=0, )
+ result = ReviewResult(
+ id=line,
+ errlevel=0,
+ stagestatus="Audit completed",
+ errormessage="None",
+ sql=statement,
+ affected_rows=0,
+ execute_time=0,
+ )
else:
- result = ReviewResult(id=line, errlevel=2,
- stagestatus='表不存在',
- errormessage=f'表 {table_name} 不存在!',
- sql=statement)
+ result = ReviewResult(
+ id=line,
+ errlevel=2,
+ stagestatus="表不存在",
+ errormessage=f"表 {table_name} 不存在!",
+ sql=statement,
+ )
else:
- result = ReviewResult(id=line, errlevel=2,
- stagestatus='驳回不支持SQL',
- errormessage='INSERT语法不正确!',
- sql=statement)
+ result = ReviewResult(
+ id=line,
+ errlevel=2,
+ stagestatus="驳回不支持SQL",
+ errormessage="INSERT语法不正确!",
+ sql=statement,
+ )
# 其他语句使用explain ast简单检查
else:
result = self.explain_check(check_result, db_name, line, statement)
# 没有找出DDL语句的才继续执行此判断
if check_result.syntax_type == 2:
- if get_syntax_type(statement, parser=False, db_type='mysql') == 'DDL':
+ if get_syntax_type(statement, parser=False, db_type="mysql") == "DDL":
check_result.syntax_type = 1
check_result.rows += [result]
line += 1
@@ -340,47 +414,55 @@ def execute_workflow(self, workflow):
line = 1
for statement in sql_list:
with FuncTimer() as t:
- result = self.execute(db_name=workflow.db_name, sql=statement, close_conn=True)
+ result = self.execute(
+ db_name=workflow.db_name, sql=statement, close_conn=True
+ )
if not result.error:
- execute_result.rows.append(ReviewResult(
- id=line,
- errlevel=0,
- stagestatus='Execute Successfully',
- errormessage='None',
- sql=statement,
- affected_rows=0,
- execute_time=t.cost,
- ))
+ execute_result.rows.append(
+ ReviewResult(
+ id=line,
+ errlevel=0,
+ stagestatus="Execute Successfully",
+ errormessage="None",
+ sql=statement,
+ affected_rows=0,
+ execute_time=t.cost,
+ )
+ )
line += 1
else:
# 追加当前报错语句信息到执行结果中
execute_result.error = result.error
- execute_result.rows.append(ReviewResult(
- id=line,
- errlevel=2,
- stagestatus='Execute Failed',
- errormessage=f'异常信息:{result.error}',
- sql=statement,
- affected_rows=0,
- execute_time=0,
- ))
- line += 1
- # 报错语句后面的语句标记为审核通过、未执行,追加到执行结果中
- for statement in sql_list[line - 1:]:
- execute_result.rows.append(ReviewResult(
+ execute_result.rows.append(
+ ReviewResult(
id=line,
- errlevel=0,
- stagestatus='Audit completed',
- errormessage=f'前序语句失败, 未执行',
+ errlevel=2,
+ stagestatus="Execute Failed",
+ errormessage=f"异常信息:{result.error}",
sql=statement,
affected_rows=0,
execute_time=0,
- ))
+ )
+ )
+ line += 1
+ # 报错语句后面的语句标记为审核通过、未执行,追加到执行结果中
+ for statement in sql_list[line - 1 :]:
+ execute_result.rows.append(
+ ReviewResult(
+ id=line,
+ errlevel=0,
+ stagestatus="Audit completed",
+ errormessage=f"前序语句失败, 未执行",
+ sql=statement,
+ affected_rows=0,
+ execute_time=0,
+ )
+ )
line += 1
break
return execute_result
- def execute(self, db_name=None, sql='', close_conn=True):
+ def execute(self, db_name=None, sql="", close_conn=True):
"""原生执行语句"""
result = ResultSet(full_sql=sql)
conn = self.get_connection(db_name=db_name)
@@ -391,7 +473,7 @@ def execute(self, db_name=None, sql='', close_conn=True):
cursor.close()
except Exception as e:
logger.warning(f"ClickHouse语句执行报错,语句:{sql},错误信息{e}")
- result.error = str(e).split('Stack trace')[0]
+ result.error = str(e).split("Stack trace")[0]
if close_conn:
self.close()
return result
diff --git a/sql/engines/goinception.py b/sql/engines/goinception.py
index eca7c525ed..2ed75ced42 100644
--- a/sql/engines/goinception.py
+++ b/sql/engines/goinception.py
@@ -11,7 +11,7 @@
from . import EngineBase
from .models import ResultSet, ReviewSet, ReviewResult
-logger = logging.getLogger('default')
+logger = logging.getLogger("default")
class GoInceptionEngine(EngineBase):
@@ -19,42 +19,51 @@ class GoInceptionEngine(EngineBase):
@property
def name(self):
- return 'GoInception'
+ return "GoInception"
@property
def info(self):
- return 'GoInception engine'
+ return "GoInception engine"
def get_connection(self, db_name=None):
if self.conn:
return self.conn
- if hasattr(self, 'instance'):
- self.conn = MySQLdb.connect(host=self.host, port=self.port, charset=self.instance.charset or 'utf8mb4',
- connect_timeout=10)
+ if hasattr(self, "instance"):
+ self.conn = MySQLdb.connect(
+ host=self.host,
+ port=self.port,
+ charset=self.instance.charset or "utf8mb4",
+ connect_timeout=10,
+ )
return self.conn
archer_config = SysConfig()
- go_inception_host = archer_config.get('go_inception_host')
- go_inception_port = int(archer_config.get('go_inception_port', 4000))
- self.conn = MySQLdb.connect(host=go_inception_host, port=go_inception_port, charset='utf8mb4',
- connect_timeout=10)
+ go_inception_host = archer_config.get("go_inception_host")
+ go_inception_port = int(archer_config.get("go_inception_port", 4000))
+ self.conn = MySQLdb.connect(
+ host=go_inception_host,
+ port=go_inception_port,
+ charset="utf8mb4",
+ connect_timeout=10,
+ )
return self.conn
@staticmethod
def get_backup_connection():
archer_config = SysConfig()
- backup_host = archer_config.get('inception_remote_backup_host')
- backup_port = int(archer_config.get('inception_remote_backup_port', 3306))
- backup_user = archer_config.get('inception_remote_backup_user')
- backup_password = archer_config.get('inception_remote_backup_password', '')
- return MySQLdb.connect(host=backup_host,
- port=backup_port,
- user=backup_user,
- passwd=backup_password,
- charset='utf8mb4',
- autocommit=True
- )
+ backup_host = archer_config.get("inception_remote_backup_host")
+ backup_port = int(archer_config.get("inception_remote_backup_port", 3306))
+ backup_user = archer_config.get("inception_remote_backup_user")
+ backup_password = archer_config.get("inception_remote_backup_password", "")
+ return MySQLdb.connect(
+ host=backup_host,
+ port=backup_port,
+ user=backup_user,
+ passwd=backup_password,
+ charset="utf8mb4",
+ autocommit=True,
+ )
- def execute_check(self, instance=None, db_name=None, sql=''):
+ def execute_check(self, instance=None, db_name=None, sql=""):
"""inception check"""
# 判断如果配置了隧道则连接隧道
host, port, user, password = self.remote_instance_conn(instance)
@@ -78,7 +87,7 @@ def execute_check(self, instance=None, db_name=None, sql=''):
check_result.error_count += 1
# 没有找出DDL语句的才继续执行此判断
if check_result.syntax_type == 2:
- if get_syntax_type(r[5], parser=False, db_type='mysql') == 'DDL':
+ if get_syntax_type(r[5], parser=False, db_type="mysql") == "DDL":
check_result.syntax_type = 1
check_result.column_list = inception_result.column_list
check_result.checked = True
@@ -109,12 +118,15 @@ def execute(self, workflow=None):
# 执行报错,inception crash或者执行中连接异常的场景
if inception_result.error and not execute_result.rows:
execute_result.error = inception_result.error
- execute_result.rows = [ReviewResult(
- stage='Execute failed',
- errlevel=2,
- stagestatus='异常终止',
- errormessage=f'goInception Error: {inception_result.error}',
- sql=workflow.sqlworkflowcontent.sql_content)]
+ execute_result.rows = [
+ ReviewResult(
+ stage="Execute failed",
+ errlevel=2,
+ stagestatus="异常终止",
+ errormessage=f"goInception Error: {inception_result.error}",
+ sql=workflow.sqlworkflowcontent.sql_content,
+ )
+ ]
return execute_result
# 把结果转换为ReviewSet
@@ -123,13 +135,17 @@ def execute(self, workflow=None):
# 如果发现任何一个行执行结果里有errLevel为1或2,并且状态列没有包含Execute Successfully,则最终执行结果为有异常.
for r in execute_result.rows:
- if r.errlevel in (1, 2) and not re.search(r"Execute Successfully", r.stagestatus):
- execute_result.error = "Line {0} has error/warning: {1}".format(r.id, r.errormessage)
+ if r.errlevel in (1, 2) and not re.search(
+ r"Execute Successfully", r.stagestatus
+ ):
+ execute_result.error = "Line {0} has error/warning: {1}".format(
+ r.id, r.errormessage
+ )
break
return execute_result
- def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs):
- """返回 ResultSet """
+ def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs):
+ """返回 ResultSet"""
result_set = ResultSet(full_sql=sql)
conn = self.get_connection()
try:
@@ -145,13 +161,13 @@ def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs):
result_set.rows = rows
result_set.affected_rows = effect_row
except Exception as e:
- logger.warning(f'goInception语句执行报错,错误信息{traceback.format_exc()}')
+ logger.warning(f"goInception语句执行报错,错误信息{traceback.format_exc()}")
result_set.error = str(e)
if close_conn:
self.close()
return result_set
- def query_print(self, instance, db_name=None, sql=''):
+ def query_print(self, instance, db_name=None, sql=""):
"""
打印语法树。
"""
@@ -163,11 +179,11 @@ def query_print(self, instance, db_name=None, sql=''):
{sql.rstrip(';')};
inception_magic_commit;"""
print_info = self.query(db_name=db_name, sql=sql).to_dict()[1]
- if print_info.get('errmsg'):
- raise RuntimeError(print_info.get('errmsg'))
+ if print_info.get("errmsg"):
+ raise RuntimeError(print_info.get("errmsg"))
return print_info
- def query_data_masking(self, instance, db_name=None, sql=''):
+ def query_data_masking(self, instance, db_name=None, sql=""):
"""
将sql交给goInception打印语法树,获取select list
使用 masking 参数,可参考 https://github.com/hanchuanchuan/goInception/pull/355
@@ -182,13 +198,13 @@ def query_data_masking(self, instance, db_name=None, sql=''):
query_result = self.query(db_name=db_name, sql=sql)
# 有异常时主动抛出
if query_result.error:
- raise RuntimeError(f'Inception Error: {query_result.error}')
+ raise RuntimeError(f"Inception Error: {query_result.error}")
if not query_result.rows:
- raise RuntimeError(f'Inception Error: 未获取到语法信息')
+ raise RuntimeError(f"Inception Error: 未获取到语法信息")
# 兼容语法错误时errlevel=0的场景
print_info = query_result.to_dict()[0]
- if print_info['errlevel'] == 0 and print_info['errmsg'] is None:
- return json.loads(print_info['query_tree'])
+ if print_info["errlevel"] == 0 and print_info["errmsg"] is None:
+ return json.loads(print_info["query_tree"])
else:
raise RuntimeError(f"Inception Error: {print_info['errmsg']}")
@@ -196,7 +212,9 @@ def get_rollback(self, workflow):
"""
获取回滚语句,并且按照执行顺序倒序展示,return ['源语句','回滚语句']
"""
- list_execute_result = json.loads(workflow.sqlworkflowcontent.execute_result or '[]')
+ list_execute_result = json.loads(
+ workflow.sqlworkflowcontent.execute_result or "[]"
+ )
# 回滚语句倒序展示
list_execute_result.reverse()
list_backup_sql = []
@@ -207,18 +225,18 @@ def get_rollback(self, workflow):
try:
# 获取backup_db_name, 兼容旧数据'[[]]'格式
if isinstance(row, list):
- if row[8] == 'None':
+ if row[8] == "None":
continue
backup_db_name = row[8]
sequence = row[7]
sql = row[5]
# 新数据
else:
- if row.get('backup_dbname') in ('None', ''):
+ if row.get("backup_dbname") in ("None", ""):
continue
- backup_db_name = row.get('backup_dbname')
- sequence = row.get('sequence')
- sql = row.get('sql')
+ backup_db_name = row.get("backup_dbname")
+ sequence = row.get("sequence")
+ sql = row.get("sql")
# 获取备份表名
opid_time = sequence.replace("'", "")
sql_table = f"""select tablename
@@ -236,7 +254,9 @@ def get_rollback(self, workflow):
cur.execute(sql_back)
list_backup = cur.fetchall()
# 拼接成回滚语句列表,['源语句','回滚语句']
- list_backup_sql.append([sql, '\n'.join([back_info[0] for back_info in list_backup])])
+ list_backup_sql.append(
+ [sql, "\n".join([back_info[0] for back_info in list_backup])]
+ )
except Exception as e:
logger.error(f"获取回滚语句报错,异常信息{traceback.format_exc()}")
raise Exception(e)
@@ -260,9 +280,9 @@ def set_variable(self, variable_name, variable_value):
def osc_control(self, **kwargs):
"""控制osc执行,获取进度、终止、暂停、恢复等"""
- sqlsha1 = kwargs.get('sqlsha1')
- command = kwargs.get('command')
- if command == 'get':
+ sqlsha1 = kwargs.get("sqlsha1")
+ command = kwargs.get("command")
+ if command == "get":
sql = f"inception get osc_percent '{sqlsha1}';"
else:
sql = f"inception {command} osc '{sqlsha1}';"
@@ -270,7 +290,7 @@ def osc_control(self, **kwargs):
@staticmethod
def get_table_ref(query_tree, db_name=None):
- __author__ = 'xxlrr'
+ __author__ = "xxlrr"
"""
* 从goInception解析后的语法树里解析出兼容Inception格式的引用表信息。
* 目前的逻辑是在SQL语法树中通过递归查找选中最小的 TableRefs 子树(可能有多个),
@@ -294,12 +314,15 @@ def get_table_ref(query_tree, db_name=None):
else:
snodes = tree.find_max_tree("Source")
if snodes:
- table_ref.extend([
- {
- "schema": snode['Source']['Schema']['O'] or db_name,
- "name": snode['Source']['Name']['O']
- } for snode in snodes
- ])
+ table_ref.extend(
+ [
+ {
+ "schema": snode["Source"]["Schema"]["O"] or db_name,
+ "name": snode["Source"]["Name"]["O"],
+ }
+ for snode in snodes
+ ]
+ )
# assert: source node must exists if table_refs node exists.
# else:
# raise Exception("GoInception Error: not found source node")
@@ -313,7 +336,7 @@ def close(self):
class DictTree(dict):
def find_max_tree(self, *keys):
- __author__ = 'xxlrr'
+ __author__ = "xxlrr"
"""通过广度优先搜索算法查找满足条件的最大子树(不找叶子节点)"""
fit = []
find_queue = [self]
@@ -331,14 +354,15 @@ def find_max_tree(self, *keys):
def get_session_variables(instance):
"""按照目标实例动态设置goInception的会话参数,可用于按照业务组自定义审核规则等场景"""
variables = {}
- set_session_sql = ''
+ set_session_sql = ""
if AliyunRdsConfig.objects.filter(instance=instance, is_enable=True).exists():
- variables.update({
- "ghost_aliyun_rds": "on",
- "ghost_allow_on_master": "true",
- "ghost_assume_rbr": "true",
-
- })
+ variables.update(
+ {
+ "ghost_aliyun_rds": "on",
+ "ghost_allow_on_master": "true",
+ "ghost_assume_rbr": "true",
+ }
+ )
# 转换成SQL语句
for k, v in variables.items():
set_session_sql += f"inception set session {k} = '{v}';\n"
diff --git a/sql/engines/models.py b/sql/engines/models.py
index fc377590eb..34e73d6652 100644
--- a/sql/engines/models.py
+++ b/sql/engines/models.py
@@ -4,16 +4,23 @@
class SqlItem:
-
- def __init__(self, id=0, statement='', stmt_type='SQL', object_owner='', object_type='', object_name=''):
- '''
+ def __init__(
+ self,
+ id=0,
+ statement="",
+ stmt_type="SQL",
+ object_owner="",
+ object_type="",
+ object_name="",
+ ):
+ """
:param id: SQL序号,从0开始
:param statement: SQL Statement
:param stmt_type: SQL类型(SQL, PLSQL), 默认为SQL
:param object_owner: PLSQL Object Owner
:param object_type: PLSQL Object Type
:param object_name: PLSQL Object Name
- '''
+ """
self.id = id
self.statement = statement
self.stmt_type = stmt_type
@@ -34,32 +41,34 @@ def __init__(self, inception_result=None, **kwargs):
"""
if inception_result:
self.id = inception_result[0] or 0
- self.stage = inception_result[1] or ''
+ self.stage = inception_result[1] or ""
self.errlevel = inception_result[2] or 0
- self.stagestatus = inception_result[3] or ''
- self.errormessage = inception_result[4] or ''
- self.sql = inception_result[5] or ''
+ self.stagestatus = inception_result[3] or ""
+ self.errormessage = inception_result[4] or ""
+ self.sql = inception_result[5] or ""
self.affected_rows = inception_result[6] or 0
- self.sequence = inception_result[7] or ''
- self.backup_dbname = inception_result[8] or ''
- self.execute_time = inception_result[9] or ''
- self.sqlsha1 = inception_result[10] or ''
- self.backup_time = inception_result[11] if len(inception_result) >= 12 else ''
- self.actual_affected_rows = ''
+ self.sequence = inception_result[7] or ""
+ self.backup_dbname = inception_result[8] or ""
+ self.execute_time = inception_result[9] or ""
+ self.sqlsha1 = inception_result[10] or ""
+ self.backup_time = (
+ inception_result[11] if len(inception_result) >= 12 else ""
+ )
+ self.actual_affected_rows = ""
else:
- self.id = kwargs.get('id', 0)
- self.stage = kwargs.get('stage', '')
- self.errlevel = kwargs.get('errlevel', 0)
- self.stagestatus = kwargs.get('stagestatus', '')
- self.errormessage = kwargs.get('errormessage', '')
- self.sql = kwargs.get('sql', '')
- self.affected_rows = kwargs.get('affected_rows', 0)
- self.sequence = kwargs.get('sequence', '')
- self.backup_dbname = kwargs.get('backup_dbname', '')
- self.execute_time = kwargs.get('execute_time', '')
- self.sqlsha1 = kwargs.get('sqlsha1', '')
- self.backup_time = kwargs.get('backup_time', '')
- self.actual_affected_rows = kwargs.get('actual_affected_rows', '')
+ self.id = kwargs.get("id", 0)
+ self.stage = kwargs.get("stage", "")
+ self.errlevel = kwargs.get("errlevel", 0)
+ self.stagestatus = kwargs.get("stagestatus", "")
+ self.errormessage = kwargs.get("errormessage", "")
+ self.sql = kwargs.get("sql", "")
+ self.affected_rows = kwargs.get("affected_rows", 0)
+ self.sequence = kwargs.get("sequence", "")
+ self.backup_dbname = kwargs.get("backup_dbname", "")
+ self.execute_time = kwargs.get("execute_time", "")
+ self.sqlsha1 = kwargs.get("sqlsha1", "")
+ self.backup_time = kwargs.get("backup_time", "")
+ self.actual_affected_rows = kwargs.get("actual_affected_rows", "")
# 自定义属性
for key, value in kwargs.items():
@@ -70,8 +79,15 @@ def __init__(self, inception_result=None, **kwargs):
class ReviewSet:
"""review和执行后的结果集, rows中是review result, 有设定好的字段"""
- def __init__(self, full_sql='', rows=None, status=None,
- affected_rows=0, column_list=None, **kwargs):
+ def __init__(
+ self,
+ full_sql="",
+ rows=None,
+ status=None,
+ affected_rows=0,
+ column_list=None,
+ **kwargs
+ ):
self.full_sql = full_sql
self.is_execute = False
self.checked = None
@@ -107,15 +123,22 @@ def to_dict(self):
class ResultSet:
"""查询的结果集, rows 内只有值, column_list 中的是key"""
- def __init__(self, full_sql='', rows=None, status=None,
- affected_rows=0, column_list=None, **kwargs):
+ def __init__(
+ self,
+ full_sql="",
+ rows=None,
+ status=None,
+ affected_rows=0,
+ column_list=None,
+ **kwargs
+ ):
self.full_sql = full_sql
self.is_execute = False
self.checked = None
self.is_masked = False
- self.query_time = ''
+ self.query_time = ""
self.mask_rule_hit = False
- self.mask_time = ''
+ self.mask_time = ""
self.warning = None
self.error = None
self.is_critical = False
@@ -134,11 +157,11 @@ def json(self):
def to_dict(self):
tmp_list = []
for r in self.rows:
- if isinstance(r,dict):
+ if isinstance(r, dict):
tmp_list += [r]
else:
tmp_list += [dict(zip(self.column_list, r))]
return tmp_list
def to_sep_dict(self):
- return {'column_list': self.column_list, 'rows': self.rows}
+ return {"column_list": self.column_list, "rows": self.rows}
diff --git a/sql/engines/mongo.py b/sql/engines/mongo.py
index dec1bbc349..e198624537 100644
--- a/sql/engines/mongo.py
+++ b/sql/engines/mongo.py
@@ -18,10 +18,10 @@
from . import EngineBase
from .models import ResultSet, ReviewSet, ReviewResult
-logger = logging.getLogger('default')
+logger = logging.getLogger("default")
# mongo客户端安装在本机的位置
-mongo = 'mongo'
+mongo = "mongo"
# 自定义异常
@@ -35,7 +35,7 @@ def __str__(self):
class JsonDecoder:
- '''处理传入mongodb语句中的条件,并转换成pymongo可识别的字典格式'''
+ """处理传入mongodb语句中的条件,并转换成pymongo可识别的字典格式"""
def __init__(self):
pass
@@ -43,35 +43,37 @@ def __init__(self):
def __json_object(self, tokener):
# obj = collections.OrderedDict()
obj = {}
- if tokener.cur_token() != '{':
+ if tokener.cur_token() != "{":
raise Exception('Json must start with "{"')
while True:
tokener.next()
tk_temp = tokener.cur_token()
- if tk_temp == '}':
+ if tk_temp == "}":
return {}
# 限制key的格式
- if not isinstance(tk_temp, str): # or (not tk_temp.isidentifier() and not tk_temp.startswith("$"))
- raise Exception('invalid key %s' % tk_temp)
+ if not isinstance(
+ tk_temp, str
+ ): # or (not tk_temp.isidentifier() and not tk_temp.startswith("$"))
+ raise Exception("invalid key %s" % tk_temp)
key = tk_temp.strip()
tokener.next()
- if tokener.cur_token() != ':':
+ if tokener.cur_token() != ":":
raise Exception('expect ":" after "%s"' % key)
tokener.next()
val = tokener.cur_token()
- if val == '[':
+ if val == "[":
val = self.__json_array(tokener)
- elif val == '{':
+ elif val == "{":
val = self.__json_object(tokener)
obj[key] = val
tokener.next()
tk_split = tokener.cur_token()
- if tk_split == ',':
+ if tk_split == ",":
continue
- elif tk_split == '}':
+ elif tk_split == "}":
break
else:
if tk_split is None:
@@ -80,20 +82,20 @@ def __json_object(self, tokener):
return obj
def __json_array(self, tokener):
- if tokener.cur_token() != '[':
+ if tokener.cur_token() != "[":
raise Exception('Json array must start with "["')
arr = []
while True:
tokener.next()
tk_temp = tokener.cur_token()
- if tk_temp == ']':
+ if tk_temp == "]":
return []
- if tk_temp == '{':
+ if tk_temp == "{":
val = self.__json_object(tokener)
- elif tk_temp == '[':
+ elif tk_temp == "[":
val = self.__json_array(tokener)
- elif tk_temp in (',', ':', '}'):
+ elif tk_temp in (",", ":", "}"):
raise Exception('unexpected token "%s"' % tk_temp)
else:
val = tk_temp
@@ -101,9 +103,9 @@ def __json_array(self, tokener):
tokener.next()
tk_end = tokener.cur_token()
- if tk_end == ',':
+ if tk_end == ",":
continue
- if tk_end == ']':
+ if tk_end == "]":
break
else:
if tk_end is None:
@@ -116,9 +118,9 @@ def decode(self, json_str):
return None
first_token = tokener.cur_token()
- if first_token == '{':
+ if first_token == "{":
decode_val = self.__json_object(tokener)
- elif first_token == '[':
+ elif first_token == "[":
decode_val = self.__json_array(tokener)
else:
raise Exception('Json must start with "{"')
@@ -135,7 +137,7 @@ def __init__(self, json_str):
def __cur_char(self):
if self.__i < len(self.__str):
return self.__str[self.__i]
- return ''
+ return ""
def __previous_char(self):
if self.__i < len(self.__str):
@@ -143,26 +145,29 @@ def __previous_char(self):
def __remain_str(self):
if self.__i < len(self.__str):
- return self.__str[self.__i:]
+ return self.__str[self.__i :]
def __move_i(self, step=1):
if self.__i < len(self.__str):
self.__i += step
def __next_string(self):
- '''当出现了"和'后就进入这个方法解析,直到出现与之对应的结束字符'''
- outstr = ''
+ """当出现了"和'后就进入这个方法解析,直到出现与之对应的结束字符"""
+ outstr = ""
trans_flag = False
start_ch = ""
self.__move_i()
- while self.__cur_char() != '':
+ while self.__cur_char() != "":
ch = self.__cur_char()
- if start_ch == "": start_ch = self.__previous_char()
+ if start_ch == "":
+ start_ch = self.__previous_char()
if ch == '\\"': # 判断是否是转义
trans_flag = True
else:
if not trans_flag:
- if (ch == '"' and start_ch == '"') or (ch == "'" and start_ch == "'"):
+ if (ch == '"' and start_ch == '"') or (
+ ch == "'" and start_ch == "'"
+ ):
break
else:
trans_flag = False
@@ -171,25 +176,29 @@ def __next_string(self):
return outstr
def __next_number(self):
- expr = ''
- while self.__cur_char().isdigit() or self.__cur_char() in ('.', '+', '-'):
+ expr = ""
+ while self.__cur_char().isdigit() or self.__cur_char() in (".", "+", "-"):
expr += self.__cur_char()
self.__move_i()
self.__move_i(-1)
- if '.' in expr:
+ if "." in expr:
return float(expr)
else:
return int(expr)
def __next_const(self):
- '''处理没有被''和""包含的字符,如true和ObjectId'''
+ """处理没有被''和""包含的字符,如true和ObjectId"""
outstr = ""
data_type = ""
while self.__cur_char().isalpha() or self.__cur_char() in ("$", "_", " "):
outstr += self.__cur_char()
self.__move_i()
if outstr.replace(" ", "") in (
- "ObjectId", "newDate", "ISODate", "newISODate"): # ======类似的类型比较多还需单独处理,如int()等
+ "ObjectId",
+ "newDate",
+ "ISODate",
+ "newISODate",
+ ): # ======类似的类型比较多还需单独处理,如int()等
data_type = outstr
for c in self.__remain_str():
outstr += c
@@ -199,8 +208,8 @@ def __next_const(self):
self.__move_i(-1)
- if outstr in ('true', 'false', 'null'):
- return {'true': True, 'false': False, 'null': None}[outstr]
+ if outstr in ("true", "false", "null"):
+ return {"true": True, "false": False, "null": None}[outstr]
elif data_type == "ObjectId":
ojStr = re.findall(r"ObjectId\(.*?\)", outstr) # 单独处理ObjectId
if len(ojStr) > 0:
@@ -208,10 +217,16 @@ def __next_const(self):
id_str = re.findall(r"\(.*?\)", ojStr[0])
oid = id_str[0].replace(" ", "")[2:-2]
return ObjectId(oid)
- elif data_type.replace(" ", "") in ("newDate", "ISODate", "newISODate"): # 处理时间格式
+ elif data_type.replace(" ", "") in (
+ "newDate",
+ "ISODate",
+ "newISODate",
+ ): # 处理时间格式
tmp_type = "%s()" % data_type
if outstr.replace(" ", "") == tmp_type.replace(" ", ""):
- return datetime.datetime.now() + datetime.timedelta(hours=-8) # mongodb默认时区为utc
+ return datetime.datetime.now() + datetime.timedelta(
+ hours=-8
+ ) # mongodb默认时区为utc
date_regex = re.compile(r'%s\("(.*)"\)' % data_type, re.IGNORECASE)
date_content = date_regex.findall(outstr)
if len(date_content) > 0:
@@ -221,21 +236,26 @@ def __next_const(self):
raise Exception('Invalid symbol "%s"' % outstr)
def next(self):
- is_white_space = lambda a_char: a_char in ('\x20', '\n', '\r', '\t') # 定义一个匿名函数
+ is_white_space = lambda a_char: a_char in (
+ "\x20",
+ "\n",
+ "\r",
+ "\t",
+ ) # 定义一个匿名函数
while is_white_space(self.__cur_char()):
self.__move_i()
ch = self.__cur_char()
- if ch == '':
+ if ch == "":
cur_token = None
- elif ch in ('{', '}', '[', ']', ',', ':'):
+ elif ch in ("{", "}", "[", "]", ",", ":"):
cur_token = ch
elif ch in ('"', "'"): # 当字符为" '
cur_token = self.__next_string()
elif ch.isalpha() or ch in ("$", "_"): # 字符串是否只由字母和"$","_"组成
cur_token = self.__next_const()
- elif ch.isdigit() or ch in ('.', '-', '+'): # 检测字符串是否只由数字组成
+ elif ch.isdigit() or ch in (".", "-", "+"): # 检测字符串是否只由数字组成
cur_token = self.__next_number()
else:
raise Exception('Invalid symbol "%s"' % ch)
@@ -256,52 +276,84 @@ class MongoEngine(EngineBase):
def test_connection(self):
return self.get_all_databases()
- def exec_cmd(self, sql, db_name=None, slave_ok=''):
+ def exec_cmd(self, sql, db_name=None, slave_ok=""):
"""审核时执行的语句"""
if self.user and self.password and self.port and self.host:
msg = ""
- auth_db = self.instance.db_name or 'admin'
+ auth_db = self.instance.db_name or "admin"
sql_len = len(sql)
is_load = False # 默认不使用load方法执行mongodb sql语句
try:
- if not sql.startswith('var host=') and sql_len > 4000: # 在master节点执行的情况,如果sql长度大于4000,就采取load js的方法
+ if (
+ not sql.startswith("var host=") and sql_len > 4000
+ ): # 在master节点执行的情况,如果sql长度大于4000,就采取load js的方法
# 因为用mongo load方法执行js脚本,所以需要重新改写一下sql,以便回显js执行结果
- sql = 'var result = ' + sql + '\nprintjson(result);'
+ sql = "var result = " + sql + "\nprintjson(result);"
# 因为要知道具体的临时文件位置,所以用了NamedTemporaryFile模块
- fp = tempfile.NamedTemporaryFile(suffix=".js", prefix="mongo_", dir='/tmp/', delete=True)
- fp.write(sql.encode('utf-8'))
+ fp = tempfile.NamedTemporaryFile(
+ suffix=".js", prefix="mongo_", dir="/tmp/", delete=True
+ )
+ fp.write(sql.encode("utf-8"))
fp.seek(0) # 把文件指针指向开始,这样写的sql内容才能落到磁盘文件上
cmd = "{mongo} --quiet -u {uname} -p '{password}' {host}:{port}/{auth_db} <<\\EOF\ndb=db.getSiblingDB(\"{db_name}\");{slave_ok}load('{tempfile_}')\nEOF".format(
- mongo=mongo, uname=self.user, password=self.password, host=self.host, port=self.port,
- db_name=db_name, sql=sql, auth_db=auth_db, slave_ok=slave_ok, tempfile_=fp.name)
+ mongo=mongo,
+ uname=self.user,
+ password=self.password,
+ host=self.host,
+ port=self.port,
+ db_name=db_name,
+ sql=sql,
+ auth_db=auth_db,
+ slave_ok=slave_ok,
+ tempfile_=fp.name,
+ )
is_load = True # 标记使用了load方法,用来在finally里面判断是否需要强制删除临时文件
- elif not sql.startswith(
- 'var host=') and sql_len < 4000: # 在master节点执行的情况, 如果sql长度小于4000,就直接用mongo shell执行,减少磁盘交换,节省性能
+ elif (
+ not sql.startswith("var host=") and sql_len < 4000
+ ): # 在master节点执行的情况, 如果sql长度小于4000,就直接用mongo shell执行,减少磁盘交换,节省性能
cmd = "{mongo} --quiet -u {uname} -p '{password}' {host}:{port}/{auth_db} <<\\EOF\ndb=db.getSiblingDB(\"{db_name}\");{slave_ok}printjson({sql})\nEOF".format(
- mongo=mongo, uname=self.user, password=self.password, host=self.host, port=self.port,
- db_name=db_name, sql=sql, auth_db=auth_db, slave_ok=slave_ok)
+ mongo=mongo,
+ uname=self.user,
+ password=self.password,
+ host=self.host,
+ port=self.port,
+ db_name=db_name,
+ sql=sql,
+ auth_db=auth_db,
+ slave_ok=slave_ok,
+ )
else:
cmd = "{mongo} --quiet -u {user} -p '{password}' {host}:{port}/{auth_db} <<\\EOF\nrs.slaveOk();{sql}\nEOF".format(
- mongo=mongo, user=self.user, password=self.password, host=self.host, port=self.port,
- db_name=db_name, sql=sql, auth_db=auth_db)
- p = subprocess.Popen(cmd, shell=True,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- universal_newlines=True)
+ mongo=mongo,
+ user=self.user,
+ password=self.password,
+ host=self.host,
+ port=self.port,
+ db_name=db_name,
+ sql=sql,
+ auth_db=auth_db,
+ )
+ p = subprocess.Popen(
+ cmd,
+ shell=True,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ universal_newlines=True,
+ )
re_msg = []
- for line in iter(p.stdout.read, ''):
+ for line in iter(p.stdout.read, ""):
re_msg.append(line)
# 因为返回的line中也有可能带有换行符,因此需要先全部转换成字符串
- __msg = '\n'.join(re_msg)
+ __msg = "\n".join(re_msg)
_re_msg = []
- for _line in __msg.split('\n'):
- if not _re_msg and re.match('WARNING.*', _line):
+ for _line in __msg.split("\n"):
+ if not _re_msg and re.match("WARNING.*", _line):
# 第一行可能是WARNING语句,因此跳过
continue
_re_msg.append(_line)
- msg = '\n'.join(_re_msg)
+ msg = "\n".join(_re_msg)
except Exception as e:
logger.warning(f"mongo语句执行报错,语句:{sql},{e}错误信息{traceback.format_exc()}")
finally:
@@ -314,8 +366,8 @@ def get_master(self):
sql = "rs.isMaster().primary"
master = self.exec_cmd(sql)
- if master != 'undefined' and master.find("TypeError") >= 0:
- sp_host = master.replace("\"", "").split(":")
+ if master != "undefined" and master.find("TypeError") >= 0:
+ sp_host = master.replace('"', "").split(":")
self.host = sp_host[0]
self.port = int(sp_host[1])
# return master
@@ -323,11 +375,11 @@ def get_master(self):
def get_slave(self):
"""获得从节点的port和host"""
- sql = '''var host=""; rs.status().members.forEach(function(item) {i=1; if (item.stateStr =="SECONDARY") \
- {host=item.name } }); print(host);'''
+ sql = """var host=""; rs.status().members.forEach(function(item) {i=1; if (item.stateStr =="SECONDARY") \
+ {host=item.name } }); print(host);"""
slave_msg = self.exec_cmd(sql)
- if slave_msg.lower().find('undefined') < 0:
- sp_host = slave_msg.replace("\"", "").split(":")
+ if slave_msg.lower().find("undefined") < 0:
+ sp_host = slave_msg.replace('"', "").split(":")
self.host = sp_host[0]
self.port = int(sp_host[1])
return True
@@ -339,7 +391,7 @@ def get_table_conut(self, table_name, db_name):
count_sql = f"db.{table_name}.count()"
status = self.get_slave() # 查询总数据要求在slave节点执行
if self.host and self.port and status:
- count = int(self.exec_cmd(count_sql, db_name, slave_ok='rs.slaveOk();'))
+ count = int(self.exec_cmd(count_sql, db_name, slave_ok="rs.slaveOk();"))
else:
count = int(self.exec_cmd(count_sql, db_name))
return count
@@ -349,9 +401,11 @@ def get_table_conut(self, table_name, db_name):
def execute_workflow(self, workflow):
"""执行上线单,返回Review set"""
- return self.execute(db_name=workflow.db_name, sql=workflow.sqlworkflowcontent.sql_content)
+ return self.execute(
+ db_name=workflow.db_name, sql=workflow.sqlworkflowcontent.sql_content
+ )
- def execute(self, db_name=None, sql=''):
+ def execute(self, db_name=None, sql=""):
"""mongo命令执行语句"""
self.get_master()
execute_result = ReviewSet(full_sql=sql)
@@ -360,7 +414,7 @@ def execute(self, db_name=None, sql=''):
sp_sql = sql.split(";")
line = 0
for exec_sql in sp_sql:
- if not exec_sql == '':
+ if not exec_sql == "":
exec_sql = exec_sql.strip()
try:
# DeprecationWarning: time.clock has been deprecated in Python 3.3 and will be removed from Python 3.8: use time.perf_counter or time.process_time instead
@@ -370,162 +424,252 @@ def execute(self, db_name=None, sql=''):
line += 1
logger.debug("执行结果:" + r)
# 如果执行中有错误
- rz = r.replace(' ', '').replace('"', '').lower()
+ rz = r.replace(" ", "").replace('"', "").lower()
tr = 1
- if r.lower().find("syntaxerror") >= 0 or rz.find('ok:0') >= 0 or rz.find(
- "error:invalid") >= 0 or rz.find("ReferenceError") >= 0 \
- or rz.find("getErrorWithCode") >= 0 or rz.find("failedtoconnect") >= 0 or rz.find(
- "Error: field") >= 0:
+ if (
+ r.lower().find("syntaxerror") >= 0
+ or rz.find("ok:0") >= 0
+ or rz.find("error:invalid") >= 0
+ or rz.find("ReferenceError") >= 0
+ or rz.find("getErrorWithCode") >= 0
+ or rz.find("failedtoconnect") >= 0
+ or rz.find("Error: field") >= 0
+ ):
tr = 0
- if (rz.find("errmsg") >= 0 or tr == 0) and (r.lower().find("already exist") < 0):
+ if (rz.find("errmsg") >= 0 or tr == 0) and (
+ r.lower().find("already exist") < 0
+ ):
execute_result.error = r
result = ReviewResult(
id=line,
- stage='Execute failed',
+ stage="Execute failed",
errlevel=2,
- stagestatus='异常终止',
- errormessage=f'mongo语句执行报错: {r}',
- sql=exec_sql)
+ stagestatus="异常终止",
+ errormessage=f"mongo语句执行报错: {r}",
+ sql=exec_sql,
+ )
else:
# 把结果转换为ReviewSet
result = ReviewResult(
- id=line, errlevel=0,
- stagestatus='执行结束',
+ id=line,
+ errlevel=0,
+ stagestatus="执行结束",
errormessage=r,
execute_time=round(end - start, 6),
actual_affected_rows=0, # todo============这个值需要优化
- sql=exec_sql)
+ sql=exec_sql,
+ )
execute_result.rows += [result]
except Exception as e:
- logger.warning(f"mongo语句执行报错,语句:{exec_sql},错误信息{traceback.format_exc()}")
+ logger.warning(
+ f"mongo语句执行报错,语句:{exec_sql},错误信息{traceback.format_exc()}"
+ )
execute_result.error = str(e)
# result_set.column_list = [i[0] for i in fields] if fields else []
return execute_result
- def execute_check(self, db_name=None, sql=''):
+ def execute_check(self, db_name=None, sql=""):
"""上线单执行前的检查, 返回Review set"""
line = 1
count = 0
check_result = ReviewSet(full_sql=sql)
sql = sql.strip()
- if (sql.find(";") < 0):
+ if sql.find(";") < 0:
raise Exception("提交的语句请以分号结尾")
# 以;切分语句,逐句执行
sp_sql = sql.split(";")
# 执行语句
for check_sql in sp_sql:
- alert = '' # 警告信息
- if not check_sql == '' and check_sql != '\n':
+ alert = "" # 警告信息
+ if not check_sql == "" and check_sql != "\n":
check_sql = check_sql.strip()
# check_sql = f'''{check_sql}'''
# check_sql = check_sql.replace('\n', '') #处理成一行
# 支持的命令列表
- supportMethodList = ["explain", "bulkWrite", "convertToCapped", "createIndex", "createIndexes",
- "deleteOne",
- "deleteMany", "drop", "dropIndex", "dropIndexes", "ensureIndex", "insert",
- "insertOne",
- "insertMany", "remove", "replaceOne", "renameCollection", "update", "updateOne",
- "updateMany", "createCollection", "renameCollection"]
+ supportMethodList = [
+ "explain",
+ "bulkWrite",
+ "convertToCapped",
+ "createIndex",
+ "createIndexes",
+ "deleteOne",
+ "deleteMany",
+ "drop",
+ "dropIndex",
+ "dropIndexes",
+ "ensureIndex",
+ "insert",
+ "insertOne",
+ "insertMany",
+ "remove",
+ "replaceOne",
+ "renameCollection",
+ "update",
+ "updateOne",
+ "updateMany",
+ "createCollection",
+ "renameCollection",
+ ]
# 需要有表存在为前提的操作
- is_exist_premise_method = ["convertToCapped", "deleteOne", "deleteMany", "drop", "dropIndex",
- "dropIndexes",
- "remove", "replaceOne", "renameCollection", "update", "updateOne",
- "updateMany", "renameCollection"]
+ is_exist_premise_method = [
+ "convertToCapped",
+ "deleteOne",
+ "deleteMany",
+ "drop",
+ "dropIndex",
+ "dropIndexes",
+ "remove",
+ "replaceOne",
+ "renameCollection",
+ "update",
+ "updateOne",
+ "updateMany",
+ "renameCollection",
+ ]
pattern = re.compile(
- r'''^db\.createCollection\(([\s\S]*)\)$|^db\.([\w\.-]+)\.(?:[A-Za-z]+)(?:\([\s\S]*\)$)|^db\.getCollection\((?:\s*)(?:'|")([\w-]*)('|")(\s*)\)\.([A-Za-z]+)(\([\s\S]*\)$)''')
+ r"""^db\.createCollection\(([\s\S]*)\)$|^db\.([\w\.-]+)\.(?:[A-Za-z]+)(?:\([\s\S]*\)$)|^db\.getCollection\((?:\s*)(?:'|")([\w-]*)('|")(\s*)\)\.([A-Za-z]+)(\([\s\S]*\)$)"""
+ )
m = pattern.match(check_sql)
- if m is not None and (re.search(re.compile(r'}(?:\s*){'), check_sql) is None) and check_sql.count(
- '{') == check_sql.count('}') and check_sql.count('(') == check_sql.count(')'):
+ if (
+ m is not None
+ and (re.search(re.compile(r"}(?:\s*){"), check_sql) is None)
+ and check_sql.count("{") == check_sql.count("}")
+ and check_sql.count("(") == check_sql.count(")")
+ ):
sql_str = m.group()
- table_name = (m.group(1) or m.group(2) or m.group(3)).strip() # 通过正则的组拿到表名
- table_name = table_name.replace('"', '').replace("'", "")
+ table_name = (
+ m.group(1) or m.group(2) or m.group(3)
+ ).strip() # 通过正则的组拿到表名
+ table_name = table_name.replace('"', "").replace("'", "")
table_names = self.get_all_tables(db_name).rows
is_in = table_name in table_names # 检查表是否存在
if not is_in:
alert = f"\n提示:{table_name}文档不存在!"
if sql_str:
count = 0
- if sql_str.find('createCollection') > 0: # 如果是db.createCollection()
+ if (
+ sql_str.find("createCollection") > 0
+ ): # 如果是db.createCollection()
methodStr = "createCollection"
alert = ""
if is_in:
check_result.error = "文档已经存在"
- result = ReviewResult(id=line, errlevel=2,
- stagestatus='文档已经存在',
- errormessage='文档已经存在!',
- affected_rows=count,
- sql=check_sql)
+ result = ReviewResult(
+ id=line,
+ errlevel=2,
+ stagestatus="文档已经存在",
+ errormessage="文档已经存在!",
+ affected_rows=count,
+ sql=check_sql,
+ )
check_result.rows += [result]
continue
else:
# method = sql_str.split('.')[2]
# methodStr = method.split('(')[0].strip()
- methodStr = sql_str.split('(')[0].split('.')[-1].strip() # 最后一个.和括号(之间的字符串作为方法
+ methodStr = (
+ sql_str.split("(")[0].split(".")[-1].strip()
+ ) # 最后一个.和括号(之间的字符串作为方法
if methodStr in is_exist_premise_method and not is_in:
check_result.error = "文档不存在"
- result = ReviewResult(id=line, errlevel=2,
- stagestatus='文档不存在',
- errormessage=f'文档不存在,不能进行{methodStr}操作!',
- sql=check_sql)
+ result = ReviewResult(
+ id=line,
+ errlevel=2,
+ stagestatus="文档不存在",
+ errormessage=f"文档不存在,不能进行{methodStr}操作!",
+ sql=check_sql,
+ )
check_result.rows += [result]
continue
if methodStr in supportMethodList: # 检查方法是否支持
- if methodStr == "createIndex" or methodStr == "createIndexes" or methodStr == "ensureIndex": # 判断是否创建索引,如果大于500万,提醒不能在高峰期创建
+ if (
+ methodStr == "createIndex"
+ or methodStr == "createIndexes"
+ or methodStr == "ensureIndex"
+ ): # 判断是否创建索引,如果大于500万,提醒不能在高峰期创建
p_back = re.compile(
- r'''(['"])(?:(?!\1)background)\1(?:\s*):(?:\s*)true|background\s*:\s*true|(['"])(?:(?!\1)background)\1(?:\s*):(?:\s*)(['"])(?:(?!\2)true)\2''',
- re.M)
+ r"""(['"])(?:(?!\1)background)\1(?:\s*):(?:\s*)true|background\s*:\s*true|(['"])(?:(?!\1)background)\1(?:\s*):(?:\s*)(['"])(?:(?!\2)true)\2""",
+ re.M,
+ )
m_back = re.search(p_back, check_sql)
if m_back is None:
count = 5555555
- check_result.warning = '创建索引请加background:true'
+ check_result.warning = "创建索引请加background:true"
check_result.warning_count += 1
- result = ReviewResult(id=line, errlevel=2,
- stagestatus='后台创建索引',
- errormessage='创建索引没有加 background:true' + alert,
- sql=check_sql)
+ result = ReviewResult(
+ id=line,
+ errlevel=2,
+ stagestatus="后台创建索引",
+ errormessage="创建索引没有加 background:true" + alert,
+ sql=check_sql,
+ )
elif not is_in:
count = 0
else:
- count = self.get_table_conut(table_name, db_name) # 获得表的总条数
+ count = self.get_table_conut(
+ table_name, db_name
+ ) # 获得表的总条数
if count >= 5000000:
- check_result.warning = alert + '大于500万条,请在业务低谷期创建索引'
+ check_result.warning = (
+ alert + "大于500万条,请在业务低谷期创建索引"
+ )
check_result.warning_count += 1
- result = ReviewResult(id=line, errlevel=1,
- stagestatus='大表创建索引',
- errormessage='大于500万条,请在业务低谷期创建索引!',
- affected_rows=count,
- sql=check_sql)
+ result = ReviewResult(
+ id=line,
+ errlevel=1,
+ stagestatus="大表创建索引",
+ errormessage="大于500万条,请在业务低谷期创建索引!",
+ affected_rows=count,
+ sql=check_sql,
+ )
if count < 5000000:
# 检测通过
- affected_all_row_method = ["drop", "dropIndex", "dropIndexes", "createIndex",
- "createIndexes", "ensureIndex"]
+ affected_all_row_method = [
+ "drop",
+ "dropIndex",
+ "dropIndexes",
+ "createIndex",
+ "createIndexes",
+ "ensureIndex",
+ ]
if methodStr not in affected_all_row_method:
count = 0
else:
- count = self.get_table_conut(table_name, db_name) # 获得表的总条数
- result = ReviewResult(id=line, errlevel=0,
- stagestatus='Audit completed',
- errormessage='检测通过',
- affected_rows=count,
- sql=check_sql,
- execute_time=0)
+ count = self.get_table_conut(
+ table_name, db_name
+ ) # 获得表的总条数
+ result = ReviewResult(
+ id=line,
+ errlevel=0,
+ stagestatus="Audit completed",
+ errormessage="检测通过",
+ affected_rows=count,
+ sql=check_sql,
+ execute_time=0,
+ )
else:
- result = ReviewResult(id=line, errlevel=2,
- stagestatus='驳回不支持语句',
- errormessage='仅支持DML和DDL语句,如需查询请使用数据库查询功能!',
- sql=check_sql)
+ result = ReviewResult(
+ id=line,
+ errlevel=2,
+ stagestatus="驳回不支持语句",
+ errormessage="仅支持DML和DDL语句,如需查询请使用数据库查询功能!",
+ sql=check_sql,
+ )
else:
check_result.error = "语法错误"
- result = ReviewResult(id=line, errlevel=2,
- stagestatus='语法错误',
- errormessage='请检查语句的正确性或(){} },{是否正确匹配!',
- sql=check_sql)
+ result = ReviewResult(
+ id=line,
+ errlevel=2,
+ stagestatus="语法错误",
+ errormessage="请检查语句的正确性或(){} },{是否正确匹配!",
+ sql=check_sql,
+ )
check_result.rows += [result]
line += 1
count = 0
- check_result.column_list = ['Result'] # 审核结果的列名
+ check_result.column_list = ["Result"] # 审核结果的列名
check_result.checked = True
check_result.warning = self.warning
# 统计警告和错误数量
@@ -537,10 +681,15 @@ def execute_check(self, db_name=None, sql=''):
return check_result
def get_connection(self, db_name=None):
- self.db_name = db_name or self.instance.db_name or 'admin'
- auth_db = self.instance.db_name or 'admin'
- self.conn = pymongo.MongoClient(self.host, self.port, authSource=auth_db, connect=True,
- connectTimeoutMS=10000)
+ self.db_name = db_name or self.instance.db_name or "admin"
+ auth_db = self.instance.db_name or "admin"
+ self.conn = pymongo.MongoClient(
+ self.host,
+ self.port,
+ authSource=auth_db,
+ connect=True,
+ connectTimeoutMS=10000,
+ )
if self.user and self.password:
self.conn[self.db_name].authenticate(self.user, self.password, auth_db)
return self.conn
@@ -552,11 +701,11 @@ def close(self):
@property
def name(self): # pragma: no cover
- return 'Mongo'
+ return "Mongo"
@property
def info(self): # pragma: no cover
- return 'Mongo engine'
+ return "Mongo engine"
def get_roles(self):
sql_get_roles = "db.system.roles.find({},{_id:1})"
@@ -605,14 +754,19 @@ def get_all_columns_by_tb(self, db_name, tb_name, **kwargs):
for prop in document:
if prop not in columns:
columns.append(prop)
- result.column_list = ['COLUMN_NAME']
+ result.column_list = ["COLUMN_NAME"]
result.rows = columns
return result
def describe_table(self, db_name, tb_name, **kwargs):
"""return ResultSet 类似查询"""
result = self.get_all_columns_by_tb(db_name=db_name, tb_name=tb_name)
- result.rows = [[[r], ] for r in result.rows]
+ result.rows = [
+ [
+ [r],
+ ]
+ for r in result.rows
+ ]
return result
@staticmethod
@@ -625,7 +779,7 @@ def dispose_str(parse_sql, start_flag, index):
return index
index += 1
stop_flag = start_flag
- raise Exception('near column %s,\' or \" has no close' % index)
+ raise Exception("near column %s,' or \" has no close" % index)
def dispose_pair(self, parse_sql, index, begin, end):
"""解析处理需要配对的字符{}[]() 检索一个左括号计数器加1,右括号计数器减1"""
@@ -648,7 +802,9 @@ def dispose_pair(self, parse_sql, index, begin, end):
index = self.dispose_str(parse_sql, char, index)
index += 1
if count > 0:
- raise Exception("near column %s, The symbol %s has no closed" % (index, begin))
+ raise Exception(
+ "near column %s, The symbol %s has no closed" % (index, begin)
+ )
re_char = parse_sql[start_pos:stop_pos] # 截取
return index, re_char
@@ -686,7 +842,9 @@ def parse_query_sentence(self, parse_sql):
pipeline = []
agg_index = 0
while agg_index < len(re_char):
- p_index, condition = self.dispose_pair(re_char, agg_index, "{", "}")
+ p_index, condition = self.dispose_pair(
+ re_char, agg_index, "{", "}"
+ )
agg_index = p_index + 1
if condition:
de = JsonDecoder()
@@ -700,7 +858,7 @@ def parse_query_sentence(self, parse_sql):
query_dict["condition"] = pipeline
query_dict["method"] = method
elif method.lower() == "getcollection": # 获得表名
- collection = re_char.strip().replace("'", "").replace('"', '')
+ collection = re_char.strip().replace("'", "").replace('"', "")
query_dict["collection"] = collection
elif method.lower() == "getindexes":
query_dict["method"] = "index_information"
@@ -712,7 +870,7 @@ def parse_query_sentence(self, parse_sql):
if query_dict:
return query_dict
- def filter_sql(self, sql='', limit_num=0):
+ def filter_sql(self, sql="", limit_num=0):
"""给查询语句改写语句, 返回修改后的语句"""
sql = sql.split(";")[0].strip()
# 执行计划
@@ -720,37 +878,38 @@ def filter_sql(self, sql='', limit_num=0):
sql = sql.replace("explain", "") + ".explain()"
return sql.strip()
- def query_check(self, db_name=None, sql=''):
+ def query_check(self, db_name=None, sql=""):
"""提交查询前的检查"""
sql = sql.strip()
if sql.startswith("explain"):
sql = sql[7:] + ".explain()"
sql = re.sub("[;\s]*.explain\(\)$", ".explain()", sql).strip()
- result = {'msg': '', 'bad_query': False, 'filtered_sql': sql, 'has_star': False}
+ result = {"msg": "", "bad_query": False, "filtered_sql": sql, "has_star": False}
pattern = re.compile(
- r'''^db\.(\w+\.?)+(?:\([\s\S]*\)(\s*;*)$)|^db\.getCollection\((?:\s*)(?:'|")(\w+\.?)+('|")(\s*)\)\.([A-Za-z]+)(\([\s\S]*\)(\s*;*)$)''')
+ r"""^db\.(\w+\.?)+(?:\([\s\S]*\)(\s*;*)$)|^db\.getCollection\((?:\s*)(?:'|")(\w+\.?)+('|")(\s*)\)\.([A-Za-z]+)(\([\s\S]*\)(\s*;*)$)"""
+ )
m = pattern.match(sql)
if m is not None:
logger.debug(sql)
query_dict = self.parse_query_sentence(sql)
if "method" not in query_dict:
- result['msg'] += "错误:对不起,只支持查询相关方法"
- result['bad_query'] = True
+ result["msg"] += "错误:对不起,只支持查询相关方法"
+ result["bad_query"] = True
return result
collection_name = query_dict["collection"]
collection_names = self.get_all_tables(db_name).rows
is_in = collection_name in collection_names # 检查表是否存在
if not is_in:
- result['msg'] += f"\n错误: {collection_name} 文档不存在!"
- result['bad_query'] = True
+ result["msg"] += f"\n错误: {collection_name} 文档不存在!"
+ result["bad_query"] = True
return result
else:
- result['msg'] += '请检查语句的正确性! 请使用原生查询语句'
- result['bad_query'] = True
+ result["msg"] += "请检查语句的正确性! 请使用原生查询语句"
+ result["bad_query"] = True
return result
- def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs):
+ def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs):
"""执行查询"""
result_set = ResultSet(full_sql=sql)
@@ -781,7 +940,11 @@ def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs):
for k, v in de.decode(query_dict["sort"]).items():
sorting.append((k, v))
find_cmd += ".sort(sorting)"
- if method == "find" and "limit" not in query_dict and "explain" not in query_dict:
+ if (
+ method == "find"
+ and "limit" not in query_dict
+ and "explain" not in query_dict
+ ):
find_cmd += ".limit(limit_num)"
if "limit" in query_dict and query_dict["limit"]:
query_limit = int(query_dict["limit"])
@@ -820,7 +983,9 @@ def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs):
row = []
columns.insert(0, "mongodballdata")
for ro in cursor:
- json_col = json.dumps(ro, ensure_ascii=False, indent=2, separators=(",", ":"))
+ json_col = json.dumps(
+ ro, ensure_ascii=False, indent=2, separators=(",", ":")
+ )
row.insert(0, json_col)
for k, v in ro.items():
if k not in columns:
@@ -832,7 +997,7 @@ def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs):
result_set.rows = rows
else:
cursor = json.loads(json_util.dumps(cursor))
- cols = projection if 'projection' in dir() else None
+ cols = projection if "projection" in dir() else None
rows, columns = self.parse_tuple(cursor, db_name, collection_name, cols)
result_set.rows = rows
result_set.column_list = columns
@@ -840,7 +1005,9 @@ def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs):
if isinstance(rows, list):
logger.debug(rows)
result_set.rows = tuple(
- [json.dumps(x, ensure_ascii=False, indent=2, separators=(",", ":"))] for x in rows)
+ [json.dumps(x, ensure_ascii=False, indent=2, separators=(",", ":"))]
+ for x in rows
+ )
except Exception as e:
logger.warning(f"Mongo命令执行报错,语句:{sql}, 错误信息:{traceback.format_exc()}")
@@ -865,7 +1032,9 @@ def parse_tuple(self, cursor, db_name, tb_name, projection=None):
columns = self.fill_query_columns(cursor, columns)
for ro in cursor:
- json_col = json.dumps(ro, ensure_ascii=False, indent=2, separators=(",", ":"))
+ json_col = json.dumps(
+ ro, ensure_ascii=False, indent=2, separators=(",", ":")
+ )
row.insert(0, json_col)
for key in columns[1:]:
if key in ro:
@@ -877,13 +1046,17 @@ def parse_tuple(self, cursor, db_name, tb_name, projection=None):
# 转换$oid
ff = re.findall(re_oid, str(value))
for ii in ff:
- value = str(value).replace(ii, "ObjectId(" + ii.split(":")[1].strip()[:-1] + ")")
+ value = str(value).replace(
+ ii, "ObjectId(" + ii.split(":")[1].strip()[:-1] + ")"
+ )
# 转换时间戳$date
dd = re.findall(re_date, str(value))
for d in dd:
t = int(d.split(":")[1].strip()[:-1])
e = datetime.fromtimestamp(t / 1000)
- value = str(value).replace(d, e.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3])
+ value = str(value).replace(
+ d, e.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
+ )
row.append(str(value))
else:
row.append("(N/A)")
@@ -904,47 +1077,60 @@ def fill_query_columns(cursor, columns):
def current_op(self, command_type):
"""
获取当前连接信息
-
+
command_type:
- Full 包含活跃与不活跃的连接,包含内部的连接,即全部的连接状态
+ Full 包含活跃与不活跃的连接,包含内部的连接,即全部的连接状态
All 包含活跃与不活跃的连接,不包含内部的连接
Active 包含活跃
Inner 内部连接
"""
- result_set = ResultSet(full_sql='db.aggregate([{"$currentOp": {"allUsers":true, "idleConnections":true}}])')
+ result_set = ResultSet(
+ full_sql='db.aggregate([{"$currentOp": {"allUsers":true, "idleConnections":true}}])'
+ )
try:
conn = self.get_connection()
processlists = []
if not command_type:
- command_type = 'Active'
- if command_type in ['Full', 'All', 'Inner']:
+ command_type = "Active"
+ if command_type in ["Full", "All", "Inner"]:
idle_connections = True
else:
idle_connections = False
# conn.admin.current_op() 这个方法已经被pymongo废除,但mongodb3.6+才支持aggregate
with conn.admin.aggregate(
- [{'$currentOp': {'allUsers': True, 'idleConnections': idle_connections}}]) as cursor:
+ [
+ {
+ "$currentOp": {
+ "allUsers": True,
+ "idleConnections": idle_connections,
+ }
+ }
+ ]
+ ) as cursor:
for operation in cursor:
# 对sharding集群的特殊处理
- if 'client' not in operation and \
- operation.get('clientMetadata', {}).get('mongos', {}).get('client', {}):
- operation['client'] = operation['clientMetadata']['mongos']['client']
+ if "client" not in operation and operation.get(
+ "clientMetadata", {}
+ ).get("mongos", {}).get("client", {}):
+ operation["client"] = operation["clientMetadata"]["mongos"][
+ "client"
+ ]
# client_s 只是处理的mongos,并不是实际客户端
# client 在sharding获取不到?
- if command_type in ['Full']:
+ if command_type in ["Full"]:
processlists.append(operation)
- elif command_type in ['All', 'Active']:
- if 'clientMetadata' in operation:
+ elif command_type in ["All", "Active"]:
+ if "clientMetadata" in operation:
processlists.append(operation)
- elif command_type in ['Inner']:
- if not 'clientMetadata' in operation:
+ elif command_type in ["Inner"]:
+ if not "clientMetadata" in operation:
processlists.append(operation)
result_set.rows = processlists
except Exception as e:
- logger.warning(f'mongodb获取连接信息错误,错误信息{traceback.format_exc()}')
+ logger.warning(f"mongodb获取连接信息错误,错误信息{traceback.format_exc()}")
result_set.error = str(e)
return result_set
@@ -953,15 +1139,17 @@ def get_kill_command(self, opids):
"""由传入的opid列表生成kill字符串"""
conn = self.get_connection()
active_opid = []
- with conn.admin.aggregate([{'$currentOp': {'allUsers': True, 'idleConnections': False}}]) as cursor:
+ with conn.admin.aggregate(
+ [{"$currentOp": {"allUsers": True, "idleConnections": False}}]
+ ) as cursor:
for operation in cursor:
- if 'opid' in operation and operation['opid'] in opids:
- active_opid.append(operation['opid'])
+ if "opid" in operation and operation["opid"] in opids:
+ active_opid.append(operation["opid"])
- kill_command = ''
+ kill_command = ""
for opid in active_opid:
if isinstance(opid, int):
- kill_command = kill_command + 'db.killOp({});'.format(opid)
+ kill_command = kill_command + "db.killOp({});".format(opid)
else:
kill_command = kill_command + 'db.killOp("{}");'.format(opid)
@@ -974,11 +1162,13 @@ def kill_op(self, opids):
conn = self.get_connection()
db = conn.admin
for opid in opids:
- conn.admin.command({'killOp': 1, 'op': opid})
+ conn.admin.command({"killOp": 1, "op": opid})
except Exception as e:
try:
- sql = {'killOp': 1, 'op': _opid}
+ sql = {"killOp": 1, "op": _opid}
except:
- sql = {'killOp': 1, 'op': ''}
- logger.warning(f"mongodb语句执行killOp报错,语句:db.runCommand({sql}) ,错误信息{traceback.format_exc()}")
+ sql = {"killOp": 1, "op": ""}
+ logger.warning(
+ f"mongodb语句执行killOp报错,语句:db.runCommand({sql}) ,错误信息{traceback.format_exc()}"
+ )
result.error = str(e)
diff --git a/sql/engines/mssql.py b/sql/engines/mssql.py
index af7af070ee..712ba07224 100644
--- a/sql/engines/mssql.py
+++ b/sql/engines/mssql.py
@@ -9,7 +9,7 @@
from .models import ResultSet, ReviewSet, ReviewResult
from sql.utils.data_masking import brute_mask
-logger = logging.getLogger('default')
+logger = logging.getLogger("default")
class MssqlEngine(EngineBase):
@@ -17,8 +17,13 @@ class MssqlEngine(EngineBase):
def get_connection(self, db_name=None):
connstr = """DRIVER=ODBC Driver 17 for SQL Server;SERVER={0},{1};UID={2};PWD={3};
-client charset = UTF-8;connect timeout=10;CHARSET={4};""".format(self.host, self.port, self.user, self.password,
- self.instance.charset or 'UTF8')
+client charset = UTF-8;connect timeout=10;CHARSET={4};""".format(
+ self.host,
+ self.port,
+ self.user,
+ self.password,
+ self.instance.charset or "UTF8",
+ )
if self.conn:
return self.conn
self.conn = pyodbc.connect(connstr)
@@ -28,8 +33,11 @@ def get_all_databases(self):
"""获取数据库列表, 返回一个ResultSet"""
sql = "SELECT name FROM master.sys.databases order by name"
result = self.query(sql=sql)
- db_list = [row[0] for row in result.rows
- if row[0] not in ('master', 'msdb', 'tempdb', 'model')]
+ db_list = [
+ row[0]
+ for row in result.rows
+ if row[0] not in ("master", "msdb", "tempdb", "model")
+ ]
result.rows = db_list
return result
@@ -37,9 +45,11 @@ def get_all_tables(self, db_name, **kwargs):
"""获取table 列表, 返回一个ResultSet"""
sql = """SELECT TABLE_NAME
FROM {0}.INFORMATION_SCHEMA.TABLES
- WHERE TABLE_TYPE = 'BASE TABLE' order by TABLE_NAME;""".format(db_name)
+ WHERE TABLE_TYPE = 'BASE TABLE' order by TABLE_NAME;""".format(
+ db_name
+ )
result = self.query(db_name=db_name, sql=sql)
- tb_list = [row[0] for row in result.rows if row[0] not in ['test']]
+ tb_list = [row[0] for row in result.rows if row[0] not in ["test"]]
result.rows = tb_list
return result
@@ -120,7 +130,7 @@ def get_table_meta_data(self, db_name, tb_name, **kwargs):
ON index_size.table_name = space.table_name;
"""
_meta_data = self.query(db_name, sql)
- return {'column_list': _meta_data.column_list, 'rows': _meta_data.rows[0]}
+ return {"column_list": _meta_data.column_list, "rows": _meta_data.rows[0]}
def get_table_desc_data(self, db_name, tb_name, **kwargs):
"""获取表格字段信息"""
@@ -132,7 +142,7 @@ def get_table_desc_data(self, db_name, tb_name, **kwargs):
COLUMN_DEFAULT 默认值
from INFORMATION_SCHEMA.columns where TABLE_CATALOG='{db_name}' and TABLE_NAME = '{tb_name}';"""
_desc_data = self.query(db_name, sql)
- return {'column_list': _desc_data.column_list, 'rows': _desc_data.rows}
+ return {"column_list": _desc_data.column_list, "rows": _desc_data.rows}
def get_table_index_data(self, db_name, tb_name, **kwargs):
"""获取表格索引信息"""
@@ -145,7 +155,7 @@ def get_table_index_data(self, db_name, tb_name, **kwargs):
WHERE i.object_id = OBJECT_ID('{tb_name}')
group by i.name,i.object_id,i.index_id,is_unique,is_primary_key;"""
_index_data = self.query(db_name, sql)
- return {'column_list': _index_data.column_list, 'rows': _index_data.rows}
+ return {"column_list": _index_data.column_list, "rows": _index_data.rows}
def get_tables_metas_data(self, db_name, **kwargs):
"""获取数据库所有表格信息,用作数据字典导出接口"""
@@ -165,11 +175,15 @@ def get_tables_metas_data(self, db_name, **kwargs):
table_metas = []
for tb in tbs:
_meta = dict()
- engine_keys = [{"key": "COLUMN_NAME", "value": "字段名"}, {"key": "COLUMN_TYPE", "value": "数据类型"},
- {"key": "COLLATION_NAME", "value": "列字符集"}, {"key": "IS_NULLABLE", "value": "允许非空"},
- {"key": "COLUMN_DEFAULT", "value": "默认值"}]
+ engine_keys = [
+ {"key": "COLUMN_NAME", "value": "字段名"},
+ {"key": "COLUMN_TYPE", "value": "数据类型"},
+ {"key": "COLLATION_NAME", "value": "列字符集"},
+ {"key": "IS_NULLABLE", "value": "允许非空"},
+ {"key": "COLUMN_DEFAULT", "value": "默认值"},
+ ]
_meta["ENGINE_KEYS"] = engine_keys
- _meta['TABLE_INFO'] = tb
+ _meta["TABLE_INFO"] = tb
sql_cols = f"""select COLUMN_NAME, case when ISNUMERIC(CHARACTER_MAXIMUM_LENGTH)=1
then DATA_TYPE + '(' + convert(varchar(max), CHARACTER_MAXIMUM_LENGTH) + ')' else DATA_TYPE end COLUMN_TYPE,
COLLATION_NAME,
@@ -182,7 +196,7 @@ def get_tables_metas_data(self, db_name, **kwargs):
# 转换查询结果为dict
for row in query_result.rows:
columns.append(dict(zip(query_result.column_list, row)))
- _meta['COLUMNS'] = tuple(columns)
+ _meta["COLUMNS"] = tuple(columns)
table_metas.append(_meta)
return table_metas
@@ -211,64 +225,89 @@ def describe_table(self, db_name, tb_name, **kwargs):
left join {0}..sysindexkeys i on i.id=o.id and i.colid=c.colid and i.indid=ie.indid
WHERE O.name NOT LIKE 'MS%' AND O.name NOT LIKE 'SY%'
and O.name='{1}'
- order by o.name,c.colid""".format(db_name, tb_name)
+ order by o.name,c.colid""".format(
+ db_name, tb_name
+ )
result = self.query(sql=sql)
return result
- def query_check(self, db_name=None, sql=''):
+ def query_check(self, db_name=None, sql=""):
# 查询语句的检查、注释去除、切分
- result = {'msg': '', 'bad_query': False, 'filtered_sql': sql, 'has_star': False}
- banned_keywords = ["ascii", "char", "charindex", "concat", "concat_ws", "difference", "format",
- "len", "nchar", "patindex", "quotename", "replace", "replicate",
- "reverse", "right", "soundex", "space", "str", "string_agg",
- "string_escape", "string_split", "stuff", "substring", "trim", "unicode"]
- keyword_warning = ''
+ result = {"msg": "", "bad_query": False, "filtered_sql": sql, "has_star": False}
+ banned_keywords = [
+ "ascii",
+ "char",
+ "charindex",
+ "concat",
+ "concat_ws",
+ "difference",
+ "format",
+ "len",
+ "nchar",
+ "patindex",
+ "quotename",
+ "replace",
+ "replicate",
+ "reverse",
+ "right",
+ "soundex",
+ "space",
+ "str",
+ "string_agg",
+ "string_escape",
+ "string_split",
+ "stuff",
+ "substring",
+ "trim",
+ "unicode",
+ ]
+ keyword_warning = ""
star_patter = r"(^|,|\s)\*(\s|\(|$)"
- sql_whitelist = ['select', 'sp_helptext']
+ sql_whitelist = ["select", "sp_helptext"]
# 根据白名单list拼接pattern语句
whitelist_pattern = "^" + "|^".join(sql_whitelist)
# 删除注释语句,进行语法判断,执行第一条有效sql
try:
sql = sql.format(sql, strip_comments=True)
sql = sqlparse.split(sql)[0]
- result['filtered_sql'] = sql.strip()
+ result["filtered_sql"] = sql.strip()
sql_lower = sql.lower()
except IndexError:
- result['bad_query'] = True
- result['msg'] = '没有有效的SQL语句'
+ result["bad_query"] = True
+ result["msg"] = "没有有效的SQL语句"
return result
if re.match(whitelist_pattern, sql_lower) is None:
- result['bad_query'] = True
- result['msg'] = '仅支持{}语法!'.format(','.join(sql_whitelist))
+ result["bad_query"] = True
+ result["msg"] = "仅支持{}语法!".format(",".join(sql_whitelist))
return result
if re.search(star_patter, sql_lower) is not None:
- keyword_warning += '禁止使用 * 关键词\n'
- result['has_star'] = True
+ keyword_warning += "禁止使用 * 关键词\n"
+ result["has_star"] = True
for keyword in banned_keywords:
pattern = r"(^|,| |=){}( |\(|$)".format(keyword)
if re.search(pattern, sql_lower) is not None:
- keyword_warning += '禁止使用 {} 关键词\n'.format(keyword)
- result['bad_query'] = True
- if result.get('bad_query') or result.get('has_star'):
- result['msg'] = keyword_warning
+ keyword_warning += "禁止使用 {} 关键词\n".format(keyword)
+ result["bad_query"] = True
+ if result.get("bad_query") or result.get("has_star"):
+ result["msg"] = keyword_warning
return result
- def filter_sql(self, sql='', limit_num=0):
+ def filter_sql(self, sql="", limit_num=0):
sql_lower = sql.lower()
# 对查询sql增加limit限制
if re.match(r"^select", sql_lower):
- if sql_lower.find(' top ') == -1:
- return sql_lower.replace('select', 'select top {}'.format(limit_num))
+ if sql_lower.find(" top ") == -1:
+ return sql_lower.replace("select", "select top {}".format(limit_num))
return sql.strip()
- def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs):
- """返回 ResultSet """
+ def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs):
+ """返回 ResultSet"""
result_set = ResultSet(full_sql=sql)
try:
conn = self.get_connection()
cursor = conn.cursor()
if db_name:
- cursor.execute('use [{}];'.format(db_name))
+ cursor.execute("use [{}];".format(db_name))
cursor.execute(sql)
if int(limit_num) > 0:
rows = cursor.fetchmany(int(limit_num))
@@ -287,7 +326,7 @@ def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs):
self.close()
return result_set
- def query_masking(self, db_name=None, sql='', resultset=None):
+ def query_masking(self, db_name=None, sql="", resultset=None):
"""传入 sql语句, db名, 结果集,
返回一个脱敏后的结果集"""
# 仅对select语句脱敏
@@ -298,11 +337,11 @@ def query_masking(self, db_name=None, sql='', resultset=None):
filtered_result = resultset
return filtered_result
- def execute_check(self, db_name=None, sql=''):
+ def execute_check(self, db_name=None, sql=""):
"""上线单执行前的检查, 返回Review set"""
check_result = ReviewSet(full_sql=sql)
# 切分语句,追加到检测结果中,默认全部检测通过
- split_reg = re.compile('^GO$', re.I | re.M)
+ split_reg = re.compile("^GO$", re.I | re.M)
sql = re.split(split_reg, sql, 0)
sql = filter(None, sql)
split_sql = [f"""use [{db_name}]"""]
@@ -310,14 +349,17 @@ def execute_check(self, db_name=None, sql=''):
split_sql = split_sql + [i]
rowid = 1
for statement in split_sql:
- check_result.rows.append(ReviewResult(
- id=rowid,
- errlevel=0,
- stagestatus='Audit completed',
- errormessage='None',
- sql=statement,
- affected_rows=0,
- execute_time=0, ))
+ check_result.rows.append(
+ ReviewResult(
+ id=rowid,
+ errlevel=0,
+ stagestatus="Audit completed",
+ errormessage="None",
+ sql=statement,
+ affected_rows=0,
+ execute_time=0,
+ )
+ )
rowid += 1
return check_result
@@ -325,14 +367,16 @@ def execute_workflow(self, workflow):
if workflow.is_backup:
# TODO mssql 备份未实现
pass
- return self.execute(db_name=workflow.db_name, sql=workflow.sqlworkflowcontent.sql_content)
+ return self.execute(
+ db_name=workflow.db_name, sql=workflow.sqlworkflowcontent.sql_content
+ )
- def execute(self, db_name=None, sql='', close_conn=True):
+ def execute(self, db_name=None, sql="", close_conn=True):
"""执行sql语句 返回 Review set"""
execute_result = ReviewSet(full_sql=sql)
conn = self.get_connection(db_name=db_name)
cursor = conn.cursor()
- split_reg = re.compile('^GO$', re.I | re.M)
+ split_reg = re.compile("^GO$", re.I | re.M)
sql = re.split(split_reg, sql, 0)
sql = filter(None, sql)
split_sql = [f"""use [{db_name}]"""]
@@ -345,44 +389,50 @@ def execute(self, db_name=None, sql='', close_conn=True):
except Exception as e:
logger.warning(f"Mssql命令执行报错,语句:{sql}, 错误信息:{traceback.format_exc()}")
execute_result.error = str(e)
- execute_result.rows.append(ReviewResult(
- id=rowid,
- errlevel=2,
- stagestatus='Execute Failed',
- errormessage=f'异常信息:{e}',
- sql=statement,
- affected_rows=0,
- execute_time=0,
- ))
+ execute_result.rows.append(
+ ReviewResult(
+ id=rowid,
+ errlevel=2,
+ stagestatus="Execute Failed",
+ errormessage=f"异常信息:{e}",
+ sql=statement,
+ affected_rows=0,
+ execute_time=0,
+ )
+ )
break
else:
- execute_result.rows.append(ReviewResult(
- id=rowid,
- errlevel=0,
- stagestatus='Execute Successfully',
- errormessage='None',
- sql=statement,
- affected_rows=cursor.rowcount,
- execute_time=0,
- ))
+ execute_result.rows.append(
+ ReviewResult(
+ id=rowid,
+ errlevel=0,
+ stagestatus="Execute Successfully",
+ errormessage="None",
+ sql=statement,
+ affected_rows=cursor.rowcount,
+ execute_time=0,
+ )
+ )
rowid += 1
if execute_result.error:
# 如果失败, 将剩下的部分加入结果集, 并将语句回滚
for statement in split_sql[rowid:]:
- execute_result.rows.append(ReviewResult(
- id=rowid,
- errlevel=2,
- stagestatus='Execute Failed',
- errormessage=f'前序语句失败, 未执行',
- sql=statement,
- affected_rows=0,
- execute_time=0,
- ))
+ execute_result.rows.append(
+ ReviewResult(
+ id=rowid,
+ errlevel=2,
+ stagestatus="Execute Failed",
+ errormessage=f"前序语句失败, 未执行",
+ sql=statement,
+ affected_rows=0,
+ execute_time=0,
+ )
+ )
rowid += 1
cursor.rollback()
for row in execute_result.rows:
- if row.stagestatus == 'Execute Successfully':
- row.stagestatus += '\nRollback Successfully'
+ if row.stagestatus == "Execute Successfully":
+ row.stagestatus += "\nRollback Successfully"
else:
cursor.commit()
if close_conn:
diff --git a/sql/engines/mysql.py b/sql/engines/mysql.py
index 00fa4de15c..6845e78798 100644
--- a/sql/engines/mysql.py
+++ b/sql/engines/mysql.py
@@ -16,7 +16,7 @@
from sql.utils.data_masking import data_masking
from common.config import SysConfig
-logger = logging.getLogger('default')
+logger = logging.getLogger("default")
class MysqlEngine(EngineBase):
@@ -30,30 +30,41 @@ def __init__(self, instance=None):
def get_connection(self, db_name=None):
# https://stackoverflow.com/questions/19256155/python-mysqldb-returning-x01-for-bit-values
conversions = MySQLdb.converters.conversions
- conversions[FIELD_TYPE.BIT] = lambda data: data == b'\x01'
+ conversions[FIELD_TYPE.BIT] = lambda data: data == b"\x01"
if self.conn:
self.thread_id = self.conn.thread_id()
return self.conn
if db_name:
- self.conn = MySQLdb.connect(host=self.host, port=self.port, user=self.user, passwd=self.password,
- db=db_name, charset=self.instance.charset or 'utf8mb4',
- conv=conversions,
- connect_timeout=10)
+ self.conn = MySQLdb.connect(
+ host=self.host,
+ port=self.port,
+ user=self.user,
+ passwd=self.password,
+ db=db_name,
+ charset=self.instance.charset or "utf8mb4",
+ conv=conversions,
+ connect_timeout=10,
+ )
else:
- self.conn = MySQLdb.connect(host=self.host, port=self.port, user=self.user, passwd=self.password,
- charset=self.instance.charset or 'utf8mb4',
- conv=conversions,
- connect_timeout=10)
+ self.conn = MySQLdb.connect(
+ host=self.host,
+ port=self.port,
+ user=self.user,
+ passwd=self.password,
+ charset=self.instance.charset or "utf8mb4",
+ conv=conversions,
+ connect_timeout=10,
+ )
self.thread_id = self.conn.thread_id()
return self.conn
@property
def name(self):
- return 'MySQL'
+ return "MySQL"
@property
def info(self):
- return 'MySQL engine'
+ return "MySQL engine"
@property
def auto_backup(self):
@@ -62,14 +73,21 @@ def auto_backup(self):
@property
def seconds_behind_master(self):
- slave_status = self.query(sql='show slave status', close_conn=False, cursorclass=MySQLdb.cursors.DictCursor)
- return slave_status.rows[0].get('Seconds_Behind_Master') if slave_status.rows else None
+ slave_status = self.query(
+ sql="show slave status",
+ close_conn=False,
+ cursorclass=MySQLdb.cursors.DictCursor,
+ )
+ return (
+ slave_status.rows[0].get("Seconds_Behind_Master")
+ if slave_status.rows
+ else None
+ )
@property
def server_version(self):
def numeric_part(s):
- """Returns the leading numeric part of a string.
- """
+ """Returns the leading numeric part of a string."""
re_numeric_part = re.compile(r"^(\d+)")
m = re_numeric_part.match(s)
if m:
@@ -78,27 +96,32 @@ def numeric_part(s):
self.get_connection()
version = self.conn.get_server_info()
- return tuple([numeric_part(n) for n in version.split('.')[:3]])
+ return tuple([numeric_part(n) for n in version.split(".")[:3]])
@property
def schema_object(self):
"""获取实例对象信息"""
- url = build_database_url(host=self.host,
- username=self.user,
- password=self.password,
- port=self.port)
- return schemaobject.SchemaObject(url, charset=self.instance.charset or 'utf8mb4')
+ url = build_database_url(
+ host=self.host, username=self.user, password=self.password, port=self.port
+ )
+ return schemaobject.SchemaObject(
+ url, charset=self.instance.charset or "utf8mb4"
+ )
def kill_connection(self, thread_id):
"""终止数据库连接"""
- self.query(sql=f'kill {thread_id}')
+ self.query(sql=f"kill {thread_id}")
def get_all_databases(self):
"""获取数据库列表, 返回一个ResultSet"""
sql = "show databases"
result = self.query(sql=sql)
- db_list = [row[0] for row in result.rows
- if row[0] not in ('information_schema', 'performance_schema', 'mysql', 'test', 'sys')]
+ db_list = [
+ row[0]
+ for row in result.rows
+ if row[0]
+ not in ("information_schema", "performance_schema", "mysql", "test", "sys")
+ ]
result.rows = db_list
return result
@@ -106,13 +129,13 @@ def get_all_tables(self, db_name, **kwargs):
"""获取table 列表, 返回一个ResultSet"""
sql = "show tables"
result = self.query(db_name=db_name, sql=sql)
- tb_list = [row[0] for row in result.rows if row[0] not in ['test']]
+ tb_list = [row[0] for row in result.rows if row[0] not in ["test"]]
result.rows = tb_list
return result
def get_group_tables_by_db(self, db_name):
# escape
- db_name = MySQLdb.escape_string(db_name).decode('utf-8')
+ db_name = MySQLdb.escape_string(db_name).decode("utf-8")
data = {}
sql = f"""SELECT TABLE_NAME,
TABLE_COMMENT
@@ -131,8 +154,8 @@ def get_group_tables_by_db(self, db_name):
def get_table_meta_data(self, db_name, tb_name, **kwargs):
"""数据字典页面使用:获取表格的元信息,返回一个dict{column_list: [], rows: []}"""
# escape
- db_name = MySQLdb.escape_string(db_name).decode('utf-8')
- tb_name = MySQLdb.escape_string(tb_name).decode('utf-8')
+ db_name = MySQLdb.escape_string(db_name).decode("utf-8")
+ tb_name = MySQLdb.escape_string(tb_name).decode("utf-8")
sql = f"""SELECT
TABLE_NAME as table_name,
ENGINE as engine,
@@ -156,7 +179,7 @@ def get_table_meta_data(self, db_name, tb_name, **kwargs):
TABLE_SCHEMA='{db_name}'
AND TABLE_NAME='{tb_name}'"""
_meta_data = self.query(db_name, sql)
- return {'column_list': _meta_data.column_list, 'rows': _meta_data.rows[0]}
+ return {"column_list": _meta_data.column_list, "rows": _meta_data.rows[0]}
def get_table_desc_data(self, db_name, tb_name, **kwargs):
"""获取表格字段信息"""
@@ -176,7 +199,7 @@ def get_table_desc_data(self, db_name, tb_name, **kwargs):
AND TABLE_NAME = '{tb_name}'
ORDER BY ORDINAL_POSITION;"""
_desc_data = self.query(db_name, sql)
- return {'column_list': _desc_data.column_list, 'rows': _desc_data.rows}
+ return {"column_list": _desc_data.column_list, "rows": _desc_data.rows}
def get_table_index_data(self, db_name, tb_name, **kwargs):
"""获取表格索引信息"""
@@ -195,24 +218,35 @@ def get_table_index_data(self, db_name, tb_name, **kwargs):
TABLE_SCHEMA = '{db_name}'
AND TABLE_NAME = '{tb_name}';"""
_index_data = self.query(db_name, sql)
- return {'column_list': _index_data.column_list, 'rows': _index_data.rows}
+ return {"column_list": _index_data.column_list, "rows": _index_data.rows}
def get_tables_metas_data(self, db_name, **kwargs):
"""获取数据库所有表格信息,用作数据字典导出接口"""
- sql_tbs = f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA='{db_name}';"
- tbs = self.query(sql=sql_tbs, cursorclass=MySQLdb.cursors.DictCursor, close_conn=False).rows
+ sql_tbs = (
+ f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA='{db_name}';"
+ )
+ tbs = self.query(
+ sql=sql_tbs, cursorclass=MySQLdb.cursors.DictCursor, close_conn=False
+ ).rows
table_metas = []
for tb in tbs:
_meta = dict()
- engine_keys = [{"key": "COLUMN_NAME", "value": "字段名"}, {"key": "COLUMN_TYPE", "value": "数据类型"},
- {"key": "COLUMN_DEFAULT", "value": "默认值"}, {"key": "IS_NULLABLE", "value": "允许非空"},
- {"key": "EXTRA", "value": "自动递增"}, {"key": "COLUMN_KEY", "value": "是否主键"},
- {"key": "COLUMN_COMMENT", "value": "备注"}]
+ engine_keys = [
+ {"key": "COLUMN_NAME", "value": "字段名"},
+ {"key": "COLUMN_TYPE", "value": "数据类型"},
+ {"key": "COLUMN_DEFAULT", "value": "默认值"},
+ {"key": "IS_NULLABLE", "value": "允许非空"},
+ {"key": "EXTRA", "value": "自动递增"},
+ {"key": "COLUMN_KEY", "value": "是否主键"},
+ {"key": "COLUMN_COMMENT", "value": "备注"},
+ ]
_meta["ENGINE_KEYS"] = engine_keys
- _meta['TABLE_INFO'] = tb
+ _meta["TABLE_INFO"] = tb
sql_cols = f"""SELECT * FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA='{tb['TABLE_SCHEMA']}' AND TABLE_NAME='{tb['TABLE_NAME']}';"""
- _meta['COLUMNS'] = self.query(sql=sql_cols, cursorclass=MySQLdb.cursors.DictCursor, close_conn=False).rows
+ _meta["COLUMNS"] = self.query(
+ sql=sql_cols, cursorclass=MySQLdb.cursors.DictCursor, close_conn=False
+ ).rows
table_metas.append(_meta)
return table_metas
@@ -243,11 +277,11 @@ def describe_table(self, db_name, tb_name, **kwargs):
result = self.query(db_name=db_name, sql=sql)
return result
- def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs):
- """返回 ResultSet """
+ def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs):
+ """返回 ResultSet"""
result_set = ResultSet(full_sql=sql)
- max_execution_time = kwargs.get('max_execution_time', 0)
- cursorclass = kwargs.get('cursorclass') or MySQLdb.cursors.Cursor
+ max_execution_time = kwargs.get("max_execution_time", 0)
+ cursorclass = kwargs.get("cursorclass") or MySQLdb.cursors.Cursor
try:
conn = self.get_connection(db_name=db_name)
conn.autocommit(True)
@@ -274,68 +308,75 @@ def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs):
self.close()
return result_set
- def query_check(self, db_name=None, sql=''):
+ def query_check(self, db_name=None, sql=""):
# 查询语句的检查、注释去除、切分
- result = {'msg': '', 'bad_query': False, 'filtered_sql': sql, 'has_star': False}
+ result = {"msg": "", "bad_query": False, "filtered_sql": sql, "has_star": False}
# 删除注释语句,进行语法判断,执行第一条有效sql
try:
sql = sqlparse.format(sql, strip_comments=True)
sql = sqlparse.split(sql)[0]
- result['filtered_sql'] = sql.strip()
+ result["filtered_sql"] = sql.strip()
except IndexError:
- result['bad_query'] = True
- result['msg'] = '没有有效的SQL语句'
+ result["bad_query"] = True
+ result["msg"] = "没有有效的SQL语句"
if re.match(r"^select|^show|^explain", sql, re.I) is None:
- result['bad_query'] = True
- result['msg'] = '不支持的查询语法类型!'
- if '*' in sql:
- result['has_star'] = True
- result['msg'] = 'SQL语句中含有 * '
+ result["bad_query"] = True
+ result["msg"] = "不支持的查询语法类型!"
+ if "*" in sql:
+ result["has_star"] = True
+ result["msg"] = "SQL语句中含有 * "
# select语句先使用Explain判断语法是否正确
if re.match(r"^select", sql, re.I):
explain_result = self.query(db_name=db_name, sql=f"explain {sql}")
if explain_result.error:
- result['bad_query'] = True
- result['msg'] = explain_result.error
+ result["bad_query"] = True
+ result["msg"] = explain_result.error
# 不应该查看mysql.user表
- if re.match('.*(\\s)+(mysql|`mysql`)(\\s)*\\.(\\s)*(user|`user`)((\\s)*|;).*', sql.lower().replace('\n', '')) or \
- (db_name == "mysql" and re.match('.*(\\s)+(user|`user`)((\\s)*|;).*', sql.lower().replace('\n', ''))):
- result['bad_query'] = True
- result['msg'] = '您无权查看该表'
+ if re.match(
+ ".*(\\s)+(mysql|`mysql`)(\\s)*\\.(\\s)*(user|`user`)((\\s)*|;).*",
+ sql.lower().replace("\n", ""),
+ ) or (
+ db_name == "mysql"
+ and re.match(
+ ".*(\\s)+(user|`user`)((\\s)*|;).*", sql.lower().replace("\n", "")
+ )
+ ):
+ result["bad_query"] = True
+ result["msg"] = "您无权查看该表"
return result
- def filter_sql(self, sql='', limit_num=0):
+ def filter_sql(self, sql="", limit_num=0):
# 对查询sql增加limit限制,limit n 或 limit n,n 或 limit n offset n统一改写成limit n
- sql = sql.rstrip(';').strip()
+ sql = sql.rstrip(";").strip()
if re.match(r"^select", sql, re.I):
# LIMIT N
- limit_n = re.compile(r'limit\s+(\d+)\s*$', re.I)
+ limit_n = re.compile(r"limit\s+(\d+)\s*$", re.I)
# LIMIT M OFFSET N
- limit_offset = re.compile(r'limit\s+(\d+)\s+offset\s+(\d+)\s*$', re.I)
+ limit_offset = re.compile(r"limit\s+(\d+)\s+offset\s+(\d+)\s*$", re.I)
# LIMIT M,N
- offset_comma_limit = re.compile(r'limit\s+(\d+)\s*,\s*(\d+)\s*$', re.I)
+ offset_comma_limit = re.compile(r"limit\s+(\d+)\s*,\s*(\d+)\s*$", re.I)
if limit_n.search(sql):
sql_limit = limit_n.search(sql).group(1)
limit_num = min(int(limit_num), int(sql_limit))
- sql = limit_n.sub(f'limit {limit_num};', sql)
+ sql = limit_n.sub(f"limit {limit_num};", sql)
elif limit_offset.search(sql):
sql_limit = limit_offset.search(sql).group(1)
sql_offset = limit_offset.search(sql).group(2)
limit_num = min(int(limit_num), int(sql_limit))
- sql = limit_offset.sub(f'limit {limit_num} offset {sql_offset};', sql)
+ sql = limit_offset.sub(f"limit {limit_num} offset {sql_offset};", sql)
elif offset_comma_limit.search(sql):
sql_offset = offset_comma_limit.search(sql).group(1)
sql_limit = offset_comma_limit.search(sql).group(2)
limit_num = min(int(limit_num), int(sql_limit))
- sql = offset_comma_limit.sub(f'limit {sql_offset},{limit_num};', sql)
+ sql = offset_comma_limit.sub(f"limit {sql_offset},{limit_num};", sql)
else:
- sql = f'{sql} limit {limit_num};'
+ sql = f"{sql} limit {limit_num};"
else:
- sql = f'{sql};'
+ sql = f"{sql};"
return sql
- def query_masking(self, db_name=None, sql='', resultset=None):
+ def query_masking(self, db_name=None, sql="", resultset=None):
"""传入 sql语句, db名, 结果集,
返回一个脱敏后的结果集"""
# 仅对select语句脱敏
@@ -345,53 +386,65 @@ def query_masking(self, db_name=None, sql='', resultset=None):
mask_result = resultset
return mask_result
- def execute_check(self, db_name=None, sql=''):
+ def execute_check(self, db_name=None, sql=""):
"""上线单执行前的检查, 返回Review set"""
# 进行Inception检查,获取检测结果
try:
- check_result = self.inc_engine.execute_check(instance=self.instance, db_name=db_name, sql=sql)
+ check_result = self.inc_engine.execute_check(
+ instance=self.instance, db_name=db_name, sql=sql
+ )
except Exception as e:
logger.debug(f"{self.inc_engine.name}检测语句报错:错误信息{traceback.format_exc()}")
- raise RuntimeError(f"{self.inc_engine.name}检测语句报错,请注意检查系统配置中{self.inc_engine.name}配置,错误信息:\n{e}")
+ raise RuntimeError(
+ f"{self.inc_engine.name}检测语句报错,请注意检查系统配置中{self.inc_engine.name}配置,错误信息:\n{e}"
+ )
# 判断Inception检测结果
if check_result.error:
logger.debug(f"{self.inc_engine.name}检测语句报错:错误信息{check_result.error}")
- raise RuntimeError(f"{self.inc_engine.name}检测语句报错,错误信息:\n{check_result.error}")
+ raise RuntimeError(
+ f"{self.inc_engine.name}检测语句报错,错误信息:\n{check_result.error}"
+ )
# 禁用/高危语句检查
- critical_ddl_regex = self.config.get('critical_ddl_regex', '')
+ critical_ddl_regex = self.config.get("critical_ddl_regex", "")
p = re.compile(critical_ddl_regex)
for row in check_result.rows:
statement = row.sql
# 去除注释
- statement = remove_comments(statement, db_type='mysql')
+ statement = remove_comments(statement, db_type="mysql")
# 禁用语句
if re.match(r"^select", statement.lower()):
check_result.error_count += 1
- row.stagestatus = '驳回不支持语句'
+ row.stagestatus = "驳回不支持语句"
row.errlevel = 2
- row.errormessage = '仅支持DML和DDL语句,查询语句请使用SQL查询功能!'
+ row.errormessage = "仅支持DML和DDL语句,查询语句请使用SQL查询功能!"
# 高危语句
elif critical_ddl_regex and p.match(statement.strip().lower()):
check_result.error_count += 1
- row.stagestatus = '驳回高危SQL'
+ row.stagestatus = "驳回高危SQL"
row.errlevel = 2
- row.errormessage = '禁止提交匹配' + critical_ddl_regex + '条件的语句!'
+ row.errormessage = "禁止提交匹配" + critical_ddl_regex + "条件的语句!"
return check_result
def execute_workflow(self, workflow):
"""执行上线单,返回Review set"""
# 判断实例是否只读
- read_only = self.query(sql='SELECT @@global.read_only;').rows[0][0]
- if read_only in (1, 'ON'):
+ read_only = self.query(sql="SELECT @@global.read_only;").rows[0][0]
+ if read_only in (1, "ON"):
result = ReviewSet(
full_sql=workflow.sqlworkflowcontent.sql_content,
- rows=[ReviewResult(id=1, errlevel=2,
- stagestatus='Execute Failed',
- errormessage='实例read_only=1,禁止执行变更语句!',
- sql=workflow.sqlworkflowcontent.sql_content)])
- result.error = '实例read_only=1,禁止执行变更语句!',
+ rows=[
+ ReviewResult(
+ id=1,
+ errlevel=2,
+ stagestatus="Execute Failed",
+ errormessage="实例read_only=1,禁止执行变更语句!",
+ sql=workflow.sqlworkflowcontent.sql_content,
+ )
+ ],
+ )
+ result.error = ("实例read_only=1,禁止执行变更语句!",)
return result
# TODO 原生执行
# if workflow.is_manual == 1:
@@ -399,7 +452,7 @@ def execute_workflow(self, workflow):
# inception执行
return self.inc_engine.execute(workflow)
- def execute(self, db_name=None, sql='', close_conn=True):
+ def execute(self, db_name=None, sql="", close_conn=True):
"""原生执行语句"""
result = ResultSet(full_sql=sql)
conn = self.get_connection(db_name=db_name)
@@ -424,8 +477,16 @@ def get_rollback(self, workflow):
def get_variables(self, variables=None):
"""获取实例参数"""
if variables:
- variables = "','".join(variables) if isinstance(variables, list) else "','".join(list(variables))
- db = 'performance_schema' if self.server_version > (5, 7) else 'information_schema'
+ variables = (
+ "','".join(variables)
+ if isinstance(variables, list)
+ else "','".join(list(variables))
+ )
+ db = (
+ "performance_schema"
+ if self.server_version > (5, 7)
+ else "information_schema"
+ )
sql = f"""select * from {db}.global_variables where variable_name in ('{variables}');"""
else:
sql = "show global variables;"
@@ -438,7 +499,7 @@ def set_variable(self, variable_name, variable_value):
def osc_control(self, **kwargs):
"""控制osc执行,获取进度、终止、暂停、恢复等
- get、kill、pause、resume
+ get、kill、pause、resume
"""
return self.inc_engine.osc_control(**kwargs)
@@ -446,42 +507,44 @@ def processlist(self, command_type):
"""获取连接信息"""
base_sql = "select id, user, host, db, command, time, state, ifnull(info,'') as info from information_schema.processlist"
# escape
- command_type = MySQLdb.escape_string(command_type).decode('utf-8')
+ command_type = MySQLdb.escape_string(command_type).decode("utf-8")
if not command_type:
- command_type = 'Query'
- if command_type == 'All':
- sql = base_sql + ';'
- elif command_type == 'Not Sleep':
+ command_type = "Query"
+ if command_type == "All":
+ sql = base_sql + ";"
+ elif command_type == "Not Sleep":
sql = "{} where command<>'Sleep';".format(base_sql)
else:
sql = "{} where command= '{}';".format(base_sql, command_type)
-
- return self.query('information_schema', sql)
-
+
+ return self.query("information_schema", sql)
+
def get_kill_command(self, thread_ids):
"""由传入的线程列表生成kill命令"""
- sql = "select concat('kill ', id, ';') from information_schema.processlist where id in ({});"\
- .format(','.join(str(tid) for tid in thread_ids))
- all_kill_sql = self.query('information_schema', sql)
- kill_sql = ''
+ sql = "select concat('kill ', id, ';') from information_schema.processlist where id in ({});".format(
+ ",".join(str(tid) for tid in thread_ids)
+ )
+ all_kill_sql = self.query("information_schema", sql)
+ kill_sql = ""
for row in all_kill_sql.rows:
- kill_sql = kill_sql + row[0]
-
+ kill_sql = kill_sql + row[0]
+
return kill_sql
-
- def kill(self, thread_ids):
+
+ def kill(self, thread_ids):
"""kill线程"""
- sql = "select concat('kill ', id, ';') from information_schema.processlist where id in ({});"\
- .format(','.join(str(tid) for tid in thread_ids))
- all_kill_sql = self.query('information_schema', sql)
- kill_sql = ''
+ sql = "select concat('kill ', id, ';') from information_schema.processlist where id in ({});".format(
+ ",".join(str(tid) for tid in thread_ids)
+ )
+ all_kill_sql = self.query("information_schema", sql)
+ kill_sql = ""
for row in all_kill_sql.rows:
kill_sql = kill_sql + row[0]
- return self.execute('information_schema', kill_sql)
-
- def tablesapce(self, offset=0, row_count=14):
+ return self.execute("information_schema", kill_sql)
+
+ def tablesapce(self, offset=0, row_count=14):
"""获取表空间信息"""
- sql = '''
+ sql = """
SELECT
table_schema AS table_schema,
table_name AS table_name,
@@ -495,22 +558,24 @@ def tablesapce(self, offset=0, row_count=14):
FROM information_schema.tables
WHERE table_schema NOT IN ('information_schema', 'performance_schema', 'mysql', 'test', 'sys')
ORDER BY total_size DESC
- LIMIT {},{};'''.format(offset, row_count)
- return self.query('information_schema', sql)
-
- def tablesapce_num(self):
+ LIMIT {},{};""".format(
+ offset, row_count
+ )
+ return self.query("information_schema", sql)
+
+ def tablesapce_num(self):
"""获取表空间数量"""
- sql = '''
+ sql = """
SELECT count(*)
FROM information_schema.tables
- WHERE table_schema NOT IN ('information_schema', 'performance_schema', 'mysql', 'test', 'sys')'''
- return self.query('information_schema', sql)
-
+ WHERE table_schema NOT IN ('information_schema', 'performance_schema', 'mysql', 'test', 'sys')"""
+ return self.query("information_schema", sql)
+
def trxandlocks(self):
"""获取锁等待信息"""
server_version = self.server_version
if server_version < (8, 0, 1):
- sql = '''
+ sql = """
SELECT
rtrx.`trx_state` AS "等待的状态",
rtrx.`trx_started` AS "等待事务开始时间",
@@ -536,10 +601,10 @@ def trxandlocks(self):
WHERE rl.`lock_id` = lw.`requested_lock_id`
AND l.`lock_id` = lw.`blocking_lock_id`
AND lw.requesting_trx_id = rtrx.trx_id
- AND lw.blocking_trx_id = trx.trx_id;'''
-
+ AND lw.blocking_trx_id = trx.trx_id;"""
+
else:
- sql = '''
+ sql = """
SELECT
rtrx.`trx_state` AS "等待的状态",
rtrx.`trx_started` AS "等待事务开始时间",
@@ -565,13 +630,13 @@ def trxandlocks(self):
WHERE rl.`ENGINE_LOCK_ID` = lw.`REQUESTING_ENGINE_LOCK_ID`
AND l.`ENGINE_LOCK_ID` = lw.`BLOCKING_ENGINE_LOCK_ID`
AND lw.REQUESTING_ENGINE_TRANSACTION_ID = rtrx.trx_id
- AND lw.BLOCKING_ENGINE_TRANSACTION_ID = trx.trx_id;'''
-
- return self.query('information_schema', sql)
-
+ AND lw.BLOCKING_ENGINE_TRANSACTION_ID = trx.trx_id;"""
+
+ return self.query("information_schema", sql)
+
def get_long_transaction(self, thread_time=3):
"""获取长事务"""
- sql = '''select trx.trx_started,
+ sql = """select trx.trx_started,
trx.trx_state,
trx.trx_operation_state,
trx.trx_mysql_thread_id,
@@ -598,10 +663,12 @@ def get_long_transaction(self, thread_time=3):
WHERE trx.trx_state = 'RUNNING'
AND p.COMMAND = 'Sleep'
AND p.time > {}
- ORDER BY trx.trx_started ASC;'''.format(thread_time)
-
- return self.query('information_schema', sql)
-
+ ORDER BY trx.trx_started ASC;""".format(
+ thread_time
+ )
+
+ return self.query("information_schema", sql)
+
def close(self):
if self.conn:
self.conn.close()
diff --git a/sql/engines/odps.py b/sql/engines/odps.py
index 79b0bb23de..b4e986b5e1 100644
--- a/sql/engines/odps.py
+++ b/sql/engines/odps.py
@@ -10,7 +10,7 @@
from odps import ODPS
-logger = logging.getLogger('default')
+logger = logging.getLogger("default")
class ODPSEngine(EngineBase):
@@ -31,16 +31,16 @@ def get_connection(self, db_name=None):
@property
def name(self):
- return 'ODPS'
+ return "ODPS"
@property
def info(self):
- return 'ODPS engine'
+ return "ODPS engine"
def get_all_databases(self):
"""获取数据库列表, 返回一个ResultSet
- ODPS只有project概念, 直接返回project名称
- TODO: 目前ODPS获取所有项目接口比较慢, 暂时支持返回一个project,后续再优化
+ ODPS只有project概念, 直接返回project名称
+ TODO: 目前ODPS获取所有项目接口比较慢, 暂时支持返回一个project,后续再优化
"""
result = ResultSet()
@@ -80,7 +80,7 @@ def get_all_tables(self, db_name, **kwargs):
def get_all_columns_by_tb(self, db_name, tb_name, **kwargs):
"""获取所有字段, 返回一个ResultSet"""
- column_list = ['COLUMN_NAME', 'COLUMN_TYPE', 'COLUMN_COMMENT']
+ column_list = ["COLUMN_NAME", "COLUMN_TYPE", "COLUMN_COMMENT"]
conn = self.get_connection(db_name)
@@ -105,27 +105,27 @@ def describe_table(self, db_name, tb_name, **kwargs):
return result
- def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs):
- """返回 ResultSet """
+ def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs):
+ """返回 ResultSet"""
result_set = ResultSet(full_sql=sql)
if not re.match(r"^select", sql, re.I):
result_set.error = str("仅支持ODPS查询语句")
# 存在limit,替换limit; 不存在,添加limit
- if re.search('limit', sql):
- sql = re.sub('limit.+(\d+)', 'limit ' + str(limit_num), sql)
+ if re.search("limit", sql):
+ sql = re.sub("limit.+(\d+)", "limit " + str(limit_num), sql)
else:
- if sql.strip()[-1] == ';':
+ if sql.strip()[-1] == ";":
sql = sql[:-1]
- sql = sql + ' limit ' + str(limit_num) + ';'
+ sql = sql + " limit " + str(limit_num) + ";"
try:
conn = self.get_connection(db_name)
effect_row = conn.execute_sql(sql)
reader = effect_row.open_reader()
rows = [row.values for row in reader]
- column_list = getattr(reader, '_schema').names
+ column_list = getattr(reader, "_schema").names
result_set.column_list = column_list
result_set.rows = rows
@@ -136,27 +136,27 @@ def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs):
result_set.error = str(e)
return result_set
- def query_check(self, db_name=None, sql=''):
+ def query_check(self, db_name=None, sql=""):
# 查询语句的检查、注释去除、切分
- result = {'msg': '', 'bad_query': False, 'filtered_sql': sql, 'has_star': False}
- keyword_warning = ''
- sql_whitelist = ['select']
+ result = {"msg": "", "bad_query": False, "filtered_sql": sql, "has_star": False}
+ keyword_warning = ""
+ sql_whitelist = ["select"]
# 根据白名单list拼接pattern语句
whitelist_pattern = re.compile("^" + "|^".join(sql_whitelist), re.IGNORECASE)
# 删除注释语句,进行语法判断,执行第一条有效sql
try:
sql = sqlparse.format(sql, strip_comments=True)
sql = sqlparse.split(sql)[0]
- result['filtered_sql'] = sql.strip()
+ result["filtered_sql"] = sql.strip()
# sql_lower = sql.lower()
except IndexError:
- result['bad_query'] = True
- result['msg'] = '没有有效的SQL语句'
+ result["bad_query"] = True
+ result["msg"] = "没有有效的SQL语句"
return result
if whitelist_pattern.match(sql) is None:
- result['bad_query'] = True
- result['msg'] = '仅支持{}语法!'.format(','.join(sql_whitelist))
+ result["bad_query"] = True
+ result["msg"] = "仅支持{}语法!".format(",".join(sql_whitelist))
return result
- if result.get('bad_query'):
- result['msg'] = keyword_warning
+ if result.get("bad_query"):
+ result["msg"] = keyword_warning
return result
diff --git a/sql/engines/oracle.py b/sql/engines/oracle.py
index 7849d69f2e..4d6195c8b2 100644
--- a/sql/engines/oracle.py
+++ b/sql/engines/oracle.py
@@ -10,13 +10,17 @@
import pandas as pd
from common.config import SysConfig
from common.utils.timer import FuncTimer
-from sql.utils.sql_utils import get_syntax_type, get_full_sqlitem_list, get_exec_sqlitem_list
+from sql.utils.sql_utils import (
+ get_syntax_type,
+ get_full_sqlitem_list,
+ get_exec_sqlitem_list,
+)
from . import EngineBase
import cx_Oracle
from .models import ResultSet, ReviewSet, ReviewResult
from sql.utils.data_masking import simple_column_mask
-logger = logging.getLogger('default')
+logger = logging.getLogger("default")
class OracleEngine(EngineBase):
@@ -32,21 +36,27 @@ def get_connection(self, db_name=None):
return self.conn
if self.sid:
dsn = cx_Oracle.makedsn(self.host, self.port, self.sid)
- self.conn = cx_Oracle.connect(self.user, self.password, dsn=dsn, encoding="UTF-8", nencoding="UTF-8")
+ self.conn = cx_Oracle.connect(
+ self.user, self.password, dsn=dsn, encoding="UTF-8", nencoding="UTF-8"
+ )
elif self.service_name:
- dsn = cx_Oracle.makedsn(self.host, self.port, service_name=self.service_name)
- self.conn = cx_Oracle.connect(self.user, self.password, dsn=dsn, encoding="UTF-8", nencoding="UTF-8")
+ dsn = cx_Oracle.makedsn(
+ self.host, self.port, service_name=self.service_name
+ )
+ self.conn = cx_Oracle.connect(
+ self.user, self.password, dsn=dsn, encoding="UTF-8", nencoding="UTF-8"
+ )
else:
- raise ValueError('sid 和 dsn 均未填写, 请联系管理页补充该实例配置.')
+ raise ValueError("sid 和 dsn 均未填写, 请联系管理页补充该实例配置.")
return self.conn
@property
def name(self):
- return 'Oracle'
+ return "Oracle"
@property
def info(self):
- return 'Oracle engine'
+ return "Oracle engine"
@property
def auto_backup(self):
@@ -57,23 +67,24 @@ def auto_backup(self):
def get_backup_connection():
"""备份库连接"""
archer_config = SysConfig()
- backup_host = archer_config.get('inception_remote_backup_host')
- backup_port = int(archer_config.get('inception_remote_backup_port', 3306))
- backup_user = archer_config.get('inception_remote_backup_user')
- backup_password = archer_config.get('inception_remote_backup_password')
- return MySQLdb.connect(host=backup_host,
- port=backup_port,
- user=backup_user,
- passwd=backup_password,
- charset='utf8mb4',
- autocommit=True
- )
+ backup_host = archer_config.get("inception_remote_backup_host")
+ backup_port = int(archer_config.get("inception_remote_backup_port", 3306))
+ backup_user = archer_config.get("inception_remote_backup_user")
+ backup_password = archer_config.get("inception_remote_backup_password")
+ return MySQLdb.connect(
+ host=backup_host,
+ port=backup_port,
+ user=backup_user,
+ passwd=backup_password,
+ charset="utf8mb4",
+ autocommit=True,
+ )
@property
def server_version(self):
conn = self.get_connection()
version = conn.version
- return tuple([n for n in version.split('.')[:3]])
+ return tuple([n for n in version.split(".")[:3]])
def get_all_databases(self):
"""获取数据库列表, 返回resultSet 供上层调用, 底层实际上是获取oracle的schema列表"""
@@ -102,11 +113,47 @@ def _get_all_schemas(self):
"""
result = self.query(sql="SELECT username FROM all_users order by username")
sysschema = (
- 'AUD_SYS', 'ANONYMOUS', 'APEX_030200', 'APEX_PUBLIC_USER', 'APPQOSSYS', 'BI USERS', 'CTXSYS', 'DBSNMP',
- 'DIP USERS', 'EXFSYS', 'FLOWS_FILES', 'HR USERS', 'IX USERS', 'MDDATA', 'MDSYS', 'MGMT_VIEW', 'OE USERS',
- 'OLAPSYS', 'ORACLE_OCM', 'ORDDATA', 'ORDPLUGINS', 'ORDSYS', 'OUTLN', 'OWBSYS', 'OWBSYS_AUDIT', 'PM USERS',
- 'SCOTT', 'SH USERS', 'SI_INFORMTN_SCHEMA', 'SPATIAL_CSW_ADMIN_USR', 'SPATIAL_WFS_ADMIN_USR', 'SYS',
- 'SYSMAN', 'SYSTEM', 'WMSYS', 'XDB', 'XS$NULL', 'DIP', 'OJVMSYS', 'LBACSYS')
+ "AUD_SYS",
+ "ANONYMOUS",
+ "APEX_030200",
+ "APEX_PUBLIC_USER",
+ "APPQOSSYS",
+ "BI USERS",
+ "CTXSYS",
+ "DBSNMP",
+ "DIP USERS",
+ "EXFSYS",
+ "FLOWS_FILES",
+ "HR USERS",
+ "IX USERS",
+ "MDDATA",
+ "MDSYS",
+ "MGMT_VIEW",
+ "OE USERS",
+ "OLAPSYS",
+ "ORACLE_OCM",
+ "ORDDATA",
+ "ORDPLUGINS",
+ "ORDSYS",
+ "OUTLN",
+ "OWBSYS",
+ "OWBSYS_AUDIT",
+ "PM USERS",
+ "SCOTT",
+ "SH USERS",
+ "SI_INFORMTN_SCHEMA",
+ "SPATIAL_CSW_ADMIN_USR",
+ "SPATIAL_WFS_ADMIN_USR",
+ "SYS",
+ "SYSMAN",
+ "SYSTEM",
+ "WMSYS",
+ "XDB",
+ "XS$NULL",
+ "DIP",
+ "OJVMSYS",
+ "LBACSYS",
+ )
schema_list = [row[0] for row in result.rows if row[0] not in sysschema]
result.rows = schema_list
return result
@@ -116,7 +163,7 @@ def get_all_tables(self, db_name, **kwargs):
sql = f"""SELECT table_name FROM all_tables WHERE nvl(tablespace_name, 'no tablespace') NOT IN ('SYSTEM', 'SYSAUX') AND OWNER = '{db_name}' AND IOT_NAME IS NULL AND DURATION IS NULL order by table_name
"""
result = self.query(db_name=db_name, sql=sql)
- tb_list = [row[0] for row in result.rows if row[0] not in ['test']]
+ tb_list = [row[0] for row in result.rows if row[0] not in ["test"]]
result.rows = tb_list
return result
@@ -161,7 +208,7 @@ def get_table_meta_data(self, db_name, tb_name, **kwargs):
tcs.OWNER='{db_name}'
AND tcs.TABLE_NAME='{tb_name}'"""
_meta_data = self.query(db_name=db_name, sql=meta_data_sql)
- return {'column_list': _meta_data.column_list, 'rows': _meta_data.rows[0]}
+ return {"column_list": _meta_data.column_list, "rows": _meta_data.rows[0]}
def get_table_desc_data(self, db_name, tb_name, **kwargs):
"""获取表格字段信息"""
@@ -206,7 +253,7 @@ def get_table_desc_data(self, db_name, tb_name, **kwargs):
AND bcs.TABLE_NAME='{tb_name}'
ORDER BY bcs.COLUMN_NAME"""
_desc_data = self.query(db_name=db_name, sql=desc_sql)
- return {'column_list': _desc_data.column_list, 'rows': _desc_data.rows}
+ return {"column_list": _desc_data.column_list, "rows": _desc_data.rows}
def get_table_index_data(self, db_name, tb_name, **kwargs):
"""获取表格索引信息"""
@@ -228,7 +275,7 @@ def get_table_index_data(self, db_name, tb_name, **kwargs):
ais.owner = '{db_name}'
AND ais.table_name = '{tb_name}'"""
_index_data = self.query(db_name, index_sql)
- return {'column_list': _index_data.column_list, 'rows': _index_data.rows}
+ return {"column_list": _index_data.column_list, "rows": _index_data.rows}
def get_tables_metas_data(self, db_name, **kwargs):
"""获取数据库所有表格信息,用作数据字典导出接口"""
@@ -282,22 +329,42 @@ def get_tables_metas_data(self, db_name, **kwargs):
cols_req = self.query(sql=sql_cols, close_conn=False).rows
# 给查询结果定义列名,query_engine.query的游标是0 1 2
- cols_df = pd.DataFrame(cols_req,
- columns=['TABLE_NAME', 'TABLE_COMMENTS', 'COLUMN_NAME', 'COLUMN_TYPE', 'COLUMN_DEFAULT',
- 'IS_NULLABLE', 'COLUMN_KEY', 'COLUMN_COMMENT'])
+ cols_df = pd.DataFrame(
+ cols_req,
+ columns=[
+ "TABLE_NAME",
+ "TABLE_COMMENTS",
+ "COLUMN_NAME",
+ "COLUMN_TYPE",
+ "COLUMN_DEFAULT",
+ "IS_NULLABLE",
+ "COLUMN_KEY",
+ "COLUMN_COMMENT",
+ ],
+ )
# 获得表名称去重
- col_list = cols_df.drop_duplicates('TABLE_NAME').to_dict('records')
+ col_list = cols_df.drop_duplicates("TABLE_NAME").to_dict("records")
for cl in col_list:
_meta = dict()
- engine_keys = [{"key": "COLUMN_NAME", "value": "字段名"}, {"key": "COLUMN_TYPE", "value": "数据类型"},
- {"key": "COLUMN_DEFAULT", "value": "默认值"}, {"key": "IS_NULLABLE", "value": "允许非空"},
- {"key": "COLUMN_KEY", "value": "是否主键"}, {"key": "COLUMN_COMMENT", "value": "备注"}]
+ engine_keys = [
+ {"key": "COLUMN_NAME", "value": "字段名"},
+ {"key": "COLUMN_TYPE", "value": "数据类型"},
+ {"key": "COLUMN_DEFAULT", "value": "默认值"},
+ {"key": "IS_NULLABLE", "value": "允许非空"},
+ {"key": "COLUMN_KEY", "value": "是否主键"},
+ {"key": "COLUMN_COMMENT", "value": "备注"},
+ ]
_meta["ENGINE_KEYS"] = engine_keys
- _meta['TABLE_INFO'] = {'TABLE_NAME': cl['TABLE_NAME'], 'TABLE_COMMENTS': cl['TABLE_COMMENTS']}
- table_name = cl['TABLE_NAME']
+ _meta["TABLE_INFO"] = {
+ "TABLE_NAME": cl["TABLE_NAME"],
+ "TABLE_COMMENTS": cl["TABLE_COMMENTS"],
+ }
+ table_name = cl["TABLE_NAME"]
# 查询DataFrame中满足表名的记录,并转为list
- _meta['COLUMNS'] = cols_df.query("TABLE_NAME == @table_name").to_dict('records')
+ _meta["COLUMNS"] = cols_df.query("TABLE_NAME == @table_name").to_dict(
+ "records"
+ )
table_metas.append(_meta)
return table_metas
@@ -306,7 +373,7 @@ def get_all_objects(self, db_name, **kwargs):
"""获取object_name 列表, 返回一个ResultSet"""
sql = f"""SELECT object_name FROM all_objects WHERE OWNER = '{db_name}' """
result = self.query(db_name=db_name, sql=sql)
- tb_list = [row[0] for row in result.rows if row[0] not in ['test']]
+ tb_list = [row[0] for row in result.rows if row[0] not in ["test"]]
result.rows = tb_list
return result
@@ -336,27 +403,27 @@ def describe_table(self, db_name, tb_name, **kwargs):
result = self.query(db_name=db_name, sql=sql)
return result
- def object_name_check(self, db_name=None, object_name=''):
+ def object_name_check(self, db_name=None, object_name=""):
"""获取table 列表, 返回一个ResultSet"""
- if '.' in object_name:
- schema_name = object_name.split('.')[0]
- object_name = object_name.split('.')[1]
+ if "." in object_name:
+ schema_name = object_name.split(".")[0]
+ object_name = object_name.split(".")[1]
if '"' in schema_name:
- schema_name = schema_name.replace('"', '')
+ schema_name = schema_name.replace('"', "")
if '"' in object_name:
- object_name = object_name.replace('"', '')
+ object_name = object_name.replace('"', "")
else:
object_name = object_name.upper()
else:
schema_name = schema_name.upper()
if '"' in object_name:
- object_name = object_name.replace('"', '')
+ object_name = object_name.replace('"', "")
else:
object_name = object_name.upper()
else:
schema_name = db_name
if '"' in object_name:
- object_name = object_name.replace('"', '')
+ object_name = object_name.replace('"', "")
else:
object_name = object_name.upper()
sql = f""" SELECT object_name FROM all_objects WHERE OWNER = '{schema_name}' and OBJECT_NAME = '{object_name}' """
@@ -367,47 +434,71 @@ def object_name_check(self, db_name=None, object_name=''):
return False
@staticmethod
- def get_sql_first_object_name(sql=''):
+ def get_sql_first_object_name(sql=""):
"""获取sql文本中的object_name"""
- object_name = ''
+ object_name = ""
if re.match(r"^create\s+table\s", sql, re.M | re.IGNORECASE):
- object_name = re.match(r"^create\s+table\s(.+?)(\s|\()", sql, re.M | re.IGNORECASE).group(1)
+ object_name = re.match(
+ r"^create\s+table\s(.+?)(\s|\()", sql, re.M | re.IGNORECASE
+ ).group(1)
elif re.match(r"^create\s+index\s", sql, re.M | re.IGNORECASE):
- object_name = re.match(r"^create\s+index\s(.+?)\s", sql, re.M | re.IGNORECASE).group(1)
+ object_name = re.match(
+ r"^create\s+index\s(.+?)\s", sql, re.M | re.IGNORECASE
+ ).group(1)
elif re.match(r"^create\s+unique\s+index\s", sql, re.M | re.IGNORECASE):
- object_name = re.match(r"^create\s+unique\s+index\s(.+?)\s", sql, re.M | re.IGNORECASE).group(1)
+ object_name = re.match(
+ r"^create\s+unique\s+index\s(.+?)\s", sql, re.M | re.IGNORECASE
+ ).group(1)
elif re.match(r"^create\s+sequence\s", sql, re.M | re.IGNORECASE):
- object_name = re.match(r"^create\s+sequence\s(.+?)(\s|$)", sql, re.M | re.IGNORECASE).group(1)
+ object_name = re.match(
+ r"^create\s+sequence\s(.+?)(\s|$)", sql, re.M | re.IGNORECASE
+ ).group(1)
elif re.match(r"^alter\s+table\s", sql, re.M | re.IGNORECASE):
- object_name = re.match(r"^alter\s+table\s(.+?)\s", sql, re.M | re.IGNORECASE).group(1)
+ object_name = re.match(
+ r"^alter\s+table\s(.+?)\s", sql, re.M | re.IGNORECASE
+ ).group(1)
elif re.match(r"^create\s+function\s", sql, re.M | re.IGNORECASE):
- object_name = re.match(r"^create\s+function\s(.+?)(\s|\()", sql, re.M | re.IGNORECASE).group(1)
+ object_name = re.match(
+ r"^create\s+function\s(.+?)(\s|\()", sql, re.M | re.IGNORECASE
+ ).group(1)
elif re.match(r"^create\s+view\s", sql, re.M | re.IGNORECASE):
- object_name = re.match(r"^create\s+view\s(.+?)\s", sql, re.M | re.IGNORECASE).group(1)
+ object_name = re.match(
+ r"^create\s+view\s(.+?)\s", sql, re.M | re.IGNORECASE
+ ).group(1)
elif re.match(r"^create\s+procedure\s", sql, re.M | re.IGNORECASE):
- object_name = re.match(r"^create\s+procedure\s(.+?)\s", sql, re.M | re.IGNORECASE).group(1)
+ object_name = re.match(
+ r"^create\s+procedure\s(.+?)\s", sql, re.M | re.IGNORECASE
+ ).group(1)
elif re.match(r"^create\s+package\s+body", sql, re.M | re.IGNORECASE):
- object_name = re.match(r"^create\s+package\s+body\s(.+?)\s", sql, re.M | re.IGNORECASE).group(1)
+ object_name = re.match(
+ r"^create\s+package\s+body\s(.+?)\s", sql, re.M | re.IGNORECASE
+ ).group(1)
elif re.match(r"^create\s+package\s", sql, re.M | re.IGNORECASE):
- object_name = re.match(r"^create\s+package\s(.+?)\s", sql, re.M | re.IGNORECASE).group(1)
+ object_name = re.match(
+ r"^create\s+package\s(.+?)\s", sql, re.M | re.IGNORECASE
+ ).group(1)
else:
return object_name.strip()
return object_name.strip()
@staticmethod
- def check_create_index_table(sql='', object_name_list=None, db_name=''):
+ def check_create_index_table(sql="", object_name_list=None, db_name=""):
object_name_list = object_name_list or set()
if re.match(r"^create\s+index\s", sql):
- table_name = re.match(r"^create\s+index\s+.+\s+on\s(.+?)(\(|\s\()", sql, re.M).group(1)
- if '.' not in table_name:
+ table_name = re.match(
+ r"^create\s+index\s+.+\s+on\s(.+?)(\(|\s\()", sql, re.M
+ ).group(1)
+ if "." not in table_name:
table_name = f"{db_name}.{table_name}"
if table_name in object_name_list:
return True
else:
return False
elif re.match(r"^create\s+unique\s+index\s", sql):
- table_name = re.match(r"^create\s+unique\s+index\s+.+\s+on\s(.+?)(\(|\s\()", sql, re.M).group(1)
- if '.' not in table_name:
+ table_name = re.match(
+ r"^create\s+unique\s+index\s+.+\s+on\s(.+?)(\(|\s\()", sql, re.M
+ ).group(1)
+ if "." not in table_name:
table_name = f"{db_name}.{table_name}"
if table_name in object_name_list:
return True
@@ -417,11 +508,11 @@ def check_create_index_table(sql='', object_name_list=None, db_name=''):
return False
@staticmethod
- def get_dml_table(sql='', object_name_list=None, db_name=''):
+ def get_dml_table(sql="", object_name_list=None, db_name=""):
object_name_list = object_name_list or set()
if re.match(r"^update", sql):
table_name = re.match(r"^update\s(.+?)\s", sql, re.M).group(1)
- if '.' not in table_name:
+ if "." not in table_name:
table_name = f"{db_name}.{table_name}"
if table_name in object_name_list:
return True
@@ -429,16 +520,19 @@ def get_dml_table(sql='', object_name_list=None, db_name=''):
return False
elif re.match(r"^delete", sql):
table_name = re.match(r"^delete\s+from\s+([\w-]+)\s*", sql, re.M).group(1)
- if '.' not in table_name:
+ if "." not in table_name:
table_name = f"{db_name}.{table_name}"
if table_name in object_name_list:
return True
else:
return False
elif re.match(r"^insert\s", sql):
- table_name = re.match(r"^insert\s+((into)|(all\s+into)|(all\s+when\s(.+?)into))\s+(.+?)(\(|\s)", sql,
- re.M).group(6)
- if '.' not in table_name:
+ table_name = re.match(
+ r"^insert\s+((into)|(all\s+into)|(all\s+when\s(.+?)into))\s+(.+?)(\(|\s)",
+ sql,
+ re.M,
+ ).group(6)
+ if "." not in table_name:
table_name = f"{db_name}.{table_name}"
if table_name in object_name_list:
return True
@@ -448,93 +542,103 @@ def get_dml_table(sql='', object_name_list=None, db_name=''):
return False
@staticmethod
- def where_check(sql=''):
+ def where_check(sql=""):
if re.match(r"^update((?!where).)*$|^delete((?!where).)*$", sql):
return True
else:
parsed = sqlparse.parse(sql)[0]
flattened = list(parsed.flatten())
n_skip = 0
- flattened = flattened[:len(flattened) - n_skip]
- logical_operators = ('AND', 'OR', 'NOT', 'BETWEEN', 'ORDER BY', 'GROUP BY', 'HAVING')
+ flattened = flattened[: len(flattened) - n_skip]
+ logical_operators = (
+ "AND",
+ "OR",
+ "NOT",
+ "BETWEEN",
+ "ORDER BY",
+ "GROUP BY",
+ "HAVING",
+ )
for t in reversed(flattened):
if t.is_keyword:
return True
return False
- def explain_check(self, db_name=None, sql='', close_conn=False):
+ def explain_check(self, db_name=None, sql="", close_conn=False):
# 使用explain进行支持的SQL语法审核,连接需不中断,防止数据库不断fork进程的大批量消耗
- result = {'msg': '', 'rows': 0}
+ result = {"msg": "", "rows": 0}
try:
conn = self.get_connection()
cursor = conn.cursor()
if db_name:
- cursor.execute(f" ALTER SESSION SET CURRENT_SCHEMA = \"{db_name}\" ")
+ cursor.execute(f' ALTER SESSION SET CURRENT_SCHEMA = "{db_name}" ')
if re.match(r"^explain", sql, re.I):
sql = sql
else:
sql = f"explain plan for {sql}"
- sql = sql.rstrip(';')
+ sql = sql.rstrip(";")
cursor.execute(sql)
# 获取影响行数
cursor.execute(f"select CARDINALITY from SYS.PLAN_TABLE$ where id = 0")
rows = cursor.fetchone()
conn.rollback()
if not rows:
- result['rows'] = 0
+ result["rows"] = 0
else:
- result['rows'] = rows[0]
+ result["rows"] = rows[0]
except Exception as e:
logger.warning(f"Oracle 语句执行报错,语句:{sql},错误信息{traceback.format_exc()}")
- result['msg'] = str(e)
+ result["msg"] = str(e)
finally:
if close_conn:
self.close()
return result
- def query_check(self, db_name=None, sql=''):
+ def query_check(self, db_name=None, sql=""):
# 查询语句的检查、注释去除、切分
- result = {'msg': '', 'bad_query': False, 'filtered_sql': sql, 'has_star': False}
- keyword_warning = ''
+ result = {"msg": "", "bad_query": False, "filtered_sql": sql, "has_star": False}
+ keyword_warning = ""
star_patter = r"(^|,|\s)\*(\s|\(|$)"
# 删除注释语句,进行语法判断,执行第一条有效sql
try:
sql = sqlparse.format(sql, strip_comments=True)
sql = sqlparse.split(sql)[0]
- result['filtered_sql'] = re.sub(r';$', '', sql.strip())
+ result["filtered_sql"] = re.sub(r";$", "", sql.strip())
sql_lower = sql.lower()
except IndexError:
- result['bad_query'] = True
- result['msg'] = '没有有效的SQL语句'
+ result["bad_query"] = True
+ result["msg"] = "没有有效的SQL语句"
return result
if re.match(r"^select|^with|^explain", sql_lower) is None:
- result['bad_query'] = True
- result['msg'] = '不支持语法!'
+ result["bad_query"] = True
+ result["msg"] = "不支持语法!"
return result
if re.search(star_patter, sql_lower) is not None:
- keyword_warning += '禁止使用 * 关键词\n'
- result['has_star'] = True
- if result.get('bad_query') or result.get('has_star'):
- result['msg'] = keyword_warning
+ keyword_warning += "禁止使用 * 关键词\n"
+ result["has_star"] = True
+ if result.get("bad_query") or result.get("has_star"):
+ result["msg"] = keyword_warning
return result
- def filter_sql(self, sql='', limit_num=0):
+ def filter_sql(self, sql="", limit_num=0):
sql_lower = sql.lower()
# 对查询sql增加limit限制
if re.match(r"^select|^with", sql_lower) and not (
- re.match(r"^select\s+sql_audit.", sql_lower) and sql_lower.find(" sql_audit where rownum <= ") != -1):
+ re.match(r"^select\s+sql_audit.", sql_lower)
+ and sql_lower.find(" sql_audit where rownum <= ") != -1
+ ):
sql = f"select sql_audit.* from ({sql.rstrip(';')}) sql_audit where rownum <= {limit_num}"
return sql.strip()
- def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs):
- """返回 ResultSet """
+ def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs):
+ """返回 ResultSet"""
result_set = ResultSet(full_sql=sql)
try:
conn = self.get_connection()
cursor = conn.cursor()
if db_name:
- cursor.execute(f" ALTER SESSION SET CURRENT_SCHEMA = \"{db_name}\" ")
- sql = sql.rstrip(';')
+ cursor.execute(f' ALTER SESSION SET CURRENT_SCHEMA = "{db_name}" ')
+ sql = sql.rstrip(";")
# 支持oralce查询SQL执行计划语句
if re.match(r"^explain", sql, re.I):
cursor.execute(sql)
@@ -543,9 +647,12 @@ def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs):
cursor.execute(sql)
fields = cursor.description
if any(x[1] == cx_Oracle.CLOB for x in fields):
- rows = [tuple([(c.read() if type(c) == cx_Oracle.LOB else c) for c in r]) for r in cursor]
+ rows = [
+ tuple([(c.read() if type(c) == cx_Oracle.LOB else c) for c in r])
+ for r in cursor
+ ]
if int(limit_num) > 0:
- rows = rows[0:int(limit_num)]
+ rows = rows[0 : int(limit_num)]
else:
if int(limit_num) > 0:
rows = cursor.fetchmany(int(limit_num))
@@ -562,7 +669,7 @@ def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs):
self.close()
return result_set
- def query_masking(self, db_name=None, sql='', resultset=None):
+ def query_masking(self, db_name=None, sql="", resultset=None):
"""简单字段脱敏规则, 仅对select有效"""
if re.match(r"^select", sql, re.I):
filtered_result = simple_column_mask(self.instance, resultset)
@@ -571,7 +678,7 @@ def query_masking(self, db_name=None, sql='', resultset=None):
filtered_result = resultset
return filtered_result
- def execute_check(self, db_name=None, sql='', close_conn=True):
+ def execute_check(self, db_name=None, sql="", close_conn=True):
"""
上线单执行前的检查, 返回Review set
update by Jan.song 20200302
@@ -585,81 +692,118 @@ def execute_check(self, db_name=None, sql='', close_conn=True):
line = 1
# 保存SQL中的新建对象
object_name_list = set()
- critical_ddl_regex = config.get('critical_ddl_regex', '')
+ critical_ddl_regex = config.get("critical_ddl_regex", "")
p = re.compile(critical_ddl_regex)
check_result.syntax_type = 2 # TODO 工单类型 0、其他 1、DDL,2、DML
try:
sqlitemList = get_full_sqlitem_list(sql, db_name)
for sqlitem in sqlitemList:
- sql_lower = sqlitem.statement.lower().rstrip(';')
- sql_nolower = sqlitem.statement.rstrip(';')
+ sql_lower = sqlitem.statement.lower().rstrip(";")
+ sql_nolower = sqlitem.statement.rstrip(";")
# 禁用语句
if re.match(r"^select|^with|^explain", sql_lower):
- result = ReviewResult(id=line, errlevel=2,
- stagestatus='驳回不支持语句',
- errormessage='仅支持DML和DDL语句,查询语句请使用SQL查询功能!',
- sql=sqlitem.statement)
+ result = ReviewResult(
+ id=line,
+ errlevel=2,
+ stagestatus="驳回不支持语句",
+ errormessage="仅支持DML和DDL语句,查询语句请使用SQL查询功能!",
+ sql=sqlitem.statement,
+ )
# 高危语句
elif critical_ddl_regex and p.match(sql_lower.strip()):
- result = ReviewResult(id=line, errlevel=2,
- stagestatus='驳回高危SQL',
- errormessage='禁止提交匹配' + critical_ddl_regex + '条件的语句!',
- sql=sqlitem.statement)
+ result = ReviewResult(
+ id=line,
+ errlevel=2,
+ stagestatus="驳回高危SQL",
+ errormessage="禁止提交匹配" + critical_ddl_regex + "条件的语句!",
+ sql=sqlitem.statement,
+ )
# 驳回未带where数据修改语句,如确实需做全部删除或更新,显示的带上where 1=1
- elif re.match(r"^update((?!where).)*$|^delete((?!where).)*$", sql_lower):
- result = ReviewResult(id=line, errlevel=2,
- stagestatus='驳回未带where数据修改',
- errormessage='数据修改需带where条件!',
- sql=sqlitem.statement)
+ elif re.match(
+ r"^update((?!where).)*$|^delete((?!where).)*$", sql_lower
+ ):
+ result = ReviewResult(
+ id=line,
+ errlevel=2,
+ stagestatus="驳回未带where数据修改",
+ errormessage="数据修改需带where条件!",
+ sql=sqlitem.statement,
+ )
# 驳回事务控制,会话控制SQL
elif re.match(r"^set|^rollback|^exit", sql_lower):
- result = ReviewResult(id=line, errlevel=2,
- stagestatus='SQL中不能包含^set|^rollback|^exit',
- errormessage='SQL中不能包含^set|^rollback|^exit',
- sql=sqlitem.statement)
+ result = ReviewResult(
+ id=line,
+ errlevel=2,
+ stagestatus="SQL中不能包含^set|^rollback|^exit",
+ errormessage="SQL中不能包含^set|^rollback|^exit",
+ sql=sqlitem.statement,
+ )
# 通过explain对SQL做语法语义检查
- elif re.match(explain_re, sql_lower) and sqlitem.stmt_type == 'SQL':
- if self.check_create_index_table(db_name=db_name, sql=sql_lower, object_name_list=object_name_list):
+ elif re.match(explain_re, sql_lower) and sqlitem.stmt_type == "SQL":
+ if self.check_create_index_table(
+ db_name=db_name,
+ sql=sql_lower,
+ object_name_list=object_name_list,
+ ):
object_name = self.get_sql_first_object_name(sql=sql_lower)
- if '.' in object_name:
+ if "." in object_name:
object_name = object_name
else:
object_name = f"""{db_name}.{object_name}"""
object_name_list.add(object_name)
- result = ReviewResult(id=line, errlevel=1,
- stagestatus='WARNING:新建表的新建索引语句暂无法检测!',
- errormessage='WARNING:新建表的新建索引语句暂无法检测!',
- stmt_type=sqlitem.stmt_type,
- object_owner=sqlitem.object_owner,
- object_type=sqlitem.object_type,
- object_name=sqlitem.object_name,
- sql=sqlitem.statement)
- elif len(object_name_list) > 0 and self.get_dml_table(db_name=db_name, sql=sql_lower,
- object_name_list=object_name_list):
- result = ReviewResult(id=line, errlevel=1,
- stagestatus='WARNING:新建表的数据修改暂无法检测!',
- errormessage='WARNING:新建表的数据修改暂无法检测!',
- stmt_type=sqlitem.stmt_type,
- object_owner=sqlitem.object_owner,
- object_type=sqlitem.object_type,
- object_name=sqlitem.object_name,
- sql=sqlitem.statement)
+ result = ReviewResult(
+ id=line,
+ errlevel=1,
+ stagestatus="WARNING:新建表的新建索引语句暂无法检测!",
+ errormessage="WARNING:新建表的新建索引语句暂无法检测!",
+ stmt_type=sqlitem.stmt_type,
+ object_owner=sqlitem.object_owner,
+ object_type=sqlitem.object_type,
+ object_name=sqlitem.object_name,
+ sql=sqlitem.statement,
+ )
+ elif len(object_name_list) > 0 and self.get_dml_table(
+ db_name=db_name,
+ sql=sql_lower,
+ object_name_list=object_name_list,
+ ):
+ result = ReviewResult(
+ id=line,
+ errlevel=1,
+ stagestatus="WARNING:新建表的数据修改暂无法检测!",
+ errormessage="WARNING:新建表的数据修改暂无法检测!",
+ stmt_type=sqlitem.stmt_type,
+ object_owner=sqlitem.object_owner,
+ object_type=sqlitem.object_type,
+ object_name=sqlitem.object_name,
+ sql=sqlitem.statement,
+ )
else:
- result_set = self.explain_check(db_name=db_name, sql=sqlitem.statement, close_conn=False)
- if result_set['msg']:
- result = ReviewResult(id=line, errlevel=2,
- stagestatus='explain语法检查未通过!',
- errormessage=result_set['msg'],
- sql=sqlitem.statement)
+ result_set = self.explain_check(
+ db_name=db_name, sql=sqlitem.statement, close_conn=False
+ )
+ if result_set["msg"]:
+ result = ReviewResult(
+ id=line,
+ errlevel=2,
+ stagestatus="explain语法检查未通过!",
+ errormessage=result_set["msg"],
+ sql=sqlitem.statement,
+ )
else:
# 对create table\create index\create unique index语法做对象存在性检测
- if re.match(r"^create\s+table|^create\s+index|^create\s+unique\s+index", sql_lower):
- object_name = self.get_sql_first_object_name(sql=sql_nolower)
+ if re.match(
+ r"^create\s+table|^create\s+index|^create\s+unique\s+index",
+ sql_lower,
+ ):
+ object_name = self.get_sql_first_object_name(
+ sql=sql_nolower
+ )
# 保存create对象对后续SQL做存在性判断
- if '.' in object_name:
- schema_name = object_name.split('.')[0]
- object_name = object_name.split('.')[1]
+ if "." in object_name:
+ schema_name = object_name.split(".")[0]
+ object_name = object_name.split(".")[1]
if '"' in schema_name:
schema_name = schema_name
if '"' not in object_name:
@@ -669,72 +813,91 @@ def execute_check(self, db_name=None, sql='', close_conn=True):
if '"' not in object_name:
object_name = object_name.upper()
else:
- schema_name = ('"' + db_name + '"')
+ schema_name = '"' + db_name + '"'
if '"' not in object_name:
object_name = object_name.upper()
object_name = f"""{schema_name}.{object_name}"""
- if self.object_name_check(db_name=db_name,
- object_name=object_name) or object_name in object_name_list:
- result = ReviewResult(id=line, errlevel=2,
- stagestatus=f"""{object_name}对象已经存在!""",
- errormessage=f"""{object_name}对象已经存在!""",
- sql=sqlitem.statement)
+ if (
+ self.object_name_check(
+ db_name=db_name, object_name=object_name
+ )
+ or object_name in object_name_list
+ ):
+ result = ReviewResult(
+ id=line,
+ errlevel=2,
+ stagestatus=f"""{object_name}对象已经存在!""",
+ errormessage=f"""{object_name}对象已经存在!""",
+ sql=sqlitem.statement,
+ )
else:
object_name_list.add(object_name)
- if result_set['rows'] > 1000:
- result = ReviewResult(id=line, errlevel=1,
- stagestatus='影响行数大于1000,请关注',
- errormessage='影响行数大于1000,请关注',
- sql=sqlitem.statement,
- stmt_type=sqlitem.stmt_type,
- object_owner=sqlitem.object_owner,
- object_type=sqlitem.object_type,
- object_name=sqlitem.object_name,
- affected_rows=result_set['rows'],
- execute_time=0, )
+ if result_set["rows"] > 1000:
+ result = ReviewResult(
+ id=line,
+ errlevel=1,
+ stagestatus="影响行数大于1000,请关注",
+ errormessage="影响行数大于1000,请关注",
+ sql=sqlitem.statement,
+ stmt_type=sqlitem.stmt_type,
+ object_owner=sqlitem.object_owner,
+ object_type=sqlitem.object_type,
+ object_name=sqlitem.object_name,
+ affected_rows=result_set["rows"],
+ execute_time=0,
+ )
else:
- result = ReviewResult(id=line, errlevel=0,
- stagestatus='Audit completed',
- errormessage='None',
- sql=sqlitem.statement,
- stmt_type=sqlitem.stmt_type,
- object_owner=sqlitem.object_owner,
- object_type=sqlitem.object_type,
- object_name=sqlitem.object_name,
- affected_rows=result_set['rows'],
- execute_time=0, )
+ result = ReviewResult(
+ id=line,
+ errlevel=0,
+ stagestatus="Audit completed",
+ errormessage="None",
+ sql=sqlitem.statement,
+ stmt_type=sqlitem.stmt_type,
+ object_owner=sqlitem.object_owner,
+ object_type=sqlitem.object_type,
+ object_name=sqlitem.object_name,
+ affected_rows=result_set["rows"],
+ execute_time=0,
+ )
else:
- if result_set['rows'] > 1000:
- result = ReviewResult(id=line, errlevel=1,
- stagestatus='影响行数大于1000,请关注',
- errormessage='影响行数大于1000,请关注',
- sql=sqlitem.statement,
- stmt_type=sqlitem.stmt_type,
- object_owner=sqlitem.object_owner,
- object_type=sqlitem.object_type,
- object_name=sqlitem.object_name,
- affected_rows=result_set['rows'],
- execute_time=0, )
+ if result_set["rows"] > 1000:
+ result = ReviewResult(
+ id=line,
+ errlevel=1,
+ stagestatus="影响行数大于1000,请关注",
+ errormessage="影响行数大于1000,请关注",
+ sql=sqlitem.statement,
+ stmt_type=sqlitem.stmt_type,
+ object_owner=sqlitem.object_owner,
+ object_type=sqlitem.object_type,
+ object_name=sqlitem.object_name,
+ affected_rows=result_set["rows"],
+ execute_time=0,
+ )
else:
- result = ReviewResult(id=line, errlevel=0,
- stagestatus='Audit completed',
- errormessage='None',
- sql=sqlitem.statement,
- stmt_type=sqlitem.stmt_type,
- object_owner=sqlitem.object_owner,
- object_type=sqlitem.object_type,
- object_name=sqlitem.object_name,
- affected_rows=result_set['rows'],
- execute_time=0, )
+ result = ReviewResult(
+ id=line,
+ errlevel=0,
+ stagestatus="Audit completed",
+ errormessage="None",
+ sql=sqlitem.statement,
+ stmt_type=sqlitem.stmt_type,
+ object_owner=sqlitem.object_owner,
+ object_type=sqlitem.object_type,
+ object_name=sqlitem.object_name,
+ affected_rows=result_set["rows"],
+ execute_time=0,
+ )
# 其它无法用explain判断的语句
else:
# 对alter table做对象存在性检查
if re.match(r"^alter\s+table\s", sql_lower):
object_name = self.get_sql_first_object_name(sql=sql_nolower)
- if '.' in object_name:
- schema_name = object_name.split('.')[0]
- object_name = object_name.split('.')[1]
+ if "." in object_name:
+ schema_name = object_name.split(".")[0]
+ object_name = object_name.split(".")[1]
if '"' in schema_name:
schema_name = schema_name
if '"' not in object_name:
@@ -744,34 +907,44 @@ def execute_check(self, db_name=None, sql='', close_conn=True):
if '"' not in object_name:
object_name = object_name.upper()
else:
- schema_name = ('"' + db_name + '"')
+ schema_name = '"' + db_name + '"'
if '"' not in object_name:
object_name = object_name.upper()
object_name = f"""{schema_name}.{object_name}"""
- if not self.object_name_check(db_name=db_name,
- object_name=object_name) and object_name not in object_name_list:
- result = ReviewResult(id=line, errlevel=2,
- stagestatus=f"""{object_name}对象不存在!""",
- errormessage=f"""{object_name}对象不存在!""",
- sql=sqlitem.statement)
+ if (
+ not self.object_name_check(
+ db_name=db_name, object_name=object_name
+ )
+ and object_name not in object_name_list
+ ):
+ result = ReviewResult(
+ id=line,
+ errlevel=2,
+ stagestatus=f"""{object_name}对象不存在!""",
+ errormessage=f"""{object_name}对象不存在!""",
+ sql=sqlitem.statement,
+ )
else:
- result = ReviewResult(id=line, errlevel=1,
- stagestatus='当前平台,此语法不支持审核!',
- errormessage='当前平台,此语法不支持审核!',
- sql=sqlitem.statement,
- stmt_type=sqlitem.stmt_type,
- object_owner=sqlitem.object_owner,
- object_type=sqlitem.object_type,
- object_name=sqlitem.object_name,
- affected_rows=0,
- execute_time=0, )
+ result = ReviewResult(
+ id=line,
+ errlevel=1,
+ stagestatus="当前平台,此语法不支持审核!",
+ errormessage="当前平台,此语法不支持审核!",
+ sql=sqlitem.statement,
+ stmt_type=sqlitem.stmt_type,
+ object_owner=sqlitem.object_owner,
+ object_type=sqlitem.object_type,
+ object_name=sqlitem.object_name,
+ affected_rows=0,
+ execute_time=0,
+ )
# 对create做对象存在性检查
elif re.match(r"^create", sql_lower):
object_name = self.get_sql_first_object_name(sql=sql_nolower)
- if '.' in object_name:
- schema_name = object_name.split('.')[0]
- object_name = object_name.split('.')[1]
+ if "." in object_name:
+ schema_name = object_name.split(".")[0]
+ object_name = object_name.split(".")[1]
if '"' in schema_name:
schema_name = schema_name
if '"' not in object_name:
@@ -781,47 +954,62 @@ def execute_check(self, db_name=None, sql='', close_conn=True):
if '"' not in object_name:
object_name = object_name.upper()
else:
- schema_name = ('"' + db_name + '"')
+ schema_name = '"' + db_name + '"'
if '"' not in object_name:
object_name = object_name.upper()
object_name = f"""{schema_name}.{object_name}"""
- if self.object_name_check(db_name=db_name,
- object_name=object_name) or object_name in object_name_list:
- result = ReviewResult(id=line, errlevel=2,
- stagestatus=f"""{object_name}对象已经存在!""",
- errormessage=f"""{object_name}对象已经存在!""",
- sql=sqlitem.statement)
+ if (
+ self.object_name_check(
+ db_name=db_name, object_name=object_name
+ )
+ or object_name in object_name_list
+ ):
+ result = ReviewResult(
+ id=line,
+ errlevel=2,
+ stagestatus=f"""{object_name}对象已经存在!""",
+ errormessage=f"""{object_name}对象已经存在!""",
+ sql=sqlitem.statement,
+ )
else:
object_name_list.add(object_name)
- result = ReviewResult(id=line, errlevel=1,
- stagestatus='当前平台,此语法不支持审核!',
- errormessage='当前平台,此语法不支持审核!',
- sql=sqlitem.statement,
- stmt_type=sqlitem.stmt_type,
- object_owner=sqlitem.object_owner,
- object_type=sqlitem.object_type,
- object_name=sqlitem.object_name,
- affected_rows=0,
- execute_time=0, )
+ result = ReviewResult(
+ id=line,
+ errlevel=1,
+ stagestatus="当前平台,此语法不支持审核!",
+ errormessage="当前平台,此语法不支持审核!",
+ sql=sqlitem.statement,
+ stmt_type=sqlitem.stmt_type,
+ object_owner=sqlitem.object_owner,
+ object_type=sqlitem.object_type,
+ object_name=sqlitem.object_name,
+ affected_rows=0,
+ execute_time=0,
+ )
else:
- result = ReviewResult(id=line, errlevel=1,
- stagestatus='当前平台,此语法不支持审核!',
- errormessage='当前平台,此语法不支持审核!',
- sql=sqlitem.statement,
- stmt_type=sqlitem.stmt_type,
- object_owner=sqlitem.object_owner,
- object_type=sqlitem.object_type,
- object_name=sqlitem.object_name,
- affected_rows=0,
- execute_time=0, )
+ result = ReviewResult(
+ id=line,
+ errlevel=1,
+ stagestatus="当前平台,此语法不支持审核!",
+ errormessage="当前平台,此语法不支持审核!",
+ sql=sqlitem.statement,
+ stmt_type=sqlitem.stmt_type,
+ object_owner=sqlitem.object_owner,
+ object_type=sqlitem.object_type,
+ object_name=sqlitem.object_name,
+ affected_rows=0,
+ execute_time=0,
+ )
# 判断工单类型
- if get_syntax_type(sql=sqlitem.statement, db_type='oracle') == 'DDL':
+ if get_syntax_type(sql=sqlitem.statement, db_type="oracle") == "DDL":
check_result.syntax_type = 1
check_result.rows += [result]
line += 1
except Exception as e:
- logger.warning(f"Oracle 语句执行报错,第{line}个SQL:{sqlitem.statement},错误信息{traceback.format_exc()}")
+ logger.warning(
+ f"Oracle 语句执行报错,第{line}个SQL:{sqlitem.statement},错误信息{traceback.format_exc()}"
+ )
check_result.error = str(e)
finally:
if close_conn:
@@ -836,9 +1024,9 @@ def execute_check(self, db_name=None, sql='', close_conn=True):
def execute_workflow(self, workflow, close_conn=True):
"""执行上线单,返回Review set
- 原来的逻辑是根据 sql_content简单来分割SQL,进而再执行这些SQL
- 新的逻辑变更为根据审核结果中记录的sql来执行,
- 如果是PLSQL存储过程等对象定义操作,还需检查确认新建对象是否编译通过!
+ 原来的逻辑是根据 sql_content简单来分割SQL,进而再执行这些SQL
+ 新的逻辑变更为根据审核结果中记录的sql来执行,
+ 如果是PLSQL存储过程等对象定义操作,还需检查确认新建对象是否编译通过!
"""
review_content = workflow.sqlworkflowcontent.review_content
review_result = json.loads(review_content)
@@ -861,7 +1049,7 @@ def execute_workflow(self, workflow, close_conn=True):
for sqlitem in sqlitemList:
statement = sqlitem.statement
if sqlitem.stmt_type == "SQL":
- statement = statement.rstrip(';')
+ statement = statement.rstrip(";")
# 如果是DDL的工单,获取对象的原定义,并保存到sql_rollback.undo_sql
# 需要授权 grant execute on dbms_metadata to xxxxx
if workflow.syntax_type == 1:
@@ -874,13 +1062,18 @@ def execute_workflow(self, workflow, close_conn=True):
metdata_back_flag = self.metdata_backup(workflow, cursor, statement)
with FuncTimer() as t:
- if statement != '':
+ if statement != "":
cursor.execute(statement)
conn.commit()
rowcount = cursor.rowcount
stagestatus = "Execute Successfully"
- if sqlitem.stmt_type == "PLSQL" and sqlitem.object_name and sqlitem.object_name != 'ANONYMOUS' and sqlitem.object_name != '':
+ if (
+ sqlitem.stmt_type == "PLSQL"
+ and sqlitem.object_name
+ and sqlitem.object_name != "ANONYMOUS"
+ and sqlitem.object_name != ""
+ ):
query_obj_sql = f"""SELECT OBJECT_NAME, STATUS, TO_CHAR(LAST_DDL_TIME, 'YYYY-MM-DD HH24:MI:SS') FROM ALL_OBJECTS
WHERE OWNER = '{sqlitem.object_owner}'
AND OBJECT_NAME = '{sqlitem.object_name}'
@@ -890,49 +1083,69 @@ def execute_workflow(self, workflow, close_conn=True):
if row:
status = row[1]
if status and status == "INVALID":
- stagestatus = "Compile Failed. Object " + sqlitem.object_owner + "." + sqlitem.object_name + " is invalid."
+ stagestatus = (
+ "Compile Failed. Object "
+ + sqlitem.object_owner
+ + "."
+ + sqlitem.object_name
+ + " is invalid."
+ )
else:
- stagestatus = "Compile Failed. Object " + sqlitem.object_owner + "." + sqlitem.object_name + " doesn't exist."
+ stagestatus = (
+ "Compile Failed. Object "
+ + sqlitem.object_owner
+ + "."
+ + sqlitem.object_name
+ + " doesn't exist."
+ )
if stagestatus != "Execute Successfully":
raise Exception(stagestatus)
- execute_result.rows.append(ReviewResult(
- id=line,
- errlevel=0,
- stagestatus=stagestatus,
- errormessage='None',
- sql=statement,
- affected_rows=cursor.rowcount,
- execute_time=t.cost,
- ))
+ execute_result.rows.append(
+ ReviewResult(
+ id=line,
+ errlevel=0,
+ stagestatus=stagestatus,
+ errormessage="None",
+ sql=statement,
+ affected_rows=cursor.rowcount,
+ execute_time=t.cost,
+ )
+ )
line += 1
except Exception as e:
- logger.warning(f"Oracle命令执行报错,工单id:{workflow.id},语句:{statement or sql}, 错误信息:{traceback.format_exc()}")
+ logger.warning(
+ f"Oracle命令执行报错,工单id:{workflow.id},语句:{statement or sql}, 错误信息:{traceback.format_exc()}"
+ )
execute_result.error = str(e)
# conn.rollback()
# 追加当前报错语句信息到执行结果中
- execute_result.rows.append(ReviewResult(
- id=line,
- errlevel=2,
- stagestatus='Execute Failed',
- errormessage=f'异常信息:{e}',
- sql=statement or sql,
- affected_rows=0,
- execute_time=0,
- ))
- line += 1
- # 报错语句后面的语句标记为审核通过、未执行,追加到执行结果中
- for sqlitem in sqlitemList[line - 1:]:
- execute_result.rows.append(ReviewResult(
+ execute_result.rows.append(
+ ReviewResult(
id=line,
- errlevel=0,
- stagestatus='Audit completed',
- errormessage=f'前序语句失败, 未执行',
- sql=sqlitem.statement,
+ errlevel=2,
+ stagestatus="Execute Failed",
+ errormessage=f"异常信息:{e}",
+ sql=statement or sql,
affected_rows=0,
execute_time=0,
- ))
+ )
+ )
+ line += 1
+ # 报错语句后面的语句标记为审核通过、未执行,追加到执行结果中
+ for sqlitem in sqlitemList[line - 1 :]:
+ execute_result.rows.append(
+ ReviewResult(
+ id=line,
+ errlevel=0,
+ stagestatus="Audit completed",
+ errormessage=f"前序语句失败, 未执行",
+ sql=sqlitem.statement,
+ affected_rows=0,
+ execute_time=0,
+ )
+ )
line += 1
finally:
# 备份
@@ -941,9 +1154,16 @@ def execute_workflow(self, workflow, close_conn=True):
cursor.execute(f"select sysdate from dual")
rows = cursor.fetchone()
end_time = rows[0]
- self.backup(workflow, cursor=cursor, begin_time=begin_time, end_time=end_time)
+ self.backup(
+ workflow,
+ cursor=cursor,
+ begin_time=begin_time,
+ end_time=end_time,
+ )
except Exception as e:
- logger.error(f"Oracle工单备份异常,工单id:{workflow.id}, 错误信息:{traceback.format_exc()}")
+ logger.error(
+ f"Oracle工单备份异常,工单id:{workflow.id}, 错误信息:{traceback.format_exc()}"
+ )
if close_conn:
self.close()
return execute_result
@@ -967,32 +1187,34 @@ def backup(self, workflow, cursor, begin_time, end_time):
backup_cursor = conn.cursor()
backup_cursor.execute(f"""create database if not exists ora_backup;""")
backup_cursor.execute(f"use ora_backup;")
- backup_cursor.execute(f"""CREATE TABLE if not exists `sql_rollback` (
+ backup_cursor.execute(
+ f"""CREATE TABLE if not exists `sql_rollback` (
`id` bigint(20) NOT NULL AUTO_INCREMENT,
`redo_sql` mediumtext,
`undo_sql` mediumtext,
`workflow_id` bigint(20) NOT NULL,
PRIMARY KEY (`id`),
key `idx_sql_rollback_01` (`workflow_id`)
- ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;""")
+ ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;"""
+ )
# 使用logminer抓取回滚SQL
- logmnr_start_sql = f'''begin
+ logmnr_start_sql = f"""begin
dbms_logmnr.start_logmnr(
starttime=>to_date('{begin_time}','yyyy-mm-dd hh24:mi:ss'),
endtime=>to_date('{end_time}','yyyy/mm/dd hh24:mi:ss'),
options=>dbms_logmnr.dict_from_online_catalog + dbms_logmnr.continuous_mine);
- end;'''
- undo_sql = f'''select
+ end;"""
+ undo_sql = f"""select
xmlagg(xmlparse(content sql_redo wellformed) order by scn,rs_id,ssn,rownum).getclobval() ,
xmlagg(xmlparse(content sql_undo wellformed) order by scn,rs_id,ssn,rownum).getclobval()
from v$logmnr_contents
where SEG_OWNER not in ('SYS')
and session# = (select sid from v$mystat where rownum = 1)
and serial# = (select serial# from v$session s where s.sid = (select sid from v$mystat where rownum = 1 ))
- group by scn,rs_id,ssn order by scn desc'''
- logmnr_end_sql = f'''begin
+ group by scn,rs_id,ssn order by scn desc"""
+ logmnr_end_sql = f"""begin
dbms_logmnr.end_logmnr;
- end;'''
+ end;"""
cursor.execute(logmnr_start_sql)
cursor.execute(undo_sql)
rows = cursor.fetchall()
@@ -1002,7 +1224,7 @@ def backup(self, workflow, cursor, begin_time, end_time):
redo_sql = f"{row[0]}"
redo_sql = redo_sql.replace("'", "\\'")
if row[1] is None:
- undo_sql = f' '
+ undo_sql = f" "
else:
undo_sql = f"{row[1]}"
undo_sql = undo_sql.replace("'", "\\'")
@@ -1032,19 +1254,21 @@ def metdata_backup(self, workflow, cursor, redo_sql):
backup_cursor = conn.cursor()
backup_cursor.execute(f"""create database if not exists ora_backup;""")
backup_cursor.execute(f"use ora_backup;")
- backup_cursor.execute(f"""CREATE TABLE if not exists `sql_rollback` (
+ backup_cursor.execute(
+ f"""CREATE TABLE if not exists `sql_rollback` (
`id` bigint(20) NOT NULL AUTO_INCREMENT,
`redo_sql` mediumtext,
`undo_sql` mediumtext,
`workflow_id` bigint(20) NOT NULL,
PRIMARY KEY (`id`),
key `idx_sql_rollback_01` (`workflow_id`)
- ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;""")
+ ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;"""
+ )
rows = cursor.fetchall()
if len(rows) > 0:
for row in rows:
if row[0] is None:
- undo_sql = f' '
+ undo_sql = f" "
else:
undo_sql = f"{row[0]}"
undo_sql = undo_sql.replace("'", "\\'")
@@ -1090,7 +1314,7 @@ def get_rollback(self, workflow):
conn.close()
return list_backup_sql
- def sqltuningadvisor(self, db_name=None, sql='', close_conn=True, **kwargs):
+ def sqltuningadvisor(self, db_name=None, sql="", close_conn=True, **kwargs):
"""
add by Jan.song 20200421
使用DBMS_SQLTUNE包做sql tuning支持
@@ -1098,14 +1322,14 @@ def sqltuningadvisor(self, db_name=None, sql='', close_conn=True, **kwargs):
返回 ResultSet
"""
result_set = ResultSet(full_sql=sql)
- task_name = 'sqlaudit' + f'''{threading.currentThread().ident}'''
+ task_name = "sqlaudit" + f"""{threading.currentThread().ident}"""
task_begin = 0
try:
conn = self.get_connection()
cursor = conn.cursor()
- sql = sql.rstrip(';')
+ sql = sql.rstrip(";")
# 创建分析任务
- create_task_sql = f'''DECLARE
+ create_task_sql = f"""DECLARE
my_task_name VARCHAR2(30);
my_sqltext CLOB;
BEGIN
@@ -1118,15 +1342,20 @@ def sqltuningadvisor(self, db_name=None, sql='', close_conn=True, **kwargs):
task_name => '{task_name}',
description => 'tuning');
DBMS_SQLTUNE.EXECUTE_TUNING_TASK( task_name => '{task_name}');
- END;'''
+ END;"""
task_begin = 1
cursor.execute(create_task_sql)
# 获取分析报告
- get_task_sql = f'''select DBMS_SQLTUNE.REPORT_TUNING_TASK( '{task_name}') from dual'''
+ get_task_sql = (
+ f"""select DBMS_SQLTUNE.REPORT_TUNING_TASK( '{task_name}') from dual"""
+ )
cursor.execute(get_task_sql)
fields = cursor.description
if any(x[1] == cx_Oracle.CLOB for x in fields):
- rows = [tuple([(c.read() if type(c) == cx_Oracle.LOB else c) for c in r]) for r in cursor]
+ rows = [
+ tuple([(c.read() if type(c) == cx_Oracle.LOB else c) for c in r])
+ for r in cursor
+ ]
else:
rows = cursor.fetchall()
result_set.column_list = [i[0] for i in fields] if fields else []
@@ -1138,10 +1367,10 @@ def sqltuningadvisor(self, db_name=None, sql='', close_conn=True, **kwargs):
finally:
# 结束分析任务
if task_begin == 1:
- end_sql = f'''DECLARE
+ end_sql = f"""DECLARE
begin
dbms_sqltune.drop_tuning_task('{task_name}');
- end;'''
+ end;"""
cursor.execute(end_sql)
if close_conn:
self.close()
diff --git a/sql/engines/pgsql.py b/sql/engines/pgsql.py
index 9b4de94bf7..0146469743 100644
--- a/sql/engines/pgsql.py
+++ b/sql/engines/pgsql.py
@@ -18,29 +18,35 @@
from .models import ResultSet, ReviewSet, ReviewResult
from sql.utils.data_masking import simple_column_mask
-__author__ = 'hhyo、yyukai'
+__author__ = "hhyo、yyukai"
-logger = logging.getLogger('default')
+logger = logging.getLogger("default")
class PgSQLEngine(EngineBase):
test_query = "SELECT 1"
def get_connection(self, db_name=None):
- db_name = db_name or self.db_name or 'postgres'
+ db_name = db_name or self.db_name or "postgres"
if self.conn:
return self.conn
- self.conn = psycopg2.connect(host=self.host, port=self.port, user=self.user,
- password=self.password, dbname=db_name, connect_timeout=10)
+ self.conn = psycopg2.connect(
+ host=self.host,
+ port=self.port,
+ user=self.user,
+ password=self.password,
+ dbname=db_name,
+ connect_timeout=10,
+ )
return self.conn
@property
def name(self):
- return 'PgSQL'
+ return "PgSQL"
@property
def info(self):
- return 'PgSQL engine'
+ return "PgSQL engine"
def get_all_databases(self):
"""
@@ -48,7 +54,11 @@ def get_all_databases(self):
:return:
"""
result = self.query(sql=f"SELECT datname FROM pg_database;")
- db_list = [row[0] for row in result.rows if row[0] not in ['postgres', 'template0', 'template1']]
+ db_list = [
+ row[0]
+ for row in result.rows
+ if row[0] not in ["postgres", "template0", "template1"]
+ ]
result.rows = db_list
return result
@@ -57,10 +67,21 @@ def get_all_schemas(self, db_name, **kwargs):
获取模式列表
:return:
"""
- result = self.query(db_name=db_name, sql=f"select schema_name from information_schema.schemata;")
- schema_list = [row[0] for row in result.rows if row[0] not in ['information_schema',
- 'pg_catalog', 'pg_toast_temp_1',
- 'pg_temp_1', 'pg_toast']]
+ result = self.query(
+ db_name=db_name, sql=f"select schema_name from information_schema.schemata;"
+ )
+ schema_list = [
+ row[0]
+ for row in result.rows
+ if row[0]
+ not in [
+ "information_schema",
+ "pg_catalog",
+ "pg_toast_temp_1",
+ "pg_temp_1",
+ "pg_toast",
+ ]
+ ]
result.rows = schema_list
return result
@@ -71,12 +92,12 @@ def get_all_tables(self, db_name, **kwargs):
:param schema_name:
:return:
"""
- schema_name = kwargs.get('schema_name')
+ schema_name = kwargs.get("schema_name")
sql = f"""SELECT table_name
FROM information_schema.tables
where table_schema ='{schema_name}';"""
result = self.query(db_name=db_name, sql=sql)
- tb_list = [row[0] for row in result.rows if row[0] not in ['test']]
+ tb_list = [row[0] for row in result.rows if row[0] not in ["test"]]
result.rows = tb_list
return result
@@ -88,7 +109,7 @@ def get_all_columns_by_tb(self, db_name, tb_name, **kwargs):
:param schema_name:
:return:
"""
- schema_name = kwargs.get('schema_name')
+ schema_name = kwargs.get("schema_name")
sql = f"""SELECT column_name
FROM information_schema.columns
where table_name='{tb_name}'
@@ -106,7 +127,7 @@ def describe_table(self, db_name, tb_name, **kwargs):
:param schema_name:
:return:
"""
- schema_name = kwargs.get('schema_name')
+ schema_name = kwargs.get("schema_name")
sql = f"""select
col.column_name,
col.data_type,
@@ -126,32 +147,32 @@ def describe_table(self, db_name, tb_name, **kwargs):
result = self.query(db_name=db_name, schema_name=schema_name, sql=sql)
return result
- def query_check(self, db_name=None, sql=''):
+ def query_check(self, db_name=None, sql=""):
# 查询语句的检查、注释去除、切分
- result = {'msg': '', 'bad_query': False, 'filtered_sql': sql, 'has_star': False}
+ result = {"msg": "", "bad_query": False, "filtered_sql": sql, "has_star": False}
# 删除注释语句,进行语法判断,执行第一条有效sql
try:
sql = sqlparse.format(sql, strip_comments=True)
sql = sqlparse.split(sql)[0]
- result['filtered_sql'] = sql.strip()
+ result["filtered_sql"] = sql.strip()
except IndexError:
- result['bad_query'] = True
- result['msg'] = '没有有效的SQL语句'
+ result["bad_query"] = True
+ result["msg"] = "没有有效的SQL语句"
if re.match(r"^select", sql, re.I) is None:
- result['bad_query'] = True
- result['msg'] = '不支持的查询语法类型!'
- if '*' in sql:
- result['has_star'] = True
- result['msg'] = 'SQL语句中含有 * '
+ result["bad_query"] = True
+ result["msg"] = "不支持的查询语法类型!"
+ if "*" in sql:
+ result["has_star"] = True
+ result["msg"] = "SQL语句中含有 * "
return result
- def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs):
- """返回 ResultSet """
- schema_name = kwargs.get('schema_name')
+ def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs):
+ """返回 ResultSet"""
+ schema_name = kwargs.get("schema_name")
result_set = ResultSet(full_sql=sql)
try:
conn = self.get_connection(db_name=db_name)
- max_execution_time = kwargs.get('max_execution_time', 0)
+ max_execution_time = kwargs.get("max_execution_time", 0)
cursor = conn.cursor()
try:
cursor.execute(f"SET statement_timeout TO {max_execution_time};")
@@ -178,16 +199,16 @@ def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs):
self.close()
return result_set
- def filter_sql(self, sql='', limit_num=0):
+ def filter_sql(self, sql="", limit_num=0):
# 对查询sql增加limit限制,# TODO limit改写待优化
- sql_lower = sql.lower().rstrip(';').strip()
+ sql_lower = sql.lower().rstrip(";").strip()
if re.match(r"^select", sql_lower):
if re.search(r"limit\s+(\d+)$", sql_lower) is None:
if re.search(r"limit\s+\d+\s*,\s*(\d+)$", sql_lower) is None:
return f"{sql.rstrip(';')} limit {limit_num};"
return f"{sql.rstrip(';')};"
- def query_masking(self, db_name=None, sql='', resultset=None):
+ def query_masking(self, db_name=None, sql="", resultset=None):
"""简单字段脱敏规则, 仅对select有效"""
if re.match(r"^select", sql, re.I):
filtered_result = simple_column_mask(self.instance, resultset)
@@ -196,40 +217,49 @@ def query_masking(self, db_name=None, sql='', resultset=None):
filtered_result = resultset
return filtered_result
- def execute_check(self, db_name=None, sql=''):
+ def execute_check(self, db_name=None, sql=""):
"""上线单执行前的检查, 返回Review set"""
config = SysConfig()
check_result = ReviewSet(full_sql=sql)
# 禁用/高危语句检查
line = 1
- critical_ddl_regex = config.get('critical_ddl_regex', '')
+ critical_ddl_regex = config.get("critical_ddl_regex", "")
p = re.compile(critical_ddl_regex)
check_result.syntax_type = 2 # TODO 工单类型 0、其他 1、DDL,2、DML
for statement in sqlparse.split(sql):
statement = sqlparse.format(statement, strip_comments=True)
# 禁用语句
if re.match(r"^select", statement.lower()):
- result = ReviewResult(id=line, errlevel=2,
- stagestatus='驳回不支持语句',
- errormessage='仅支持DML和DDL语句,查询语句请使用SQL查询功能!',
- sql=statement)
+ result = ReviewResult(
+ id=line,
+ errlevel=2,
+ stagestatus="驳回不支持语句",
+ errormessage="仅支持DML和DDL语句,查询语句请使用SQL查询功能!",
+ sql=statement,
+ )
# 高危语句
elif critical_ddl_regex and p.match(statement.strip().lower()):
- result = ReviewResult(id=line, errlevel=2,
- stagestatus='驳回高危SQL',
- errormessage='禁止提交匹配' + critical_ddl_regex + '条件的语句!',
- sql=statement)
+ result = ReviewResult(
+ id=line,
+ errlevel=2,
+ stagestatus="驳回高危SQL",
+ errormessage="禁止提交匹配" + critical_ddl_regex + "条件的语句!",
+ sql=statement,
+ )
# 正常语句
else:
- result = ReviewResult(id=line, errlevel=0,
- stagestatus='Audit completed',
- errormessage='None',
- sql=statement,
- affected_rows=0,
- execute_time=0, )
+ result = ReviewResult(
+ id=line,
+ errlevel=0,
+ stagestatus="Audit completed",
+ errormessage="None",
+ sql=statement,
+ affected_rows=0,
+ execute_time=0,
+ )
# 判断工单类型
- if get_syntax_type(statement) == 'DDL':
+ if get_syntax_type(statement) == "DDL":
check_result.syntax_type = 1
check_result.rows += [result]
line += 1
@@ -256,45 +286,53 @@ def execute_workflow(self, workflow, close_conn=True):
cursor = conn.cursor()
# 逐条执行切分语句,追加到执行结果中
for statement in split_sql:
- statement = statement.rstrip(';')
+ statement = statement.rstrip(";")
with FuncTimer() as t:
cursor.execute(statement)
conn.commit()
- execute_result.rows.append(ReviewResult(
- id=line,
- errlevel=0,
- stagestatus='Execute Successfully',
- errormessage='None',
- sql=statement,
- affected_rows=cursor.rowcount,
- execute_time=t.cost,
- ))
+ execute_result.rows.append(
+ ReviewResult(
+ id=line,
+ errlevel=0,
+ stagestatus="Execute Successfully",
+ errormessage="None",
+ sql=statement,
+ affected_rows=cursor.rowcount,
+ execute_time=t.cost,
+ )
+ )
line += 1
except Exception as e:
- logger.warning(f"PGSQL命令执行报错,语句:{statement or sql}, 错误信息:{traceback.format_exc()}")
+ logger.warning(
+ f"PGSQL命令执行报错,语句:{statement or sql}, 错误信息:{traceback.format_exc()}"
+ )
execute_result.error = str(e)
# 追加当前报错语句信息到执行结果中
- execute_result.rows.append(ReviewResult(
- id=line,
- errlevel=2,
- stagestatus='Execute Failed',
- errormessage=f'异常信息:{e}',
- sql=statement or sql,
- affected_rows=0,
- execute_time=0,
- ))
- line += 1
- # 报错语句后面的语句标记为审核通过、未执行,追加到执行结果中
- for statement in split_sql[line - 1:]:
- execute_result.rows.append(ReviewResult(
+ execute_result.rows.append(
+ ReviewResult(
id=line,
- errlevel=0,
- stagestatus='Audit completed',
- errormessage=f'前序语句失败, 未执行',
- sql=statement,
+ errlevel=2,
+ stagestatus="Execute Failed",
+ errormessage=f"异常信息:{e}",
+ sql=statement or sql,
affected_rows=0,
execute_time=0,
- ))
+ )
+ )
+ line += 1
+ # 报错语句后面的语句标记为审核通过、未执行,追加到执行结果中
+ for statement in split_sql[line - 1 :]:
+ execute_result.rows.append(
+ ReviewResult(
+ id=line,
+ errlevel=0,
+ stagestatus="Audit completed",
+ errormessage=f"前序语句失败, 未执行",
+ sql=statement,
+ affected_rows=0,
+ execute_time=0,
+ )
+ )
line += 1
finally:
if close_conn:
diff --git a/sql/engines/phoenix.py b/sql/engines/phoenix.py
index 8da6e7f025..c68c911d19 100644
--- a/sql/engines/phoenix.py
+++ b/sql/engines/phoenix.py
@@ -8,7 +8,7 @@
from . import EngineBase
from .models import ResultSet, ReviewSet, ReviewResult
-logger = logging.getLogger('default')
+logger = logging.getLogger("default")
class PhoenixEngine(EngineBase):
@@ -18,7 +18,7 @@ def get_connection(self, db_name=None):
if self.conn:
return self.conn
- database_url = f'http://{self.host}:{self.port}/'
+ database_url = f"http://{self.host}:{self.port}/"
self.conn = phoenixdb.connect(database_url, autocommit=True)
return self.conn
@@ -50,43 +50,45 @@ def describe_table(self, db_name, tb_name, **kwargs):
result = self.query(sql=sql)
return result
- def query_check(self, db_name=None, sql=''):
+ def query_check(self, db_name=None, sql=""):
# 查询语句的检查、注释去除、切分
- result = {'msg': '', 'bad_query': False, 'filtered_sql': sql, 'has_star': False}
- keyword_warning = ''
- sql_whitelist = ['select', 'explain']
+ result = {"msg": "", "bad_query": False, "filtered_sql": sql, "has_star": False}
+ keyword_warning = ""
+ sql_whitelist = ["select", "explain"]
# 根据白名单list拼接pattern语句
whitelist_pattern = "^" + "|^".join(sql_whitelist)
# 删除注释语句,进行语法判断,执行第一条有效sql
try:
sql = sql.format(sql, strip_comments=True)
sql = sqlparse.split(sql)[0]
- result['filtered_sql'] = sql.strip()
+ result["filtered_sql"] = sql.strip()
# sql_lower = sql.lower()
except IndexError:
- result['bad_query'] = True
- result['msg'] = '没有有效的SQL语句'
+ result["bad_query"] = True
+ result["msg"] = "没有有效的SQL语句"
return result
if re.match(whitelist_pattern, sql) is None:
- result['bad_query'] = True
- result['msg'] = '仅支持{}语法!'.format(','.join(sql_whitelist))
+ result["bad_query"] = True
+ result["msg"] = "仅支持{}语法!".format(",".join(sql_whitelist))
return result
- if result.get('bad_query'):
- result['msg'] = keyword_warning
+ if result.get("bad_query"):
+ result["msg"] = keyword_warning
return result
- def filter_sql(self, sql='', limit_num=0):
+ def filter_sql(self, sql="", limit_num=0):
"""检查是SELECT语句否添加了limit限制关键词"""
- sql = sql.rstrip(';').strip()
+ sql = sql.rstrip(";").strip()
if re.match(r"^select", sql, re.I):
- if not re.compile(r'limit\s+(\d+)\s*((,|offset)\s*\d+)?\s*$', re.I).search(sql):
- sql = f'{sql} limit {limit_num}'
+ if not re.compile(r"limit\s+(\d+)\s*((,|offset)\s*\d+)?\s*$", re.I).search(
+ sql
+ ):
+ sql = f"{sql} limit {limit_num}"
else:
- sql = f'{sql};'
+ sql = f"{sql};"
return sql.strip()
- def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs):
- """返回 ResultSet """
+ def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs):
+ """返回 ResultSet"""
result_set = ResultSet(full_sql=sql)
try:
conn = self.get_connection()
@@ -109,34 +111,38 @@ def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs):
self.close()
return result_set
- def query_masking(self, db_name=None, sql='', resultset=None):
+ def query_masking(self, db_name=None, sql="", resultset=None):
"""传入 sql语句, db名, 结果集, 返回一个脱敏后的结果集"""
return resultset
- def execute_check(self, db_name=None, sql=''):
+ def execute_check(self, db_name=None, sql=""):
"""上线单执行前的检查, 返回Review set"""
check_result = ReviewSet(full_sql=sql)
# 切分语句,追加到检测结果中,默认全部检测通过
rowid = 1
split_sql = sqlparse.split(sql)
for statement in split_sql:
- check_result.rows.append(ReviewResult(
- id=rowid,
- errlevel=0,
- stagestatus='Audit completed',
- errormessage='None',
- sql=statement,
- affected_rows=0,
- execute_time=0, )
+ check_result.rows.append(
+ ReviewResult(
+ id=rowid,
+ errlevel=0,
+ stagestatus="Audit completed",
+ errormessage="None",
+ sql=statement,
+ affected_rows=0,
+ execute_time=0,
+ )
)
rowid += 1
return check_result
def execute_workflow(self, workflow):
"""PhoenixDB无需备份"""
- return self.execute(db_name=workflow.db_name, sql=workflow.sqlworkflowcontent.sql_content)
+ return self.execute(
+ db_name=workflow.db_name, sql=workflow.sqlworkflowcontent.sql_content
+ )
- def execute(self, db_name=None, sql='', close_conn=True):
+ def execute(self, db_name=None, sql="", close_conn=True):
"""原生执行语句"""
execute_result = ReviewSet(full_sql=sql)
conn = self.get_connection(db_name=db_name)
@@ -149,39 +155,45 @@ def execute(self, db_name=None, sql='', close_conn=True):
except Exception as e:
logger.error(f"Phoenix命令执行报错,语句:{sql}, 错误信息:{traceback.format_exc()}")
execute_result.error = str(e)
- execute_result.rows.append(ReviewResult(
- id=rowid,
- errlevel=2,
- stagestatus='Execute Failed',
- errormessage=f'异常信息:{e}',
- sql=statement,
- affected_rows=0,
- execute_time=0,
- ))
+ execute_result.rows.append(
+ ReviewResult(
+ id=rowid,
+ errlevel=2,
+ stagestatus="Execute Failed",
+ errormessage=f"异常信息:{e}",
+ sql=statement,
+ affected_rows=0,
+ execute_time=0,
+ )
+ )
break
else:
- execute_result.rows.append(ReviewResult(
- id=rowid,
- errlevel=0,
- stagestatus='Execute Successfully',
- errormessage='None',
- sql=statement,
- affected_rows=cursor.rowcount,
- execute_time=0,
- ))
+ execute_result.rows.append(
+ ReviewResult(
+ id=rowid,
+ errlevel=0,
+ stagestatus="Execute Successfully",
+ errormessage="None",
+ sql=statement,
+ affected_rows=cursor.rowcount,
+ execute_time=0,
+ )
+ )
rowid += 1
if execute_result.error:
# 如果失败, 将剩下的部分加入结果集返回
for statement in split_sql[rowid:]:
- execute_result.rows.append(ReviewResult(
- id=rowid,
- errlevel=2,
- stagestatus='Execute Failed',
- errormessage=f'前序语句失败, 未执行',
- sql=statement,
- affected_rows=0,
- execute_time=0,
- ))
+ execute_result.rows.append(
+ ReviewResult(
+ id=rowid,
+ errlevel=2,
+ stagestatus="Execute Failed",
+ errormessage=f"前序语句失败, 未执行",
+ sql=statement,
+ affected_rows=0,
+ execute_time=0,
+ )
+ )
rowid += 1
if close_conn:
diff --git a/sql/engines/redis.py b/sql/engines/redis.py
index af84d76779..eda4779755 100644
--- a/sql/engines/redis.py
+++ b/sql/engines/redis.py
@@ -17,29 +17,41 @@
from . import EngineBase
from .models import ResultSet, ReviewSet, ReviewResult
-__author__ = 'hhyo'
+__author__ = "hhyo"
-logger = logging.getLogger('default')
+logger = logging.getLogger("default")
class RedisEngine(EngineBase):
def get_connection(self, db_name=None):
db_name = db_name or self.db_name
- if self.mode == 'cluster':
- return redis.cluster.RedisCluster(host=self.host, port=self.port, password=self.password,
- encoding_errors='ignore', decode_responses=True,
- socket_connect_timeout=10)
+ if self.mode == "cluster":
+ return redis.cluster.RedisCluster(
+ host=self.host,
+ port=self.port,
+ password=self.password,
+ encoding_errors="ignore",
+ decode_responses=True,
+ socket_connect_timeout=10,
+ )
else:
- return redis.Redis(host=self.host, port=self.port, db=db_name, password=self.password,
- encoding_errors='ignore', decode_responses=True, socket_connect_timeout=10)
+ return redis.Redis(
+ host=self.host,
+ port=self.port,
+ db=db_name,
+ password=self.password,
+ encoding_errors="ignore",
+ decode_responses=True,
+ socket_connect_timeout=10,
+ )
@property
def name(self):
- return 'Redis'
+ return "Redis"
@property
def info(self):
- return 'Redis engine'
+ return "Redis engine"
def test_connection(self):
return self.get_all_databases()
@@ -49,44 +61,74 @@ def get_all_databases(self, **kwargs):
获取数据库列表
:return:
"""
- result = ResultSet(full_sql='CONFIG GET databases')
+ result = ResultSet(full_sql="CONFIG GET databases")
conn = self.get_connection()
try:
- rows = conn.config_get('databases')['databases']
+ rows = conn.config_get("databases")["databases"]
except Exception as e:
logger.warning(f"Redis CONFIG GET databases 执行报错,异常信息:{e}")
- dbs = [int(i.split('db')[1]) for i in conn.info('Keyspace').keys() if len(i.split('db')) == 2]
+ dbs = [
+ int(i.split("db")[1])
+ for i in conn.info("Keyspace").keys()
+ if len(i.split("db")) == 2
+ ]
rows = max(dbs, [16])
db_list = [str(x) for x in range(int(rows))]
result.rows = db_list
return result
- def query_check(self, db_name=None, sql='', limit_num=0):
+ def query_check(self, db_name=None, sql="", limit_num=0):
"""提交查询前的检查"""
- result = {'msg': '', 'bad_query': True, 'filtered_sql': sql, 'has_star': False}
- safe_cmd = ["scan", "exists", "ttl", "pttl", "type", "get", "mget", "strlen",
- "hgetall", "hexists", "hget", "hmget", "hkeys", "hvals",
- "smembers", "scard", "sdiff", "sunion", "sismember", "llen", "lrange", "lindex",
- "zrange", "zrangebyscore", "zscore", "zcard", "zcount", "zrank"]
+ result = {"msg": "", "bad_query": True, "filtered_sql": sql, "has_star": False}
+ safe_cmd = [
+ "scan",
+ "exists",
+ "ttl",
+ "pttl",
+ "type",
+ "get",
+ "mget",
+ "strlen",
+ "hgetall",
+ "hexists",
+ "hget",
+ "hmget",
+ "hkeys",
+ "hvals",
+ "smembers",
+ "scard",
+ "sdiff",
+ "sunion",
+ "sismember",
+ "llen",
+ "lrange",
+ "lindex",
+ "zrange",
+ "zrangebyscore",
+ "zscore",
+ "zcard",
+ "zcount",
+ "zrank",
+ ]
# 命令校验,仅可以执行safe_cmd内的命令
for cmd in safe_cmd:
- if re.match(fr'^{cmd}', sql.strip(), re.I):
- result['bad_query'] = False
+ if re.match(rf"^{cmd}", sql.strip(), re.I):
+ result["bad_query"] = False
break
- if result['bad_query']:
- result['msg'] = "禁止执行该命令!"
+ if result["bad_query"]:
+ result["msg"] = "禁止执行该命令!"
return result
- def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs):
- """返回 ResultSet """
+ def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs):
+ """返回 ResultSet"""
result_set = ResultSet(full_sql=sql)
try:
conn = self.get_connection(db_name=db_name)
rows = conn.execute_command(*shlex.split(sql))
- result_set.column_list = ['Result']
+ result_set.column_list = ["Result"]
if isinstance(rows, list) or isinstance(rows, tuple):
- if re.match(fr'^scan', sql.strip(), re.I):
+ if re.match(rf"^scan", sql.strip(), re.I):
keys = [[row] for row in rows[1]]
keys.insert(0, [rows[0]])
result_set.rows = tuple(keys)
@@ -108,26 +150,28 @@ def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs):
result_set.error = str(e)
return result_set
- def filter_sql(self, sql='', limit_num=0):
+ def filter_sql(self, sql="", limit_num=0):
return sql.strip()
- def query_masking(self, db_name=None, sql='', resultset=None):
+ def query_masking(self, db_name=None, sql="", resultset=None):
"""不做脱敏"""
return resultset
- def execute_check(self, db_name=None, sql=''):
+ def execute_check(self, db_name=None, sql=""):
"""上线单执行前的检查, 返回Review set"""
check_result = ReviewSet(full_sql=sql)
- split_sql = [cmd.strip() for cmd in sql.split('\n') if cmd.strip()]
+ split_sql = [cmd.strip() for cmd in sql.split("\n") if cmd.strip()]
line = 1
for cmd in split_sql:
- result = ReviewResult(id=line,
- errlevel=0,
- stagestatus='Audit completed',
- errormessage='None',
- sql=cmd,
- affected_rows=0,
- execute_time=0, )
+ result = ReviewResult(
+ id=line,
+ errlevel=0,
+ stagestatus="Audit completed",
+ errormessage="None",
+ sql=cmd,
+ affected_rows=0,
+ execute_time=0,
+ )
check_result.rows += [result]
line += 1
return check_result
@@ -135,7 +179,7 @@ def execute_check(self, db_name=None, sql=''):
def execute_workflow(self, workflow):
"""执行上线单,返回Review set"""
sql = workflow.sqlworkflowcontent.sql_content
- split_sql = [cmd.strip() for cmd in sql.split('\n') if cmd.strip()]
+ split_sql = [cmd.strip() for cmd in sql.split("\n") if cmd.strip()]
execute_result = ReviewSet(full_sql=sql)
line = 1
cmd = None
@@ -144,40 +188,48 @@ def execute_workflow(self, workflow):
for cmd in split_sql:
with FuncTimer() as t:
conn.execute_command(*shlex.split(cmd))
- execute_result.rows.append(ReviewResult(
- id=line,
- errlevel=0,
- stagestatus='Execute Successfully',
- errormessage='None',
- sql=cmd,
- affected_rows=0,
- execute_time=t.cost,
- ))
+ execute_result.rows.append(
+ ReviewResult(
+ id=line,
+ errlevel=0,
+ stagestatus="Execute Successfully",
+ errormessage="None",
+ sql=cmd,
+ affected_rows=0,
+ execute_time=t.cost,
+ )
+ )
line += 1
except Exception as e:
- logger.warning(f"Redis命令执行报错,语句:{cmd or sql}, 错误信息:{traceback.format_exc()}")
+ logger.warning(
+ f"Redis命令执行报错,语句:{cmd or sql}, 错误信息:{traceback.format_exc()}"
+ )
# 追加当前报错语句信息到执行结果中
execute_result.error = str(e)
- execute_result.rows.append(ReviewResult(
- id=line,
- errlevel=2,
- stagestatus='Execute Failed',
- errormessage=f'异常信息:{e}',
- sql=cmd,
- affected_rows=0,
- execute_time=0,
- ))
- line += 1
- # 报错语句后面的语句标记为审核通过、未执行,追加到执行结果中
- for statement in split_sql[line - 1:]:
- execute_result.rows.append(ReviewResult(
+ execute_result.rows.append(
+ ReviewResult(
id=line,
- errlevel=0,
- stagestatus='Audit completed',
- errormessage=f'前序语句失败, 未执行',
- sql=statement,
+ errlevel=2,
+ stagestatus="Execute Failed",
+ errormessage=f"异常信息:{e}",
+ sql=cmd,
affected_rows=0,
execute_time=0,
- ))
+ )
+ )
+ line += 1
+ # 报错语句后面的语句标记为审核通过、未执行,追加到执行结果中
+ for statement in split_sql[line - 1 :]:
+ execute_result.rows.append(
+ ReviewResult(
+ id=line,
+ errlevel=0,
+ stagestatus="Audit completed",
+ errormessage=f"前序语句失败, 未执行",
+ sql=statement,
+ affected_rows=0,
+ execute_time=0,
+ )
+ )
line += 1
return execute_result
diff --git a/sql/engines/tests.py b/sql/engines/tests.py
index 14da252e3a..e284ff2d62 100644
--- a/sql/engines/tests.py
+++ b/sql/engines/tests.py
@@ -27,39 +27,44 @@
class TestReviewSet(TestCase):
def test_review_set(self):
new_review_set = ReviewSet()
- new_review_set.rows = [{'id': '1679123'}]
- self.assertIn('1679123', new_review_set.json())
+ new_review_set.rows = [{"id": "1679123"}]
+ self.assertIn("1679123", new_review_set.json())
class TestEngineBase(TestCase):
@classmethod
def setUpClass(cls):
- cls.u1 = User(username='some_user', display='用户1')
+ cls.u1 = User(username="some_user", display="用户1")
cls.u1.save()
- cls.ins1 = Instance(instance_name='some_ins', type='master', db_type='mssql', host='some_host',
- port=1366, user='ins_user', password='some_str')
+ cls.ins1 = Instance(
+ instance_name="some_ins",
+ type="master",
+ db_type="mssql",
+ host="some_host",
+ port=1366,
+ user="ins_user",
+ password="some_str",
+ )
cls.ins1.save()
cls.wf1 = SqlWorkflow.objects.create(
- workflow_name='some_name',
+ workflow_name="some_name",
group_id=1,
- group_name='g1',
+ group_name="g1",
engineer=cls.u1.username,
engineer_display=cls.u1.display,
- audit_auth_groups='some_group',
+ audit_auth_groups="some_group",
create_time=datetime.now() - timedelta(days=1),
- status='workflow_finish',
+ status="workflow_finish",
is_backup=True,
instance=cls.ins1,
- db_name='some_db',
- syntax_type=1
+ db_name="some_db",
+ syntax_type=1,
)
cls.wfc1 = SqlWorkflowContent.objects.create(
workflow=cls.wf1,
- sql_content='some_sql',
- execute_result=json.dumps([{
- 'id': 1,
- 'sql': 'some_content'
- }]))
+ sql_content="some_sql",
+ execute_result=json.dumps([{"id": 1, "sql": "some_content"}]),
+ )
@classmethod
def tearDownClass(cls):
@@ -75,27 +80,35 @@ def test_init_with_ins(self):
class TestMssql(TestCase):
-
@classmethod
def setUpClass(cls):
- cls.ins1 = Instance(instance_name='some_ins', type='slave', db_type='mssql', host='some_host',
- port=1366, user='ins_user', password='some_str')
+ cls.ins1 = Instance(
+ instance_name="some_ins",
+ type="slave",
+ db_type="mssql",
+ host="some_host",
+ port=1366,
+ user="ins_user",
+ password="some_str",
+ )
cls.ins1.save()
cls.engine = MssqlEngine(instance=cls.ins1)
cls.wf = SqlWorkflow.objects.create(
- workflow_name='some_name',
+ workflow_name="some_name",
group_id=1,
- group_name='g1',
- engineer_display='',
- audit_auth_groups='some_group',
+ group_name="g1",
+ engineer_display="",
+ audit_auth_groups="some_group",
create_time=datetime.now() - timedelta(days=1),
- status='workflow_finish',
+ status="workflow_finish",
is_backup=True,
instance=cls.ins1,
- db_name='some_db',
- syntax_type=1
+ db_name="some_db",
+ syntax_type=1,
+ )
+ SqlWorkflowContent.objects.create(
+ workflow=cls.wf, sql_content="insert into some_tb values (1)"
)
- SqlWorkflowContent.objects.create(workflow=cls.wf, sql_content='insert into some_tb values (1)')
@classmethod
def tearDownClass(cls):
@@ -103,91 +116,96 @@ def tearDownClass(cls):
cls.wf.delete()
SqlWorkflowContent.objects.all().delete()
- @patch('sql.engines.mssql.pyodbc.connect')
+ @patch("sql.engines.mssql.pyodbc.connect")
def testGetConnection(self, connect):
new_engine = MssqlEngine(instance=self.ins1)
new_engine.get_connection()
connect.assert_called_once()
- @patch('sql.engines.mssql.pyodbc.connect')
+ @patch("sql.engines.mssql.pyodbc.connect")
def testQuery(self, connect):
cur = Mock()
connect.return_value.cursor = cur
cur.return_value.execute = Mock()
- cur.return_value.fetchmany.return_value = (('v1', 'v2'),)
- cur.return_value.description = (('k1', 'some_other_des'), ('k2', 'some_other_des'))
+ cur.return_value.fetchmany.return_value = (("v1", "v2"),)
+ cur.return_value.description = (
+ ("k1", "some_other_des"),
+ ("k2", "some_other_des"),
+ )
new_engine = MssqlEngine(instance=self.ins1)
- query_result = new_engine.query(sql='some_str', limit_num=100)
+ query_result = new_engine.query(sql="some_str", limit_num=100)
cur.return_value.execute.assert_called()
cur.return_value.fetchmany.assert_called_once_with(100)
connect.return_value.close.assert_called_once()
self.assertIsInstance(query_result, ResultSet)
- @patch.object(MssqlEngine, 'query')
+ @patch.object(MssqlEngine, "query")
def testAllDb(self, mock_query):
db_result = ResultSet()
- db_result.rows = [('db_1',), ('db_2',)]
+ db_result.rows = [("db_1",), ("db_2",)]
mock_query.return_value = db_result
new_engine = MssqlEngine(instance=self.ins1)
dbs = new_engine.get_all_databases()
- self.assertEqual(dbs.rows, ['db_1', 'db_2'])
+ self.assertEqual(dbs.rows, ["db_1", "db_2"])
- @patch.object(MssqlEngine, 'query')
+ @patch.object(MssqlEngine, "query")
def testAllTables(self, mock_query):
table_result = ResultSet()
- table_result.rows = [('tb_1', 'some_des'), ('tb_2', 'some_des')]
+ table_result.rows = [("tb_1", "some_des"), ("tb_2", "some_des")]
mock_query.return_value = table_result
new_engine = MssqlEngine(instance=self.ins1)
- tables = new_engine.get_all_tables('some_db')
- mock_query.assert_called_once_with(db_name='some_db', sql=ANY)
- self.assertEqual(tables.rows, ['tb_1', 'tb_2'])
+ tables = new_engine.get_all_tables("some_db")
+ mock_query.assert_called_once_with(db_name="some_db", sql=ANY)
+ self.assertEqual(tables.rows, ["tb_1", "tb_2"])
- @patch.object(MssqlEngine, 'query')
+ @patch.object(MssqlEngine, "query")
def testAllColumns(self, mock_query):
db_result = ResultSet()
- db_result.rows = [('col_1', 'type'), ('col_2', 'type2')]
+ db_result.rows = [("col_1", "type"), ("col_2", "type2")]
mock_query.return_value = db_result
new_engine = MssqlEngine(instance=self.ins1)
- dbs = new_engine.get_all_columns_by_tb('some_db', 'some_tb')
- self.assertEqual(dbs.rows, ['col_1', 'col_2'])
+ dbs = new_engine.get_all_columns_by_tb("some_db", "some_tb")
+ self.assertEqual(dbs.rows, ["col_1", "col_2"])
- @patch.object(MssqlEngine, 'query')
+ @patch.object(MssqlEngine, "query")
def testDescribe(self, mock_query):
new_engine = MssqlEngine(instance=self.ins1)
- new_engine.describe_table('some_db', 'some_db')
+ new_engine.describe_table("some_db", "some_db")
mock_query.assert_called_once()
def testQueryCheck(self):
new_engine = MssqlEngine(instance=self.ins1)
# 只抽查一个函数
- banned_sql = 'select concat(phone,1) from user_table'
- check_result = new_engine.query_check(db_name='some_db', sql=banned_sql)
- self.assertTrue(check_result.get('bad_query'))
- banned_sql = 'select phone from user_table where phone=concat(phone,1)'
- check_result = new_engine.query_check(db_name='some_db', sql=banned_sql)
- self.assertTrue(check_result.get('bad_query'))
+ banned_sql = "select concat(phone,1) from user_table"
+ check_result = new_engine.query_check(db_name="some_db", sql=banned_sql)
+ self.assertTrue(check_result.get("bad_query"))
+ banned_sql = "select phone from user_table where phone=concat(phone,1)"
+ check_result = new_engine.query_check(db_name="some_db", sql=banned_sql)
+ self.assertTrue(check_result.get("bad_query"))
sp_sql = "sp_helptext '[SomeName].[SomeAction]'"
- check_result = new_engine.query_check(db_name='some_db', sql=sp_sql)
- self.assertFalse(check_result.get('bad_query'))
- self.assertEqual(check_result.get('filtered_sql'), sp_sql)
+ check_result = new_engine.query_check(db_name="some_db", sql=sp_sql)
+ self.assertFalse(check_result.get("bad_query"))
+ self.assertEqual(check_result.get("filtered_sql"), sp_sql)
def test_filter_sql(self):
new_engine = MssqlEngine(instance=self.ins1)
# 只抽查一个函数
- banned_sql = 'select user from user_table'
+ banned_sql = "select user from user_table"
check_result = new_engine.filter_sql(sql=banned_sql, limit_num=10)
self.assertEqual(check_result, "select top 10 user from user_table")
def test_execute_check(self):
new_engine = MssqlEngine(instance=self.ins1)
- test_sql = 'use database\ngo\nsome sql1\nGO\nsome sql2\n\r\nGo\nsome sql3\n\r\ngO\n'
+ test_sql = (
+ "use database\ngo\nsome sql1\nGO\nsome sql2\n\r\nGo\nsome sql3\n\r\ngO\n"
+ )
check_result = new_engine.execute_check(db_name=None, sql=test_sql)
self.assertIsInstance(check_result, ReviewSet)
- self.assertEqual(check_result.rows[1].__dict__['sql'], "use database\n")
- self.assertEqual(check_result.rows[2].__dict__['sql'], "\nsome sql1\n")
- self.assertEqual(check_result.rows[4].__dict__['sql'], "\nsome sql3\n\r\n")
+ self.assertEqual(check_result.rows[1].__dict__["sql"], "use database\n")
+ self.assertEqual(check_result.rows[2].__dict__["sql"], "\nsome sql1\n")
+ self.assertEqual(check_result.rows[4].__dict__["sql"], "\nsome sql3\n\r\n")
- @patch('sql.engines.mssql.MssqlEngine.execute')
+ @patch("sql.engines.mssql.MssqlEngine.execute")
def test_execute_workflow(self, mock_execute):
mock_execute.return_value.error = None
new_engine = MssqlEngine(instance=self.ins1)
@@ -196,48 +214,56 @@ def test_execute_workflow(self, mock_execute):
mock_execute.assert_called()
self.assertEqual(1, mock_execute.call_count)
- @patch('sql.engines.mssql.MssqlEngine.get_connection')
+ @patch("sql.engines.mssql.MssqlEngine.get_connection")
def test_execute(self, mock_connect):
mock_cursor = Mock()
mock_connect.return_value.cursor = mock_cursor
new_engine = MssqlEngine(instance=self.ins1)
- execute_result = new_engine.execute('some_db', 'some_sql')
+ execute_result = new_engine.execute("some_db", "some_sql")
# 验证结果, 无异常
self.assertIsNone(execute_result.error)
- self.assertEqual('some_sql', execute_result.full_sql)
+ self.assertEqual("some_sql", execute_result.full_sql)
self.assertEqual(2, len(execute_result.rows))
mock_cursor.return_value.execute.assert_called()
mock_cursor.return_value.commit.assert_called()
mock_cursor.reset_mock()
# 验证异常
- mock_cursor.return_value.execute.side_effect = Exception('Boom! some exception!')
- execute_result = new_engine.execute('some_db', 'some_sql')
- self.assertIn('Boom! some exception!', execute_result.error)
- self.assertEqual('some_sql', execute_result.full_sql)
+ mock_cursor.return_value.execute.side_effect = Exception(
+ "Boom! some exception!"
+ )
+ execute_result = new_engine.execute("some_db", "some_sql")
+ self.assertIn("Boom! some exception!", execute_result.error)
+ self.assertEqual("some_sql", execute_result.full_sql)
self.assertEqual(2, len(execute_result.rows))
mock_cursor.return_value.commit.assert_not_called()
mock_cursor.return_value.rollback.assert_called()
class TestMysql(TestCase):
-
def setUp(self):
- self.ins1 = Instance(instance_name='some_ins', type='slave', db_type='mysql', host='some_host',
- port=1366, user='ins_user', password='some_str')
+ self.ins1 = Instance(
+ instance_name="some_ins",
+ type="slave",
+ db_type="mysql",
+ host="some_host",
+ port=1366,
+ user="ins_user",
+ password="some_str",
+ )
self.ins1.save()
self.sys_config = SysConfig()
self.wf = SqlWorkflow.objects.create(
- workflow_name='some_name',
+ workflow_name="some_name",
group_id=1,
- group_name='g1',
- engineer_display='',
- audit_auth_groups='some_group',
+ group_name="g1",
+ engineer_display="",
+ audit_auth_groups="some_group",
create_time=datetime.now() - timedelta(days=1),
- status='workflow_finish',
+ status="workflow_finish",
is_backup=True,
instance=self.ins1,
- db_name='some_db',
- syntax_type=1
+ db_name="some_db",
+ syntax_type=1,
)
SqlWorkflowContent.objects.create(workflow=self.wf)
@@ -247,359 +273,408 @@ def tearDown(self):
SqlWorkflow.objects.all().delete()
SqlWorkflowContent.objects.all().delete()
- @patch('MySQLdb.connect')
+ @patch("MySQLdb.connect")
def test_engine_base_info(self, _conn):
new_engine = MysqlEngine(instance=self.ins1)
- self.assertEqual(new_engine.name, 'MySQL')
- self.assertEqual(new_engine.info, 'MySQL engine')
+ self.assertEqual(new_engine.name, "MySQL")
+ self.assertEqual(new_engine.info, "MySQL engine")
- @patch('MySQLdb.connect')
+ @patch("MySQLdb.connect")
def testGetConnection(self, connect):
new_engine = MysqlEngine(instance=self.ins1)
new_engine.get_connection()
connect.assert_called_once()
- @patch('MySQLdb.connect')
+ @patch("MySQLdb.connect")
def testQuery(self, connect):
cur = Mock()
connect.return_value.cursor = cur
cur.return_value.execute = Mock()
- cur.return_value.fetchmany.return_value = (('v1', 'v2'),)
- cur.return_value.description = (('k1', 'some_other_des'), ('k2', 'some_other_des'))
+ cur.return_value.fetchmany.return_value = (("v1", "v2"),)
+ cur.return_value.description = (
+ ("k1", "some_other_des"),
+ ("k2", "some_other_des"),
+ )
new_engine = MysqlEngine(instance=self.ins1)
- query_result = new_engine.query(sql='some_str', limit_num=100)
+ query_result = new_engine.query(sql="some_str", limit_num=100)
cur.return_value.execute.assert_called()
cur.return_value.fetchmany.assert_called_once_with(size=100)
connect.return_value.close.assert_called_once()
self.assertIsInstance(query_result, ResultSet)
- @patch.object(MysqlEngine, 'query')
+ @patch.object(MysqlEngine, "query")
def testAllDb(self, mock_query):
db_result = ResultSet()
- db_result.rows = [('db_1',), ('db_2',)]
+ db_result.rows = [("db_1",), ("db_2",)]
mock_query.return_value = db_result
new_engine = MysqlEngine(instance=self.ins1)
dbs = new_engine.get_all_databases()
- self.assertEqual(dbs.rows, ['db_1', 'db_2'])
+ self.assertEqual(dbs.rows, ["db_1", "db_2"])
- @patch.object(MysqlEngine, 'query')
+ @patch.object(MysqlEngine, "query")
def testAllTables(self, mock_query):
table_result = ResultSet()
- table_result.rows = [('tb_1', 'some_des'), ('tb_2', 'some_des')]
+ table_result.rows = [("tb_1", "some_des"), ("tb_2", "some_des")]
mock_query.return_value = table_result
new_engine = MysqlEngine(instance=self.ins1)
- tables = new_engine.get_all_tables('some_db')
- mock_query.assert_called_once_with(db_name='some_db', sql=ANY)
- self.assertEqual(tables.rows, ['tb_1', 'tb_2'])
+ tables = new_engine.get_all_tables("some_db")
+ mock_query.assert_called_once_with(db_name="some_db", sql=ANY)
+ self.assertEqual(tables.rows, ["tb_1", "tb_2"])
- @patch.object(MysqlEngine, 'query')
+ @patch.object(MysqlEngine, "query")
def testAllColumns(self, mock_query):
db_result = ResultSet()
- db_result.rows = [('col_1', 'type'), ('col_2', 'type2')]
+ db_result.rows = [("col_1", "type"), ("col_2", "type2")]
mock_query.return_value = db_result
new_engine = MysqlEngine(instance=self.ins1)
- dbs = new_engine.get_all_columns_by_tb('some_db', 'some_tb')
- self.assertEqual(dbs.rows, ['col_1', 'col_2'])
+ dbs = new_engine.get_all_columns_by_tb("some_db", "some_tb")
+ self.assertEqual(dbs.rows, ["col_1", "col_2"])
- @patch.object(MysqlEngine, 'query')
+ @patch.object(MysqlEngine, "query")
def testDescribe(self, mock_query):
new_engine = MysqlEngine(instance=self.ins1)
- new_engine.describe_table('some_db', 'some_db')
+ new_engine.describe_table("some_db", "some_db")
mock_query.assert_called_once()
def testQueryCheck(self):
new_engine = MysqlEngine(instance=self.ins1)
- sql_without_limit = '-- 测试\n select user from usertable'
- check_result = new_engine.query_check(db_name='some_db', sql=sql_without_limit)
- self.assertEqual(check_result['filtered_sql'], 'select user from usertable')
+ sql_without_limit = "-- 测试\n select user from usertable"
+ check_result = new_engine.query_check(db_name="some_db", sql=sql_without_limit)
+ self.assertEqual(check_result["filtered_sql"], "select user from usertable")
def test_query_check_wrong_sql(self):
new_engine = MysqlEngine(instance=self.ins1)
- wrong_sql = '-- 测试'
- check_result = new_engine.query_check(db_name='some_db', sql=wrong_sql)
- self.assertDictEqual(check_result,
- {'msg': '不支持的查询语法类型!', 'bad_query': True, 'filtered_sql': '-- 测试', 'has_star': False})
+ wrong_sql = "-- 测试"
+ check_result = new_engine.query_check(db_name="some_db", sql=wrong_sql)
+ self.assertDictEqual(
+ check_result,
+ {
+ "msg": "不支持的查询语法类型!",
+ "bad_query": True,
+ "filtered_sql": "-- 测试",
+ "has_star": False,
+ },
+ )
def test_query_check_update_sql(self):
new_engine = MysqlEngine(instance=self.ins1)
- update_sql = 'update user set id=0'
- check_result = new_engine.query_check(db_name='some_db', sql=update_sql)
- self.assertDictEqual(check_result,
- {'msg': '不支持的查询语法类型!', 'bad_query': True, 'filtered_sql': 'update user set id=0',
- 'has_star': False})
+ update_sql = "update user set id=0"
+ check_result = new_engine.query_check(db_name="some_db", sql=update_sql)
+ self.assertDictEqual(
+ check_result,
+ {
+ "msg": "不支持的查询语法类型!",
+ "bad_query": True,
+ "filtered_sql": "update user set id=0",
+ "has_star": False,
+ },
+ )
def test_filter_sql_with_delimiter(self):
new_engine = MysqlEngine(instance=self.ins1)
- sql_without_limit = 'select user from usertable;'
+ sql_without_limit = "select user from usertable;"
check_result = new_engine.filter_sql(sql=sql_without_limit, limit_num=100)
- self.assertEqual(check_result, 'select user from usertable limit 100;')
+ self.assertEqual(check_result, "select user from usertable limit 100;")
def test_filter_sql_without_delimiter(self):
new_engine = MysqlEngine(instance=self.ins1)
- sql_without_limit = 'select user from usertable'
+ sql_without_limit = "select user from usertable"
check_result = new_engine.filter_sql(sql=sql_without_limit, limit_num=100)
- self.assertEqual(check_result, 'select user from usertable limit 100;')
+ self.assertEqual(check_result, "select user from usertable limit 100;")
def test_filter_sql_with_limit(self):
new_engine = MysqlEngine(instance=self.ins1)
- sql_without_limit = 'select user from usertable limit 10'
+ sql_without_limit = "select user from usertable limit 10"
check_result = new_engine.filter_sql(sql=sql_without_limit, limit_num=1)
- self.assertEqual(check_result, 'select user from usertable limit 1;')
+ self.assertEqual(check_result, "select user from usertable limit 1;")
def test_filter_sql_with_limit_min(self):
new_engine = MysqlEngine(instance=self.ins1)
- sql_without_limit = 'select user from usertable limit 10'
+ sql_without_limit = "select user from usertable limit 10"
check_result = new_engine.filter_sql(sql=sql_without_limit, limit_num=100)
- self.assertEqual(check_result, 'select user from usertable limit 10;')
+ self.assertEqual(check_result, "select user from usertable limit 10;")
def test_filter_sql_with_limit_offset(self):
new_engine = MysqlEngine(instance=self.ins1)
- sql_without_limit = 'select user from usertable limit 10 offset 100'
+ sql_without_limit = "select user from usertable limit 10 offset 100"
check_result = new_engine.filter_sql(sql=sql_without_limit, limit_num=1)
- self.assertEqual(check_result, 'select user from usertable limit 1 offset 100;')
+ self.assertEqual(check_result, "select user from usertable limit 1 offset 100;")
def test_filter_sql_with_limit_nn(self):
new_engine = MysqlEngine(instance=self.ins1)
- sql_without_limit = 'select user from usertable limit 10, 100'
+ sql_without_limit = "select user from usertable limit 10, 100"
check_result = new_engine.filter_sql(sql=sql_without_limit, limit_num=1)
- self.assertEqual(check_result, 'select user from usertable limit 10,1;')
+ self.assertEqual(check_result, "select user from usertable limit 10,1;")
def test_filter_sql_upper(self):
new_engine = MysqlEngine(instance=self.ins1)
- sql_without_limit = 'SELECT USER FROM usertable LIMIT 10, 100'
+ sql_without_limit = "SELECT USER FROM usertable LIMIT 10, 100"
check_result = new_engine.filter_sql(sql=sql_without_limit, limit_num=1)
- self.assertEqual(check_result, 'SELECT USER FROM usertable limit 10,1;')
+ self.assertEqual(check_result, "SELECT USER FROM usertable limit 10,1;")
def test_filter_sql_not_select(self):
new_engine = MysqlEngine(instance=self.ins1)
- sql_without_limit = 'show create table usertable;'
+ sql_without_limit = "show create table usertable;"
check_result = new_engine.filter_sql(sql=sql_without_limit, limit_num=1)
- self.assertEqual(check_result, 'show create table usertable;')
+ self.assertEqual(check_result, "show create table usertable;")
- @patch('sql.engines.mysql.data_masking', return_value=ResultSet())
+ @patch("sql.engines.mysql.data_masking", return_value=ResultSet())
def test_query_masking(self, _data_masking):
query_result = ResultSet()
new_engine = MysqlEngine(instance=self.ins1)
- masking_result = new_engine.query_masking(db_name='archery', sql='select 1', resultset=query_result)
+ masking_result = new_engine.query_masking(
+ db_name="archery", sql="select 1", resultset=query_result
+ )
self.assertIsInstance(masking_result, ResultSet)
- @patch('sql.engines.mysql.data_masking', return_value=ResultSet())
+ @patch("sql.engines.mysql.data_masking", return_value=ResultSet())
def test_query_masking_not_select(self, _data_masking):
query_result = ResultSet()
new_engine = MysqlEngine(instance=self.ins1)
- masking_result = new_engine.query_masking(db_name='archery', sql='explain select 1', resultset=query_result)
+ masking_result = new_engine.query_masking(
+ db_name="archery", sql="explain select 1", resultset=query_result
+ )
self.assertEqual(masking_result, query_result)
- @patch('sql.engines.mysql.GoInceptionEngine')
+ @patch("sql.engines.mysql.GoInceptionEngine")
def test_execute_check_select_sql(self, _inception_engine):
- self.sys_config.set('goinception', 'true')
- sql = 'select * from user'
- inc_row = ReviewResult(id=1,
- errlevel=0,
- stagestatus='Audit completed',
- errormessage='None',
- sql=sql,
- affected_rows=0,
- execute_time='', )
- row = ReviewResult(id=1, errlevel=2,
- stagestatus='驳回不支持语句',
- errormessage='仅支持DML和DDL语句,查询语句请使用SQL查询功能!',
- sql=sql)
- _inception_engine.return_value.execute_check.return_value = ReviewSet(full_sql=sql, rows=[inc_row])
+ self.sys_config.set("goinception", "true")
+ sql = "select * from user"
+ inc_row = ReviewResult(
+ id=1,
+ errlevel=0,
+ stagestatus="Audit completed",
+ errormessage="None",
+ sql=sql,
+ affected_rows=0,
+ execute_time="",
+ )
+ row = ReviewResult(
+ id=1,
+ errlevel=2,
+ stagestatus="驳回不支持语句",
+ errormessage="仅支持DML和DDL语句,查询语句请使用SQL查询功能!",
+ sql=sql,
+ )
+ _inception_engine.return_value.execute_check.return_value = ReviewSet(
+ full_sql=sql, rows=[inc_row]
+ )
new_engine = MysqlEngine(instance=self.ins1)
- check_result = new_engine.execute_check(db_name='archery', sql=sql)
+ check_result = new_engine.execute_check(db_name="archery", sql=sql)
self.assertIsInstance(check_result, ReviewSet)
self.assertEqual(check_result.rows[0].__dict__, row.__dict__)
- @patch('sql.engines.mysql.GoInceptionEngine')
+ @patch("sql.engines.mysql.GoInceptionEngine")
def test_execute_check_critical_sql(self, _inception_engine):
- self.sys_config.set('goinception', 'true')
- self.sys_config.set('critical_ddl_regex', '^|update')
+ self.sys_config.set("goinception", "true")
+ self.sys_config.set("critical_ddl_regex", "^|update")
self.sys_config.get_all_config()
- sql = 'update user set id=1'
- inc_row = ReviewResult(id=1,
- errlevel=0,
- stagestatus='Audit completed',
- errormessage='None',
- sql=sql,
- affected_rows=0,
- execute_time='', )
- row = ReviewResult(id=1, errlevel=2,
- stagestatus='驳回高危SQL',
- errormessage='禁止提交匹配' + '^|update' + '条件的语句!',
- sql=sql)
- _inception_engine.return_value.execute_check.return_value = ReviewSet(full_sql=sql, rows=[inc_row])
+ sql = "update user set id=1"
+ inc_row = ReviewResult(
+ id=1,
+ errlevel=0,
+ stagestatus="Audit completed",
+ errormessage="None",
+ sql=sql,
+ affected_rows=0,
+ execute_time="",
+ )
+ row = ReviewResult(
+ id=1,
+ errlevel=2,
+ stagestatus="驳回高危SQL",
+ errormessage="禁止提交匹配" + "^|update" + "条件的语句!",
+ sql=sql,
+ )
+ _inception_engine.return_value.execute_check.return_value = ReviewSet(
+ full_sql=sql, rows=[inc_row]
+ )
new_engine = MysqlEngine(instance=self.ins1)
- check_result = new_engine.execute_check(db_name='archery', sql=sql)
+ check_result = new_engine.execute_check(db_name="archery", sql=sql)
self.assertIsInstance(check_result, ReviewSet)
self.assertEqual(check_result.rows[0].__dict__, row.__dict__)
- @patch('sql.engines.mysql.GoInceptionEngine')
+ @patch("sql.engines.mysql.GoInceptionEngine")
def test_execute_check_normal_sql(self, _inception_engine):
- self.sys_config.set('goinception', 'true')
- sql = 'update user set id=1'
- row = ReviewResult(id=1,
- errlevel=0,
- stagestatus='Audit completed',
- errormessage='None',
- sql=sql,
- affected_rows=0,
- execute_time=0, )
- _inception_engine.return_value.execute_check.return_value = ReviewSet(full_sql=sql, rows=[row])
+ self.sys_config.set("goinception", "true")
+ sql = "update user set id=1"
+ row = ReviewResult(
+ id=1,
+ errlevel=0,
+ stagestatus="Audit completed",
+ errormessage="None",
+ sql=sql,
+ affected_rows=0,
+ execute_time=0,
+ )
+ _inception_engine.return_value.execute_check.return_value = ReviewSet(
+ full_sql=sql, rows=[row]
+ )
new_engine = MysqlEngine(instance=self.ins1)
- check_result = new_engine.execute_check(db_name='archery', sql=sql)
+ check_result = new_engine.execute_check(db_name="archery", sql=sql)
self.assertIsInstance(check_result, ReviewSet)
self.assertEqual(check_result.rows[0].__dict__, row.__dict__)
- @patch('sql.engines.mysql.GoInceptionEngine')
+ @patch("sql.engines.mysql.GoInceptionEngine")
def test_execute_check_normal_sql_with_Exception(self, _inception_engine):
- sql = 'update user set id=1'
+ sql = "update user set id=1"
_inception_engine.return_value.execute_check.side_effect = RuntimeError()
new_engine = MysqlEngine(instance=self.ins1)
with self.assertRaises(RuntimeError):
new_engine.execute_check(db_name=0, sql=sql)
- @patch.object(MysqlEngine, 'query')
- @patch('sql.engines.mysql.GoInceptionEngine')
+ @patch.object(MysqlEngine, "query")
+ @patch("sql.engines.mysql.GoInceptionEngine")
def test_execute_workflow(self, _inception_engine, _query):
- self.sys_config.set('goinception', 'true')
- sql = 'update user set id=1'
+ self.sys_config.set("goinception", "true")
+ sql = "update user set id=1"
_inception_engine.return_value.execute.return_value = ReviewSet(full_sql=sql)
- _query.return_value.rows = (('0',),)
+ _query.return_value.rows = (("0",),)
new_engine = MysqlEngine(instance=self.ins1)
execute_result = new_engine.execute_workflow(self.wf)
self.assertIsInstance(execute_result, ReviewSet)
- @patch('MySQLdb.connect.cursor.execute')
- @patch('MySQLdb.connect.cursor')
- @patch('MySQLdb.connect')
+ @patch("MySQLdb.connect.cursor.execute")
+ @patch("MySQLdb.connect.cursor")
+ @patch("MySQLdb.connect")
def test_execute(self, _connect, _cursor, _execute):
new_engine = MysqlEngine(instance=self.ins1)
execute_result = new_engine.execute(self.wf)
self.assertIsInstance(execute_result, ResultSet)
- @patch('MySQLdb.connect')
+ @patch("MySQLdb.connect")
def test_server_version(self, _connect):
- _connect.return_value.get_server_info.return_value = '5.7.20-16log'
+ _connect.return_value.get_server_info.return_value = "5.7.20-16log"
new_engine = MysqlEngine(instance=self.ins1)
server_version = new_engine.server_version
self.assertTupleEqual(server_version, (5, 7, 20))
- @patch.object(MysqlEngine, 'query')
+ @patch.object(MysqlEngine, "query")
def test_get_variables_not_filter(self, _query):
new_engine = MysqlEngine(instance=self.ins1)
new_engine.get_variables()
_query.assert_called_once()
- @patch('MySQLdb.connect')
- @patch.object(MysqlEngine, 'query')
+ @patch("MySQLdb.connect")
+ @patch.object(MysqlEngine, "query")
def test_get_variables_filter(self, _query, _connect):
- _connect.return_value.get_server_info.return_value = '5.7.20-16log'
+ _connect.return_value.get_server_info.return_value = "5.7.20-16log"
new_engine = MysqlEngine(instance=self.ins1)
- new_engine.get_variables(variables=['binlog_format'])
+ new_engine.get_variables(variables=["binlog_format"])
_query.assert_called()
- @patch.object(MysqlEngine, 'query')
+ @patch.object(MysqlEngine, "query")
def test_set_variable(self, _query):
new_engine = MysqlEngine(instance=self.ins1)
- new_engine.set_variable('binlog_format', 'ROW')
+ new_engine.set_variable("binlog_format", "ROW")
_query.assert_called_once_with(sql="set global binlog_format=ROW;")
- @patch('sql.engines.mysql.GoInceptionEngine')
+ @patch("sql.engines.mysql.GoInceptionEngine")
def test_osc_go_inception(self, _inception_engine):
- self.sys_config.set('goinception', 'false')
+ self.sys_config.set("goinception", "false")
_inception_engine.return_value.osc_control.return_value = ReviewSet()
- command = 'get'
- sqlsha1 = 'xxxxx'
+ command = "get"
+ sqlsha1 = "xxxxx"
new_engine = MysqlEngine(instance=self.ins1)
new_engine.osc_control(sqlsha1=sqlsha1, command=command)
- @patch('sql.engines.mysql.GoInceptionEngine')
+ @patch("sql.engines.mysql.GoInceptionEngine")
def test_osc_inception(self, _inception_engine):
- self.sys_config.set('goinception', 'true')
+ self.sys_config.set("goinception", "true")
_inception_engine.return_value.osc_control.return_value = ReviewSet()
- command = 'get'
- sqlsha1 = 'xxxxx'
+ command = "get"
+ sqlsha1 = "xxxxx"
new_engine = MysqlEngine(instance=self.ins1)
new_engine.osc_control(sqlsha1=sqlsha1, command=command)
- @patch.object(MysqlEngine, 'query')
+ @patch.object(MysqlEngine, "query")
def test_kill_connection(self, _query):
new_engine = MysqlEngine(instance=self.ins1)
new_engine.kill_connection(100)
_query.assert_called_once_with(sql="kill 100")
- @patch.object(MysqlEngine, 'query')
+ @patch.object(MysqlEngine, "query")
def test_seconds_behind_master(self, _query):
new_engine = MysqlEngine(instance=self.ins1)
new_engine.seconds_behind_master
- _query.assert_called_once_with(sql="show slave status", close_conn=False,
- cursorclass=MySQLdb.cursors.DictCursor)
-
- @patch.object(MysqlEngine, 'query')
+ _query.assert_called_once_with(
+ sql="show slave status",
+ close_conn=False,
+ cursorclass=MySQLdb.cursors.DictCursor,
+ )
+
+ @patch.object(MysqlEngine, "query")
def test_processlist(self, _query):
new_engine = MysqlEngine(instance=self.ins1)
_query.return_value = ResultSet()
- for command_type in [ 'Query', 'All', 'Not Sleep']:
+ for command_type in ["Query", "All", "Not Sleep"]:
r = new_engine.processlist(command_type)
self.assertIsInstance(r, ResultSet)
-
- @patch.object(MysqlEngine, 'query')
+
+ @patch.object(MysqlEngine, "query")
def test_get_kill_command(self, _query):
new_engine = MysqlEngine(instance=self.ins1)
- _query.return_value.rows = (('kill 1;',),('kill 2;',))
- r = new_engine.get_kill_command([1,2])
- self.assertEqual(r, 'kill 1;kill 2;')
-
- @patch('MySQLdb.connect.cursor.execute')
- @patch('MySQLdb.connect.cursor')
- @patch('MySQLdb.connect')
- @patch.object(MysqlEngine, 'query')
+ _query.return_value.rows = (("kill 1;",), ("kill 2;",))
+ r = new_engine.get_kill_command([1, 2])
+ self.assertEqual(r, "kill 1;kill 2;")
+
+ @patch("MySQLdb.connect.cursor.execute")
+ @patch("MySQLdb.connect.cursor")
+ @patch("MySQLdb.connect")
+ @patch.object(MysqlEngine, "query")
def test_kill(self, _query, _connect, _cursor, _execute):
new_engine = MysqlEngine(instance=self.ins1)
- _query.return_value.rows = (('kill 1;',),('kill 2;',))
- _execute.return_value = ResultSet()
- r = new_engine.kill([1,2])
+ _query.return_value.rows = (("kill 1;",), ("kill 2;",))
+ _execute.return_value = ResultSet()
+ r = new_engine.kill([1, 2])
self.assertIsInstance(r, ResultSet)
-
- @patch.object(MysqlEngine, 'query')
+
+ @patch.object(MysqlEngine, "query")
def test_tablesapce(self, _query):
new_engine = MysqlEngine(instance=self.ins1)
_query.return_value = ResultSet()
r = new_engine.tablesapce()
self.assertIsInstance(r, ResultSet)
-
- @patch.object(MysqlEngine, 'query')
+
+ @patch.object(MysqlEngine, "query")
def test_tablesapce_num(self, _query):
new_engine = MysqlEngine(instance=self.ins1)
_query.return_value = ResultSet()
r = new_engine.tablesapce_num()
self.assertIsInstance(r, ResultSet)
-
- @patch.object(MysqlEngine, 'query')
- @patch('MySQLdb.connect')
+
+ @patch.object(MysqlEngine, "query")
+ @patch("MySQLdb.connect")
def test_trxandlocks(self, _connect, _query):
new_engine = MysqlEngine(instance=self.ins1)
_connect.return_value = Mock()
- for v in ['5.7.0','8.0.1']:
+ for v in ["5.7.0", "8.0.1"]:
_connect.return_value.get_server_info.return_value = v
_query.return_value = ResultSet()
r = new_engine.trxandlocks()
self.assertIsInstance(r, ResultSet)
- @patch.object(MysqlEngine, 'query')
+ @patch.object(MysqlEngine, "query")
def test_get_long_transaction(self, _query):
new_engine = MysqlEngine(instance=self.ins1)
_query.return_value = ResultSet()
r = new_engine.get_long_transaction()
self.assertIsInstance(r, ResultSet)
-
+
class TestRedis(TestCase):
@classmethod
def setUpClass(cls):
- cls.ins = Instance(instance_name='some_ins', type='slave', db_type='redis', mode='standalone',
- host='some_host', port=1366, user='ins_user', password='some_str')
+ cls.ins = Instance(
+ instance_name="some_ins",
+ type="slave",
+ db_type="redis",
+ mode="standalone",
+ host="some_host",
+ port=1366,
+ user="ins_user",
+ password="some_str",
+ )
cls.ins.save()
@classmethod
@@ -608,107 +683,127 @@ def tearDownClass(cls):
SqlWorkflow.objects.all().delete()
SqlWorkflowContent.objects.all().delete()
- @patch('redis.Redis')
+ @patch("redis.Redis")
def test_engine_base_info(self, _conn):
new_engine = RedisEngine(instance=self.ins)
- self.assertEqual(new_engine.name, 'Redis')
- self.assertEqual(new_engine.info, 'Redis engine')
+ self.assertEqual(new_engine.name, "Redis")
+ self.assertEqual(new_engine.info, "Redis engine")
- @patch('redis.Redis')
+ @patch("redis.Redis")
def test_get_connection(self, _conn):
new_engine = RedisEngine(instance=self.ins)
new_engine.get_connection()
_conn.assert_called_once()
- @patch('redis.Redis.execute_command', return_value=[1, 2, 3])
+ @patch("redis.Redis.execute_command", return_value=[1, 2, 3])
def test_query_return_list(self, _execute_command):
new_engine = RedisEngine(instance=self.ins)
- query_result = new_engine.query(db_name=0, sql='keys *', limit_num=100)
+ query_result = new_engine.query(db_name=0, sql="keys *", limit_num=100)
self.assertIsInstance(query_result, ResultSet)
self.assertTupleEqual(query_result.rows, ([1], [2], [3]))
- @patch('redis.Redis.execute_command', return_value='text')
+ @patch("redis.Redis.execute_command", return_value="text")
def test_query_return_str(self, _execute_command):
new_engine = RedisEngine(instance=self.ins)
- query_result = new_engine.query(db_name=0, sql='keys *', limit_num=100)
+ query_result = new_engine.query(db_name=0, sql="keys *", limit_num=100)
self.assertIsInstance(query_result, ResultSet)
- self.assertTupleEqual(query_result.rows, (['text'],))
+ self.assertTupleEqual(query_result.rows, (["text"],))
- @patch('redis.Redis.execute_command', return_value='text')
+ @patch("redis.Redis.execute_command", return_value="text")
def test_query_execute(self, _execute_command):
new_engine = RedisEngine(instance=self.ins)
- query_result = new_engine.query(db_name=0, sql='keys *', limit_num=100)
+ query_result = new_engine.query(db_name=0, sql="keys *", limit_num=100)
self.assertIsInstance(query_result, ResultSet)
- self.assertTupleEqual(query_result.rows, (['text'],))
+ self.assertTupleEqual(query_result.rows, (["text"],))
- @patch('redis.Redis.config_get', return_value={"databases": 4})
+ @patch("redis.Redis.config_get", return_value={"databases": 4})
def test_get_all_databases(self, _config_get):
new_engine = RedisEngine(instance=self.ins)
dbs = new_engine.get_all_databases()
- self.assertListEqual(dbs.rows, ['0', '1', '2', '3'])
+ self.assertListEqual(dbs.rows, ["0", "1", "2", "3"])
def test_query_check_safe_cmd(self):
safe_cmd = "keys 1*"
new_engine = RedisEngine(instance=self.ins)
check_result = new_engine.query_check(db_name=0, sql=safe_cmd)
- self.assertDictEqual(check_result,
- {'msg': '禁止执行该命令!', 'bad_query': True, 'filtered_sql': safe_cmd, 'has_star': False})
+ self.assertDictEqual(
+ check_result,
+ {
+ "msg": "禁止执行该命令!",
+ "bad_query": True,
+ "filtered_sql": safe_cmd,
+ "has_star": False,
+ },
+ )
def test_query_check_danger_cmd(self):
safe_cmd = "keys *"
new_engine = RedisEngine(instance=self.ins)
check_result = new_engine.query_check(db_name=0, sql=safe_cmd)
- self.assertDictEqual(check_result,
- {'msg': '禁止执行该命令!', 'bad_query': True, 'filtered_sql': safe_cmd, 'has_star': False})
+ self.assertDictEqual(
+ check_result,
+ {
+ "msg": "禁止执行该命令!",
+ "bad_query": True,
+ "filtered_sql": safe_cmd,
+ "has_star": False,
+ },
+ )
def test_filter_sql(self):
safe_cmd = "keys 1*"
new_engine = RedisEngine(instance=self.ins)
check_result = new_engine.filter_sql(sql=safe_cmd, limit_num=100)
- self.assertEqual(check_result, 'keys 1*')
+ self.assertEqual(check_result, "keys 1*")
def test_query_masking(self):
query_result = ResultSet()
new_engine = RedisEngine(instance=self.ins)
- masking_result = new_engine.query_masking(db_name=0, sql='', resultset=query_result)
+ masking_result = new_engine.query_masking(
+ db_name=0, sql="", resultset=query_result
+ )
self.assertEqual(masking_result, query_result)
def test_execute_check(self):
- sql = 'set 1 1'
- row = ReviewResult(id=1,
- errlevel=0,
- stagestatus='Audit completed',
- errormessage='None',
- sql=sql,
- affected_rows=0,
- execute_time=0)
+ sql = "set 1 1"
+ row = ReviewResult(
+ id=1,
+ errlevel=0,
+ stagestatus="Audit completed",
+ errormessage="None",
+ sql=sql,
+ affected_rows=0,
+ execute_time=0,
+ )
new_engine = RedisEngine(instance=self.ins)
check_result = new_engine.execute_check(db_name=0, sql=sql)
self.assertIsInstance(check_result, ReviewSet)
self.assertEqual(check_result.rows[0].__dict__, row.__dict__)
- @patch('redis.Redis.execute_command', return_value='text')
+ @patch("redis.Redis.execute_command", return_value="text")
def test_execute_workflow_success(self, _execute_command):
- sql = 'set 1 1'
- row = ReviewResult(id=1,
- errlevel=0,
- stagestatus='Execute Successfully',
- errormessage='None',
- sql=sql,
- affected_rows=0,
- execute_time=0)
+ sql = "set 1 1"
+ row = ReviewResult(
+ id=1,
+ errlevel=0,
+ stagestatus="Execute Successfully",
+ errormessage="None",
+ sql=sql,
+ affected_rows=0,
+ execute_time=0,
+ )
wf = SqlWorkflow.objects.create(
- workflow_name='some_name',
+ workflow_name="some_name",
group_id=1,
- group_name='g1',
- engineer_display='',
- audit_auth_groups='some_group',
+ group_name="g1",
+ engineer_display="",
+ audit_auth_groups="some_group",
create_time=datetime.now() - timedelta(days=1),
- status='workflow_finish',
+ status="workflow_finish",
is_backup=True,
instance=self.ins,
- db_name='some_db',
- syntax_type=1
+ db_name="some_db",
+ syntax_type=1,
)
SqlWorkflowContent.objects.create(workflow=wf, sql_content=sql)
new_engine = RedisEngine(instance=self.ins)
@@ -720,8 +815,15 @@ def test_execute_workflow_success(self, _execute_command):
class TestPgSQL(TestCase):
@classmethod
def setUpClass(cls):
- cls.ins = Instance(instance_name='some_ins', type='slave', db_type='pgsql', host='some_host',
- port=1366, user='ins_user', password='some_str')
+ cls.ins = Instance(
+ instance_name="some_ins",
+ type="slave",
+ db_type="pgsql",
+ host="some_host",
+ port=1366,
+ user="ins_user",
+ password="some_str",
+ )
cls.ins.save()
cls.sys_config = SysConfig()
@@ -730,85 +832,130 @@ def tearDownClass(cls):
cls.ins.delete()
cls.sys_config.purge()
- @patch('psycopg2.connect')
+ @patch("psycopg2.connect")
def test_engine_base_info(self, _conn):
new_engine = PgSQLEngine(instance=self.ins)
- self.assertEqual(new_engine.name, 'PgSQL')
- self.assertEqual(new_engine.info, 'PgSQL engine')
+ self.assertEqual(new_engine.name, "PgSQL")
+ self.assertEqual(new_engine.info, "PgSQL engine")
- @patch('psycopg2.connect')
+ @patch("psycopg2.connect")
def test_get_connection(self, _conn):
new_engine = PgSQLEngine(instance=self.ins)
new_engine.get_connection("some_dbname")
_conn.assert_called_once()
- @patch('psycopg2.connect.cursor.execute')
- @patch('psycopg2.connect.cursor')
- @patch('psycopg2.connect')
+ @patch("psycopg2.connect.cursor.execute")
+ @patch("psycopg2.connect.cursor")
+ @patch("psycopg2.connect")
def test_query(self, _conn, _cursor, _execute):
_conn.return_value.cursor.return_value.fetchmany.return_value = [(1,)]
new_engine = PgSQLEngine(instance=self.ins)
- query_result = new_engine.query(db_name="some_dbname", sql='select 1', limit_num=100, schema_name="some_schema")
+ query_result = new_engine.query(
+ db_name="some_dbname",
+ sql="select 1",
+ limit_num=100,
+ schema_name="some_schema",
+ )
self.assertIsInstance(query_result, ResultSet)
self.assertListEqual(query_result.rows, [(1,)])
- @patch('psycopg2.connect.cursor.execute')
- @patch('psycopg2.connect.cursor')
- @patch('psycopg2.connect')
+ @patch("psycopg2.connect.cursor.execute")
+ @patch("psycopg2.connect.cursor")
+ @patch("psycopg2.connect")
def test_query_not_limit(self, _conn, _cursor, _execute):
_conn.return_value.cursor.return_value.fetchall.return_value = [(1,)]
new_engine = PgSQLEngine(instance=self.ins)
- query_result = new_engine.query(db_name="some_dbname", sql='select 1', limit_num=0, schema_name="some_schema")
+ query_result = new_engine.query(
+ db_name="some_dbname",
+ sql="select 1",
+ limit_num=0,
+ schema_name="some_schema",
+ )
self.assertIsInstance(query_result, ResultSet)
self.assertListEqual(query_result.rows, [(1,)])
- @patch('sql.engines.pgsql.PgSQLEngine.query',
- return_value=ResultSet(rows=[('postgres',), ('archery',), ('template1',), ('template0',)]))
+ @patch(
+ "sql.engines.pgsql.PgSQLEngine.query",
+ return_value=ResultSet(
+ rows=[("postgres",), ("archery",), ("template1",), ("template0",)]
+ ),
+ )
def test_get_all_databases(self, query):
new_engine = PgSQLEngine(instance=self.ins)
dbs = new_engine.get_all_databases()
- self.assertListEqual(dbs.rows, ['archery'])
-
- @patch('sql.engines.pgsql.PgSQLEngine.query',
- return_value=ResultSet(rows=[('information_schema',), ('archery',), ('pg_catalog',)]))
+ self.assertListEqual(dbs.rows, ["archery"])
+
+ @patch(
+ "sql.engines.pgsql.PgSQLEngine.query",
+ return_value=ResultSet(
+ rows=[("information_schema",), ("archery",), ("pg_catalog",)]
+ ),
+ )
def test_get_all_schemas(self, _query):
new_engine = PgSQLEngine(instance=self.ins)
- schemas = new_engine.get_all_schemas(db_name='archery')
- self.assertListEqual(schemas.rows, ['archery'])
+ schemas = new_engine.get_all_schemas(db_name="archery")
+ self.assertListEqual(schemas.rows, ["archery"])
- @patch('sql.engines.pgsql.PgSQLEngine.query', return_value=ResultSet(rows=[('test',), ('test2',)]))
+ @patch(
+ "sql.engines.pgsql.PgSQLEngine.query",
+ return_value=ResultSet(rows=[("test",), ("test2",)]),
+ )
def test_get_all_tables(self, _query):
new_engine = PgSQLEngine(instance=self.ins)
- tables = new_engine.get_all_tables(db_name='archery', schema_name='archery')
- self.assertListEqual(tables.rows, ['test2'])
+ tables = new_engine.get_all_tables(db_name="archery", schema_name="archery")
+ self.assertListEqual(tables.rows, ["test2"])
- @patch('sql.engines.pgsql.PgSQLEngine.query',
- return_value=ResultSet(rows=[('id',), ('name',)]))
+ @patch(
+ "sql.engines.pgsql.PgSQLEngine.query",
+ return_value=ResultSet(rows=[("id",), ("name",)]),
+ )
def test_get_all_columns_by_tb(self, _query):
new_engine = PgSQLEngine(instance=self.ins)
- columns = new_engine.get_all_columns_by_tb(db_name='archery', tb_name='test2', schema_name='archery')
- self.assertListEqual(columns.rows, ['id', 'name'])
-
- @patch('sql.engines.pgsql.PgSQLEngine.query',
- return_value=ResultSet(rows=[('postgres',), ('archery',), ('template1',), ('template0',)]))
+ columns = new_engine.get_all_columns_by_tb(
+ db_name="archery", tb_name="test2", schema_name="archery"
+ )
+ self.assertListEqual(columns.rows, ["id", "name"])
+
+ @patch(
+ "sql.engines.pgsql.PgSQLEngine.query",
+ return_value=ResultSet(
+ rows=[("postgres",), ("archery",), ("template1",), ("template0",)]
+ ),
+ )
def test_describe_table(self, _query):
new_engine = PgSQLEngine(instance=self.ins)
- describe = new_engine.describe_table(db_name='archery', schema_name='archery', tb_name='text')
+ describe = new_engine.describe_table(
+ db_name="archery", schema_name="archery", tb_name="text"
+ )
self.assertIsInstance(describe, ResultSet)
def test_query_check_disable_sql(self):
sql = "update xxx set a=1 "
new_engine = PgSQLEngine(instance=self.ins)
- check_result = new_engine.query_check(db_name='archery', sql=sql)
- self.assertDictEqual(check_result,
- {'msg': '不支持的查询语法类型!', 'bad_query': True, 'filtered_sql': sql.strip(), 'has_star': False})
+ check_result = new_engine.query_check(db_name="archery", sql=sql)
+ self.assertDictEqual(
+ check_result,
+ {
+ "msg": "不支持的查询语法类型!",
+ "bad_query": True,
+ "filtered_sql": sql.strip(),
+ "has_star": False,
+ },
+ )
def test_query_check_star_sql(self):
sql = "select * from xx "
new_engine = PgSQLEngine(instance=self.ins)
- check_result = new_engine.query_check(db_name='archery', sql=sql)
- self.assertDictEqual(check_result,
- {'msg': 'SQL语句中含有 * ', 'bad_query': False, 'filtered_sql': sql.strip(), 'has_star': True})
+ check_result = new_engine.query_check(db_name="archery", sql=sql)
+ self.assertDictEqual(
+ check_result,
+ {
+ "msg": "SQL语句中含有 * ",
+ "bad_query": False,
+ "filtered_sql": sql.strip(),
+ "has_star": True,
+ },
+ )
def test_filter_sql_with_delimiter(self):
sql = "select * from xx;"
@@ -831,72 +978,84 @@ def test_filter_sql_with_limit(self):
def test_query_masking(self):
query_result = ResultSet()
new_engine = PgSQLEngine(instance=self.ins)
- masking_result = new_engine.query_masking(db_name=0, sql='', resultset=query_result)
+ masking_result = new_engine.query_masking(
+ db_name=0, sql="", resultset=query_result
+ )
self.assertEqual(masking_result, query_result)
def test_execute_check_select_sql(self):
- sql = 'select * from user;'
- row = ReviewResult(id=1, errlevel=2,
- stagestatus='驳回不支持语句',
- errormessage='仅支持DML和DDL语句,查询语句请使用SQL查询功能!',
- sql=sql)
+ sql = "select * from user;"
+ row = ReviewResult(
+ id=1,
+ errlevel=2,
+ stagestatus="驳回不支持语句",
+ errormessage="仅支持DML和DDL语句,查询语句请使用SQL查询功能!",
+ sql=sql,
+ )
new_engine = PgSQLEngine(instance=self.ins)
- check_result = new_engine.execute_check(db_name='archery', sql=sql)
+ check_result = new_engine.execute_check(db_name="archery", sql=sql)
self.assertIsInstance(check_result, ReviewSet)
self.assertEqual(check_result.rows[0].__dict__, row.__dict__)
def test_execute_check_critical_sql(self):
- self.sys_config.set('critical_ddl_regex', '^|update')
+ self.sys_config.set("critical_ddl_regex", "^|update")
self.sys_config.get_all_config()
- sql = 'update user set id=1'
- row = ReviewResult(id=1, errlevel=2,
- stagestatus='驳回高危SQL',
- errormessage='禁止提交匹配' + '^|update' + '条件的语句!',
- sql=sql)
+ sql = "update user set id=1"
+ row = ReviewResult(
+ id=1,
+ errlevel=2,
+ stagestatus="驳回高危SQL",
+ errormessage="禁止提交匹配" + "^|update" + "条件的语句!",
+ sql=sql,
+ )
new_engine = PgSQLEngine(instance=self.ins)
- check_result = new_engine.execute_check(db_name='archery', sql=sql)
+ check_result = new_engine.execute_check(db_name="archery", sql=sql)
self.assertIsInstance(check_result, ReviewSet)
self.assertEqual(check_result.rows[0].__dict__, row.__dict__)
def test_execute_check_normal_sql(self):
self.sys_config.purge()
- sql = 'alter table tb set id=1'
- row = ReviewResult(id=1,
- errlevel=0,
- stagestatus='Audit completed',
- errormessage='None',
- sql=sql,
- affected_rows=0,
- execute_time=0, )
+ sql = "alter table tb set id=1"
+ row = ReviewResult(
+ id=1,
+ errlevel=0,
+ stagestatus="Audit completed",
+ errormessage="None",
+ sql=sql,
+ affected_rows=0,
+ execute_time=0,
+ )
new_engine = PgSQLEngine(instance=self.ins)
- check_result = new_engine.execute_check(db_name='archery', sql=sql)
+ check_result = new_engine.execute_check(db_name="archery", sql=sql)
self.assertIsInstance(check_result, ReviewSet)
self.assertEqual(check_result.rows[0].__dict__, row.__dict__)
- @patch('psycopg2.connect.cursor.execute')
- @patch('psycopg2.connect.cursor')
- @patch('psycopg2.connect')
+ @patch("psycopg2.connect.cursor.execute")
+ @patch("psycopg2.connect.cursor")
+ @patch("psycopg2.connect")
def test_execute_workflow_success(self, _conn, _cursor, _execute):
- sql = 'update user set id=1'
- row = ReviewResult(id=1,
- errlevel=0,
- stagestatus='Execute Successfully',
- errormessage='None',
- sql=sql,
- affected_rows=0,
- execute_time=0)
+ sql = "update user set id=1"
+ row = ReviewResult(
+ id=1,
+ errlevel=0,
+ stagestatus="Execute Successfully",
+ errormessage="None",
+ sql=sql,
+ affected_rows=0,
+ execute_time=0,
+ )
wf = SqlWorkflow.objects.create(
- workflow_name='some_name',
+ workflow_name="some_name",
group_id=1,
- group_name='g1',
- engineer_display='',
- audit_auth_groups='some_group',
+ group_name="g1",
+ engineer_display="",
+ audit_auth_groups="some_group",
create_time=datetime.now() - timedelta(days=1),
- status='workflow_finish',
+ status="workflow_finish",
is_backup=True,
instance=self.ins,
- db_name='some_db',
- syntax_type=1
+ db_name="some_db",
+ syntax_type=1,
)
SqlWorkflowContent.objects.create(workflow=wf, sql_content=sql)
new_engine = PgSQLEngine(instance=self.ins)
@@ -904,41 +1063,44 @@ def test_execute_workflow_success(self, _conn, _cursor, _execute):
self.assertIsInstance(execute_result, ReviewSet)
self.assertEqual(execute_result.rows[0].__dict__.keys(), row.__dict__.keys())
- @patch('psycopg2.connect.cursor.execute')
- @patch('psycopg2.connect.cursor')
- @patch('psycopg2.connect', return_value=RuntimeError)
+ @patch("psycopg2.connect.cursor.execute")
+ @patch("psycopg2.connect.cursor")
+ @patch("psycopg2.connect", return_value=RuntimeError)
def test_execute_workflow_exception(self, _conn, _cursor, _execute):
- sql = 'update user set id=1'
- row = ReviewResult(id=1,
- errlevel=2,
- stagestatus='Execute Failed',
- errormessage=f'异常信息:{f"Oracle命令执行报错,语句:{sql}"}',
- sql=sql,
- affected_rows=0,
- execute_time=0, )
+ sql = "update user set id=1"
+ row = ReviewResult(
+ id=1,
+ errlevel=2,
+ stagestatus="Execute Failed",
+ errormessage=f'异常信息:{f"Oracle命令执行报错,语句:{sql}"}',
+ sql=sql,
+ affected_rows=0,
+ execute_time=0,
+ )
wf = SqlWorkflow.objects.create(
- workflow_name='some_name',
+ workflow_name="some_name",
group_id=1,
- group_name='g1',
- engineer_display='',
- audit_auth_groups='some_group',
+ group_name="g1",
+ engineer_display="",
+ audit_auth_groups="some_group",
create_time=datetime.now() - timedelta(days=1),
- status='workflow_finish',
+ status="workflow_finish",
is_backup=True,
instance=self.ins,
- db_name='some_db',
- syntax_type=1
+ db_name="some_db",
+ syntax_type=1,
)
SqlWorkflowContent.objects.create(workflow=wf, sql_content=sql)
with self.assertRaises(AttributeError):
new_engine = PgSQLEngine(instance=self.ins)
execute_result = new_engine.execute_workflow(workflow=wf)
self.assertIsInstance(execute_result, ReviewSet)
- self.assertEqual(execute_result.rows[0].__dict__.keys(), row.__dict__.keys())
+ self.assertEqual(
+ execute_result.rows[0].__dict__.keys(), row.__dict__.keys()
+ )
class TestModel(TestCase):
-
def setUp(self):
pass
@@ -963,23 +1125,34 @@ def test_result_set_rows_shadow(self):
class TestGoInception(TestCase):
def setUp(self):
- self.ins = Instance.objects.create(instance_name='some_ins', type='slave', db_type='mysql',
- host='some_host',
- port=3306, user='ins_user', password='some_str')
- self.ins_inc = Instance.objects.create(instance_name='some_ins_inc', type='slave', db_type='goinception',
- host='some_host', port=4000)
+ self.ins = Instance.objects.create(
+ instance_name="some_ins",
+ type="slave",
+ db_type="mysql",
+ host="some_host",
+ port=3306,
+ user="ins_user",
+ password="some_str",
+ )
+ self.ins_inc = Instance.objects.create(
+ instance_name="some_ins_inc",
+ type="slave",
+ db_type="goinception",
+ host="some_host",
+ port=4000,
+ )
self.wf = SqlWorkflow.objects.create(
- workflow_name='some_name',
+ workflow_name="some_name",
group_id=1,
- group_name='g1',
- engineer_display='',
- audit_auth_groups='some_group',
+ group_name="g1",
+ engineer_display="",
+ audit_auth_groups="some_group",
create_time=datetime.now() - timedelta(days=1),
- status='workflow_finish',
+ status="workflow_finish",
is_backup=True,
instance=self.ins,
- db_name='some_db',
- syntax_type=1
+ db_name="some_db",
+ syntax_type=1,
)
SqlWorkflowContent.objects.create(workflow=self.wf)
@@ -989,119 +1162,186 @@ def tearDown(self):
SqlWorkflow.objects.all().delete()
SqlWorkflowContent.objects.all().delete()
- @patch('MySQLdb.connect')
+ @patch("MySQLdb.connect")
def test_get_connection(self, _connect):
new_engine = GoInceptionEngine()
new_engine.get_connection()
_connect.assert_called_once()
- @patch('sql.engines.goinception.GoInceptionEngine.query')
+ @patch("sql.engines.goinception.GoInceptionEngine.query")
def test_execute_check_normal_sql(self, _query):
- sql = 'update user set id=100'
- row = [1, 'CHECKED', 0, 'Audit completed', 'None', 'use archery', 0, "'0_0_0'", 'None', '0', '', '']
+ sql = "update user set id=100"
+ row = [
+ 1,
+ "CHECKED",
+ 0,
+ "Audit completed",
+ "None",
+ "use archery",
+ 0,
+ "'0_0_0'",
+ "None",
+ "0",
+ "",
+ "",
+ ]
_query.return_value = ResultSet(full_sql=sql, rows=[row])
new_engine = GoInceptionEngine()
check_result = new_engine.execute_check(instance=self.ins, db_name=0, sql=sql)
self.assertIsInstance(check_result, ReviewSet)
- @patch('sql.engines.goinception.GoInceptionEngine.query')
+ @patch("sql.engines.goinception.GoInceptionEngine.query")
def test_execute_exception(self, _query):
- sql = 'update user set id=100'
- row = [1, 'CHECKED', 1, 'Execute failed', 'None', 'use archery', 0, "'0_0_0'", 'None', '0', '', '']
- column_list = ['order_id', 'stage', 'error_level', 'stage_status', 'error_message', 'sql',
- 'affected_rows', 'sequence', 'backup_dbname', 'execute_time', 'sqlsha1', 'backup_time']
- _query.return_value = ResultSet(full_sql=sql, rows=[row], column_list=column_list)
+ sql = "update user set id=100"
+ row = [
+ 1,
+ "CHECKED",
+ 1,
+ "Execute failed",
+ "None",
+ "use archery",
+ 0,
+ "'0_0_0'",
+ "None",
+ "0",
+ "",
+ "",
+ ]
+ column_list = [
+ "order_id",
+ "stage",
+ "error_level",
+ "stage_status",
+ "error_message",
+ "sql",
+ "affected_rows",
+ "sequence",
+ "backup_dbname",
+ "execute_time",
+ "sqlsha1",
+ "backup_time",
+ ]
+ _query.return_value = ResultSet(
+ full_sql=sql, rows=[row], column_list=column_list
+ )
new_engine = GoInceptionEngine()
execute_result = new_engine.execute(workflow=self.wf)
self.assertIsInstance(execute_result, ReviewSet)
- @patch('sql.engines.goinception.GoInceptionEngine.query')
+ @patch("sql.engines.goinception.GoInceptionEngine.query")
def test_execute_finish(self, _query):
- sql = 'update user set id=100'
- row = [1, 'CHECKED', 0, 'Execute Successfully', 'None', 'use archery', 0, "'0_0_0'", 'None', '0', '', '']
- column_list = ['order_id', 'stage', 'error_level', 'stage_status', 'error_message', 'sql',
- 'affected_rows', 'sequence', 'backup_dbname', 'execute_time', 'sqlsha1', 'backup_time']
- _query.return_value = ResultSet(full_sql=sql, rows=[row], column_list=column_list)
+ sql = "update user set id=100"
+ row = [
+ 1,
+ "CHECKED",
+ 0,
+ "Execute Successfully",
+ "None",
+ "use archery",
+ 0,
+ "'0_0_0'",
+ "None",
+ "0",
+ "",
+ "",
+ ]
+ column_list = [
+ "order_id",
+ "stage",
+ "error_level",
+ "stage_status",
+ "error_message",
+ "sql",
+ "affected_rows",
+ "sequence",
+ "backup_dbname",
+ "execute_time",
+ "sqlsha1",
+ "backup_time",
+ ]
+ _query.return_value = ResultSet(
+ full_sql=sql, rows=[row], column_list=column_list
+ )
new_engine = GoInceptionEngine()
execute_result = new_engine.execute(workflow=self.wf)
self.assertIsInstance(execute_result, ReviewSet)
- @patch('MySQLdb.connect.cursor.execute')
- @patch('MySQLdb.connect.cursor')
- @patch('MySQLdb.connect')
+ @patch("MySQLdb.connect.cursor.execute")
+ @patch("MySQLdb.connect.cursor")
+ @patch("MySQLdb.connect")
def test_query(self, _conn, _cursor, _execute):
_conn.return_value.cursor.return_value.fetchall.return_value = [(1,)]
new_engine = GoInceptionEngine()
- query_result = new_engine.query(db_name=0, sql='select 1', limit_num=100)
+ query_result = new_engine.query(db_name=0, sql="select 1", limit_num=100)
self.assertIsInstance(query_result, ResultSet)
- @patch('MySQLdb.connect.cursor.execute')
- @patch('MySQLdb.connect.cursor')
- @patch('MySQLdb.connect')
+ @patch("MySQLdb.connect.cursor.execute")
+ @patch("MySQLdb.connect.cursor")
+ @patch("MySQLdb.connect")
def test_query_not_limit(self, _conn, _cursor, _execute):
_conn.return_value.cursor.return_value.fetchall.return_value = [(1,)]
new_engine = GoInceptionEngine(instance=self.ins)
- query_result = new_engine.query(db_name=0, sql='select 1', limit_num=0)
+ query_result = new_engine.query(db_name=0, sql="select 1", limit_num=0)
self.assertIsInstance(query_result, ResultSet)
- @patch('sql.engines.goinception.GoInceptionEngine.query')
+ @patch("sql.engines.goinception.GoInceptionEngine.query")
def test_osc_get(self, _query):
new_engine = GoInceptionEngine()
- command = 'get'
- sqlsha1 = 'xxxxx'
+ command = "get"
+ sqlsha1 = "xxxxx"
sql = f"inception get osc_percent '{sqlsha1}';"
_query.return_value = ResultSet(full_sql=sql, rows=[], column_list=[])
new_engine.osc_control(sqlsha1=sqlsha1, command=command)
_query.assert_called_once_with(sql=sql)
- @patch('sql.engines.goinception.GoInceptionEngine.query')
+ @patch("sql.engines.goinception.GoInceptionEngine.query")
def test_osc_pause(self, _query):
new_engine = GoInceptionEngine()
- command = 'pause'
- sqlsha1 = 'xxxxx'
+ command = "pause"
+ sqlsha1 = "xxxxx"
sql = f"inception {command} osc '{sqlsha1}';"
_query.return_value = ResultSet(full_sql=sql, rows=[], column_list=[])
new_engine.osc_control(sqlsha1=sqlsha1, command=command)
_query.assert_called_once_with(sql=sql)
- @patch('sql.engines.goinception.GoInceptionEngine.query')
+ @patch("sql.engines.goinception.GoInceptionEngine.query")
def test_osc_resume(self, _query):
new_engine = GoInceptionEngine()
- command = 'resume'
- sqlsha1 = 'xxxxx'
+ command = "resume"
+ sqlsha1 = "xxxxx"
sql = f"inception {command} osc '{sqlsha1}';"
_query.return_value = ResultSet(full_sql=sql, rows=[], column_list=[])
new_engine.osc_control(sqlsha1=sqlsha1, command=command)
_query.assert_called_once_with(sql=sql)
- @patch('sql.engines.goinception.GoInceptionEngine.query')
+ @patch("sql.engines.goinception.GoInceptionEngine.query")
def test_osc_kill(self, _query):
new_engine = GoInceptionEngine()
- command = 'kill'
- sqlsha1 = 'xxxxx'
+ command = "kill"
+ sqlsha1 = "xxxxx"
sql = f"inception kill osc '{sqlsha1}';"
_query.return_value = ResultSet(full_sql=sql, rows=[], column_list=[])
new_engine.osc_control(sqlsha1=sqlsha1, command=command)
_query.assert_called_once_with(sql=sql)
- @patch('sql.engines.goinception.GoInceptionEngine.query')
+ @patch("sql.engines.goinception.GoInceptionEngine.query")
def test_get_variables(self, _query):
new_engine = GoInceptionEngine(instance=self.ins_inc)
new_engine.get_variables()
sql = f"inception get variables;"
_query.assert_called_once_with(sql=sql)
- @patch('sql.engines.goinception.GoInceptionEngine.query')
+ @patch("sql.engines.goinception.GoInceptionEngine.query")
def test_get_variables_filter(self, _query):
new_engine = GoInceptionEngine(instance=self.ins_inc)
- new_engine.get_variables(variables=['inception_osc_on'])
+ new_engine.get_variables(variables=["inception_osc_on"])
sql = f"inception get variables like 'inception_osc_on';"
_query.assert_called_once_with(sql=sql)
- @patch('sql.engines.goinception.GoInceptionEngine.query')
+ @patch("sql.engines.goinception.GoInceptionEngine.query")
def test_set_variable(self, _query):
new_engine = GoInceptionEngine(instance=self.ins)
- new_engine.set_variable('inception_osc_on', 'on')
+ new_engine.set_variable("inception_osc_on", "on")
_query.assert_called_once_with(sql="inception set inception_osc_on=on;")
@@ -1109,21 +1349,28 @@ class TestOracle(TestCase):
"""Oracle 测试"""
def setUp(self):
- self.ins = Instance.objects.create(instance_name='some_ins', type='slave', db_type='oracle',
- host='some_host', port=3306, user='ins_user', password='some_str',
- sid='some_id')
+ self.ins = Instance.objects.create(
+ instance_name="some_ins",
+ type="slave",
+ db_type="oracle",
+ host="some_host",
+ port=3306,
+ user="ins_user",
+ password="some_str",
+ sid="some_id",
+ )
self.wf = SqlWorkflow.objects.create(
- workflow_name='some_name',
+ workflow_name="some_name",
group_id=1,
- group_name='g1',
- engineer_display='',
- audit_auth_groups='some_group',
+ group_name="g1",
+ engineer_display="",
+ audit_auth_groups="some_group",
create_time=datetime.now() - timedelta(days=1),
- status='workflow_finish',
+ status="workflow_finish",
is_backup=True,
instance=self.ins,
- db_name='some_db',
- syntax_type=1
+ db_name="some_db",
+ syntax_type=1,
)
SqlWorkflowContent.objects.create(workflow=self.wf)
self.sys_config = SysConfig()
@@ -1134,8 +1381,8 @@ def tearDown(self):
SqlWorkflow.objects.all().delete()
SqlWorkflowContent.objects.all().delete()
- @patch('cx_Oracle.makedsn')
- @patch('cx_Oracle.connect')
+ @patch("cx_Oracle.makedsn")
+ @patch("cx_Oracle.connect")
def test_get_connection(self, _connect, _makedsn):
# 填写 sid 测试
new_engine = OracleEngine(self.ins)
@@ -1145,8 +1392,8 @@ def test_get_connection(self, _connect, _makedsn):
# 填写 service_name 测试
_connect.reset_mock()
_makedsn.reset_mock()
- self.ins.service_name = 'some_service'
- self.ins.sid = ''
+ self.ins.service_name = "some_service"
+ self.ins.sid = ""
self.ins.save()
new_engine = OracleEngine(self.ins)
new_engine.get_connection()
@@ -1155,364 +1402,475 @@ def test_get_connection(self, _connect, _makedsn):
# 都不填写, 检测 ValueError
_connect.reset_mock()
_makedsn.reset_mock()
- self.ins.service_name = ''
- self.ins.sid = ''
+ self.ins.service_name = ""
+ self.ins.sid = ""
self.ins.save()
new_engine = OracleEngine(self.ins)
with self.assertRaises(ValueError):
new_engine.get_connection()
- @patch('cx_Oracle.connect')
+ @patch("cx_Oracle.connect")
def test_engine_base_info(self, _conn):
new_engine = OracleEngine(instance=self.ins)
- self.assertEqual(new_engine.name, 'Oracle')
- self.assertEqual(new_engine.info, 'Oracle engine')
- _conn.return_value.version = '12.1.0.2.0'
- self.assertTupleEqual(new_engine.server_version, ('12', '1', '0'))
-
- @patch('cx_Oracle.connect.cursor.execute')
- @patch('cx_Oracle.connect.cursor')
- @patch('cx_Oracle.connect')
+ self.assertEqual(new_engine.name, "Oracle")
+ self.assertEqual(new_engine.info, "Oracle engine")
+ _conn.return_value.version = "12.1.0.2.0"
+ self.assertTupleEqual(new_engine.server_version, ("12", "1", "0"))
+
+ @patch("cx_Oracle.connect.cursor.execute")
+ @patch("cx_Oracle.connect.cursor")
+ @patch("cx_Oracle.connect")
def test_query(self, _conn, _cursor, _execute):
_conn.return_value.cursor.return_value.fetchmany.return_value = [(1,)]
new_engine = OracleEngine(instance=self.ins)
- query_result = new_engine.query(db_name='archery', sql='select 1', limit_num=100)
+ query_result = new_engine.query(
+ db_name="archery", sql="select 1", limit_num=100
+ )
self.assertIsInstance(query_result, ResultSet)
self.assertListEqual(query_result.rows, [(1,)])
- @patch('cx_Oracle.connect.cursor.execute')
- @patch('cx_Oracle.connect.cursor')
- @patch('cx_Oracle.connect')
+ @patch("cx_Oracle.connect.cursor.execute")
+ @patch("cx_Oracle.connect.cursor")
+ @patch("cx_Oracle.connect")
def test_query_not_limit(self, _conn, _cursor, _execute):
_conn.return_value.cursor.return_value.fetchall.return_value = [(1,)]
new_engine = OracleEngine(instance=self.ins)
- query_result = new_engine.query(db_name=0, sql='select 1', limit_num=0)
+ query_result = new_engine.query(db_name=0, sql="select 1", limit_num=0)
self.assertIsInstance(query_result, ResultSet)
self.assertListEqual(query_result.rows, [(1,)])
- @patch('sql.engines.oracle.OracleEngine.query',
- return_value=ResultSet(rows=[('AUD_SYS',), ('archery',), ('ANONYMOUS',)]))
+ @patch(
+ "sql.engines.oracle.OracleEngine.query",
+ return_value=ResultSet(rows=[("AUD_SYS",), ("archery",), ("ANONYMOUS",)]),
+ )
def test_get_all_databases(self, _query):
new_engine = OracleEngine(instance=self.ins)
dbs = new_engine.get_all_databases()
- self.assertListEqual(dbs.rows, ['archery'])
+ self.assertListEqual(dbs.rows, ["archery"])
- @patch('sql.engines.oracle.OracleEngine.query',
- return_value=ResultSet(rows=[('AUD_SYS',), ('archery',), ('ANONYMOUS',)]))
+ @patch(
+ "sql.engines.oracle.OracleEngine.query",
+ return_value=ResultSet(rows=[("AUD_SYS",), ("archery",), ("ANONYMOUS",)]),
+ )
def test__get_all_databases(self, _query):
new_engine = OracleEngine(instance=self.ins)
dbs = new_engine._get_all_databases()
- self.assertListEqual(dbs.rows, ['AUD_SYS', 'archery', 'ANONYMOUS'])
+ self.assertListEqual(dbs.rows, ["AUD_SYS", "archery", "ANONYMOUS"])
- @patch('sql.engines.oracle.OracleEngine.query',
- return_value=ResultSet(rows=[('archery',)]))
+ @patch(
+ "sql.engines.oracle.OracleEngine.query",
+ return_value=ResultSet(rows=[("archery",)]),
+ )
def test__get_all_instances(self, _query):
new_engine = OracleEngine(instance=self.ins)
dbs = new_engine._get_all_instances()
- self.assertListEqual(dbs.rows, ['archery'])
+ self.assertListEqual(dbs.rows, ["archery"])
- @patch('sql.engines.oracle.OracleEngine.query',
- return_value=ResultSet(rows=[('ANONYMOUS',), ('archery',), ('SYSTEM',)]))
+ @patch(
+ "sql.engines.oracle.OracleEngine.query",
+ return_value=ResultSet(rows=[("ANONYMOUS",), ("archery",), ("SYSTEM",)]),
+ )
def test_get_all_schemas(self, _query):
new_engine = OracleEngine(instance=self.ins)
schemas = new_engine._get_all_schemas()
- self.assertListEqual(schemas.rows, ['archery'])
+ self.assertListEqual(schemas.rows, ["archery"])
- @patch('sql.engines.oracle.OracleEngine.query', return_value=ResultSet(rows=[('test',), ('test2',)]))
+ @patch(
+ "sql.engines.oracle.OracleEngine.query",
+ return_value=ResultSet(rows=[("test",), ("test2",)]),
+ )
def test_get_all_tables(self, _query):
new_engine = OracleEngine(instance=self.ins)
- tables = new_engine.get_all_tables(db_name='archery')
- self.assertListEqual(tables.rows, ['test2'])
+ tables = new_engine.get_all_tables(db_name="archery")
+ self.assertListEqual(tables.rows, ["test2"])
- @patch('sql.engines.oracle.OracleEngine.query',
- return_value=ResultSet(rows=[('id',), ('name',)]))
+ @patch(
+ "sql.engines.oracle.OracleEngine.query",
+ return_value=ResultSet(rows=[("id",), ("name",)]),
+ )
def test_get_all_columns_by_tb(self, _query):
new_engine = OracleEngine(instance=self.ins)
- columns = new_engine.get_all_columns_by_tb(db_name='archery', tb_name='test2')
- self.assertListEqual(columns.rows, ['id', 'name'])
+ columns = new_engine.get_all_columns_by_tb(db_name="archery", tb_name="test2")
+ self.assertListEqual(columns.rows, ["id", "name"])
- @patch('sql.engines.oracle.OracleEngine.query',
- return_value=ResultSet(rows=[('archery',), ('template1',), ('template0',)]))
+ @patch(
+ "sql.engines.oracle.OracleEngine.query",
+ return_value=ResultSet(rows=[("archery",), ("template1",), ("template0",)]),
+ )
def test_describe_table(self, _query):
new_engine = OracleEngine(instance=self.ins)
- describe = new_engine.describe_table(db_name='archery', tb_name='text')
+ describe = new_engine.describe_table(db_name="archery", tb_name="text")
self.assertIsInstance(describe, ResultSet)
def test_query_check_disable_sql(self):
sql = "update xxx set a=1;"
new_engine = OracleEngine(instance=self.ins)
- check_result = new_engine.query_check(db_name='archery', sql=sql)
- self.assertDictEqual(check_result,
- {'msg': '不支持语法!', 'bad_query': True, 'filtered_sql': sql.strip(';'),
- 'has_star': False})
+ check_result = new_engine.query_check(db_name="archery", sql=sql)
+ self.assertDictEqual(
+ check_result,
+ {
+ "msg": "不支持语法!",
+ "bad_query": True,
+ "filtered_sql": sql.strip(";"),
+ "has_star": False,
+ },
+ )
- @patch('sql.engines.oracle.OracleEngine.explain_check', return_value={'msg': '', 'rows': 0})
+ @patch(
+ "sql.engines.oracle.OracleEngine.explain_check",
+ return_value={"msg": "", "rows": 0},
+ )
def test_query_check_star_sql(self, _explain_check):
sql = "select * from xx;"
new_engine = OracleEngine(instance=self.ins)
- check_result = new_engine.query_check(db_name='archery', sql=sql)
- self.assertDictEqual(check_result,
- {'msg': '禁止使用 * 关键词\n', 'bad_query': False, 'filtered_sql': sql.strip(';'),
- 'has_star': True})
+ check_result = new_engine.query_check(db_name="archery", sql=sql)
+ self.assertDictEqual(
+ check_result,
+ {
+ "msg": "禁止使用 * 关键词\n",
+ "bad_query": False,
+ "filtered_sql": sql.strip(";"),
+ "has_star": True,
+ },
+ )
def test_query_check_IndexError(self):
sql = ""
new_engine = OracleEngine(instance=self.ins)
- check_result = new_engine.query_check(db_name='archery', sql=sql)
- self.assertDictEqual(check_result,
- {'msg': '没有有效的SQL语句', 'bad_query': True, 'filtered_sql': sql.strip(), 'has_star': False})
+ check_result = new_engine.query_check(db_name="archery", sql=sql)
+ self.assertDictEqual(
+ check_result,
+ {
+ "msg": "没有有效的SQL语句",
+ "bad_query": True,
+ "filtered_sql": sql.strip(),
+ "has_star": False,
+ },
+ )
def test_filter_sql_with_delimiter(self):
sql = "select * from xx;"
new_engine = OracleEngine(instance=self.ins)
check_result = new_engine.filter_sql(sql=sql, limit_num=100)
- self.assertEqual(check_result, "select sql_audit.* from (select * from xx) sql_audit where rownum <= 100")
+ self.assertEqual(
+ check_result,
+ "select sql_audit.* from (select * from xx) sql_audit where rownum <= 100",
+ )
def test_filter_sql_with_delimiter_and_where(self):
sql = "select * from xx where id>1;"
new_engine = OracleEngine(instance=self.ins)
check_result = new_engine.filter_sql(sql=sql, limit_num=100)
- self.assertEqual(check_result,
- "select sql_audit.* from (select * from xx where id>1) sql_audit where rownum <= 100")
+ self.assertEqual(
+ check_result,
+ "select sql_audit.* from (select * from xx where id>1) sql_audit where rownum <= 100",
+ )
def test_filter_sql_without_delimiter(self):
sql = "select * from xx;"
new_engine = OracleEngine(instance=self.ins)
check_result = new_engine.filter_sql(sql=sql, limit_num=100)
- self.assertEqual(check_result, "select sql_audit.* from (select * from xx) sql_audit where rownum <= 100")
+ self.assertEqual(
+ check_result,
+ "select sql_audit.* from (select * from xx) sql_audit where rownum <= 100",
+ )
def test_filter_sql_with_limit(self):
sql = "select * from xx limit 10;"
new_engine = OracleEngine(instance=self.ins)
check_result = new_engine.filter_sql(sql=sql, limit_num=1)
- self.assertEqual(check_result,
- "select sql_audit.* from (select * from xx limit 10) sql_audit where rownum <= 1")
+ self.assertEqual(
+ check_result,
+ "select sql_audit.* from (select * from xx limit 10) sql_audit where rownum <= 1",
+ )
def test_query_masking(self):
query_result = ResultSet()
new_engine = OracleEngine(instance=self.ins)
- masking_result = new_engine.query_masking(sql='select 1 from dual', resultset=query_result)
+ masking_result = new_engine.query_masking(
+ sql="select 1 from dual", resultset=query_result
+ )
self.assertEqual(masking_result, query_result)
def test_execute_check_select_sql(self):
- sql = 'select * from user;'
- row = ReviewResult(id=1, errlevel=2,
- stagestatus='驳回不支持语句',
- errormessage='仅支持DML和DDL语句,查询语句请使用SQL查询功能!',
- sql=sqlparse.format(sql, strip_comments=True, reindent=True, keyword_case='lower'))
+ sql = "select * from user;"
+ row = ReviewResult(
+ id=1,
+ errlevel=2,
+ stagestatus="驳回不支持语句",
+ errormessage="仅支持DML和DDL语句,查询语句请使用SQL查询功能!",
+ sql=sqlparse.format(
+ sql, strip_comments=True, reindent=True, keyword_case="lower"
+ ),
+ )
new_engine = OracleEngine(instance=self.ins)
- check_result = new_engine.execute_check(db_name='archery', sql=sql)
+ check_result = new_engine.execute_check(db_name="archery", sql=sql)
self.assertIsInstance(check_result, ReviewSet)
self.assertEqual(check_result.rows[0].__dict__, row.__dict__)
def test_execute_check_critical_sql(self):
- self.sys_config.set('critical_ddl_regex', '^|update')
+ self.sys_config.set("critical_ddl_regex", "^|update")
self.sys_config.get_all_config()
- sql = 'update user set id=1'
- row = ReviewResult(id=1, errlevel=2,
- stagestatus='驳回高危SQL',
- errormessage='禁止提交匹配' + '^|update' + '条件的语句!',
- sql=sqlparse.format(sql, strip_comments=True, reindent=True, keyword_case='lower'))
+ sql = "update user set id=1"
+ row = ReviewResult(
+ id=1,
+ errlevel=2,
+ stagestatus="驳回高危SQL",
+ errormessage="禁止提交匹配" + "^|update" + "条件的语句!",
+ sql=sqlparse.format(
+ sql, strip_comments=True, reindent=True, keyword_case="lower"
+ ),
+ )
new_engine = OracleEngine(instance=self.ins)
- check_result = new_engine.execute_check(db_name='archery', sql=sql)
+ check_result = new_engine.execute_check(db_name="archery", sql=sql)
self.assertIsInstance(check_result, ReviewSet)
self.assertEqual(check_result.rows[0].__dict__, row.__dict__)
- @patch('sql.engines.oracle.OracleEngine.explain_check', return_value={'msg': '', 'rows': 0})
- @patch('sql.engines.oracle.OracleEngine.get_sql_first_object_name', return_value='tb')
- @patch('sql.engines.oracle.OracleEngine.object_name_check', return_value=True)
- def test_execute_check_normal_sql(self, _explain_check, _get_sql_first_object_name, _object_name_check):
+ @patch(
+ "sql.engines.oracle.OracleEngine.explain_check",
+ return_value={"msg": "", "rows": 0},
+ )
+ @patch(
+ "sql.engines.oracle.OracleEngine.get_sql_first_object_name", return_value="tb"
+ )
+ @patch("sql.engines.oracle.OracleEngine.object_name_check", return_value=True)
+ def test_execute_check_normal_sql(
+ self, _explain_check, _get_sql_first_object_name, _object_name_check
+ ):
self.sys_config.purge()
- sql = 'alter table tb set id=1'
- row = ReviewResult(id=1,
- errlevel=1,
- stagestatus='当前平台,此语法不支持审核!',
- errormessage='当前平台,此语法不支持审核!',
- sql=sqlparse.format(sql, strip_comments=True, reindent=True, keyword_case='lower'),
- affected_rows=0,
- execute_time=0,
- stmt_type='SQL',
- object_owner='',
- object_type='',
- object_name='',
- )
+ sql = "alter table tb set id=1"
+ row = ReviewResult(
+ id=1,
+ errlevel=1,
+ stagestatus="当前平台,此语法不支持审核!",
+ errormessage="当前平台,此语法不支持审核!",
+ sql=sqlparse.format(
+ sql, strip_comments=True, reindent=True, keyword_case="lower"
+ ),
+ affected_rows=0,
+ execute_time=0,
+ stmt_type="SQL",
+ object_owner="",
+ object_type="",
+ object_name="",
+ )
new_engine = OracleEngine(instance=self.ins)
- check_result = new_engine.execute_check(db_name='archery', sql=sql)
+ check_result = new_engine.execute_check(db_name="archery", sql=sql)
self.assertIsInstance(check_result, ReviewSet)
self.assertEqual(check_result.rows[0].__dict__, row.__dict__)
- @patch('cx_Oracle.connect.cursor.execute')
- @patch('cx_Oracle.connect.cursor')
- @patch('cx_Oracle.connect')
+ @patch("cx_Oracle.connect.cursor.execute")
+ @patch("cx_Oracle.connect.cursor")
+ @patch("cx_Oracle.connect")
def test_execute_workflow_success(self, _conn, _cursor, _execute):
- sql = 'update user set id=1'
- review_row = ReviewResult(id=1,
- errlevel=0,
- stagestatus='Execute Successfully',
- errormessage='None',
- sql=sql,
- affected_rows=0,
- execute_time=0,
- stmt_type='SQL',
- object_owner='',
- object_type='',
- object_name='', )
- execute_row = ReviewResult(id=1,
- errlevel=0,
- stagestatus='Execute Successfully',
- errormessage='None',
- sql=sql,
- affected_rows=0,
- execute_time=0)
+ sql = "update user set id=1"
+ review_row = ReviewResult(
+ id=1,
+ errlevel=0,
+ stagestatus="Execute Successfully",
+ errormessage="None",
+ sql=sql,
+ affected_rows=0,
+ execute_time=0,
+ stmt_type="SQL",
+ object_owner="",
+ object_type="",
+ object_name="",
+ )
+ execute_row = ReviewResult(
+ id=1,
+ errlevel=0,
+ stagestatus="Execute Successfully",
+ errormessage="None",
+ sql=sql,
+ affected_rows=0,
+ execute_time=0,
+ )
wf = SqlWorkflow.objects.create(
- workflow_name='some_name',
+ workflow_name="some_name",
group_id=1,
- group_name='g1',
- engineer_display='',
- audit_auth_groups='some_group',
+ group_name="g1",
+ engineer_display="",
+ audit_auth_groups="some_group",
create_time=datetime.now() - timedelta(days=1),
- status='workflow_finish',
+ status="workflow_finish",
is_backup=True,
instance=self.ins,
- db_name='some_db',
- syntax_type=1
+ db_name="some_db",
+ syntax_type=1,
+ )
+ SqlWorkflowContent.objects.create(
+ workflow=wf,
+ sql_content=sql,
+ review_content=ReviewSet(rows=[review_row]).json(),
)
- SqlWorkflowContent.objects.create(workflow=wf, sql_content=sql,
- review_content=ReviewSet(rows=[review_row]).json())
new_engine = OracleEngine(instance=self.ins)
execute_result = new_engine.execute_workflow(workflow=wf)
self.assertIsInstance(execute_result, ReviewSet)
- self.assertEqual(execute_result.rows[0].__dict__.keys(), execute_row.__dict__.keys())
+ self.assertEqual(
+ execute_result.rows[0].__dict__.keys(), execute_row.__dict__.keys()
+ )
- @patch('cx_Oracle.connect.cursor.execute')
- @patch('cx_Oracle.connect.cursor')
- @patch('cx_Oracle.connect', return_value=RuntimeError)
+ @patch("cx_Oracle.connect.cursor.execute")
+ @patch("cx_Oracle.connect.cursor")
+ @patch("cx_Oracle.connect", return_value=RuntimeError)
def test_execute_workflow_exception(self, _conn, _cursor, _execute):
- sql = 'update user set id=1'
- row = ReviewResult(id=1,
- errlevel=2,
- stagestatus='Execute Failed',
- errormessage=f'异常信息:{f"Oracle命令执行报错,语句:{sql}"}',
- sql=sql,
- affected_rows=0,
- execute_time=0,
- stmt_type='SQL',
- object_owner='',
- object_type='',
- object_name='',
- )
+ sql = "update user set id=1"
+ row = ReviewResult(
+ id=1,
+ errlevel=2,
+ stagestatus="Execute Failed",
+ errormessage=f'异常信息:{f"Oracle命令执行报错,语句:{sql}"}',
+ sql=sql,
+ affected_rows=0,
+ execute_time=0,
+ stmt_type="SQL",
+ object_owner="",
+ object_type="",
+ object_name="",
+ )
wf = SqlWorkflow.objects.create(
- workflow_name='some_name',
+ workflow_name="some_name",
group_id=1,
- group_name='g1',
- engineer_display='',
- audit_auth_groups='some_group',
+ group_name="g1",
+ engineer_display="",
+ audit_auth_groups="some_group",
create_time=datetime.now() - timedelta(days=1),
- status='workflow_finish',
+ status="workflow_finish",
is_backup=True,
instance=self.ins,
- db_name='some_db',
- syntax_type=1
+ db_name="some_db",
+ syntax_type=1,
+ )
+ SqlWorkflowContent.objects.create(
+ workflow=wf, sql_content=sql, review_content=ReviewSet(rows=[row]).json()
)
- SqlWorkflowContent.objects.create(workflow=wf, sql_content=sql, review_content=ReviewSet(rows=[row]).json())
with self.assertRaises(AttributeError):
new_engine = OracleEngine(instance=self.ins)
execute_result = new_engine.execute_workflow(workflow=wf)
self.assertIsInstance(execute_result, ReviewSet)
- self.assertEqual(execute_result.rows[0].__dict__.keys(), row.__dict__.keys())
+ self.assertEqual(
+ execute_result.rows[0].__dict__.keys(), row.__dict__.keys()
+ )
class MongoTest(TestCase):
def setUp(self) -> None:
- self.ins = Instance.objects.create(instance_name='some_ins', type='slave', db_type='mongo',
- host='some_host', port=3306, user='ins_user')
+ self.ins = Instance.objects.create(
+ instance_name="some_ins",
+ type="slave",
+ db_type="mongo",
+ host="some_host",
+ port=3306,
+ user="ins_user",
+ )
self.engine = MongoEngine(instance=self.ins)
def tearDown(self) -> None:
self.ins.delete()
- @patch('sql.engines.mongo.pymongo')
+ @patch("sql.engines.mongo.pymongo")
def test_get_connection(self, mock_pymongo):
_ = self.engine.get_connection()
mock_pymongo.MongoClient.assert_called_once()
- @patch('sql.engines.mongo.MongoEngine.get_connection')
+ @patch("sql.engines.mongo.MongoEngine.get_connection")
def test_query(self, mock_get_connection):
# TODO 正常查询还没做
test_sql = """db.job.find().count()"""
- self.assertIsInstance(self.engine.query('some_db', test_sql), ResultSet)
+ self.assertIsInstance(self.engine.query("some_db", test_sql), ResultSet)
- @patch('sql.engines.mongo.MongoEngine.get_all_tables')
+ @patch("sql.engines.mongo.MongoEngine.get_all_tables")
def test_query_check(self, mock_get_all_tables):
test_sql = """db.job.find().count()"""
- mock_get_all_tables.return_value.rows = ("job")
- check_result = self.engine.query_check('some_db', sql=test_sql)
+ mock_get_all_tables.return_value.rows = "job"
+ check_result = self.engine.query_check("some_db", sql=test_sql)
mock_get_all_tables.assert_called_once()
- self.assertEqual(False, check_result.get('bad_query'))
+ self.assertEqual(False, check_result.get("bad_query"))
- @patch('sql.engines.mongo.MongoEngine.get_connection')
+ @patch("sql.engines.mongo.MongoEngine.get_connection")
def test_get_all_databases(self, mock_get_connection):
db_list = self.engine.get_all_databases()
self.assertIsInstance(db_list, ResultSet)
# mock_get_connection.return_value.list_database_names.assert_called_once()
- @patch('sql.engines.mongo.MongoEngine.get_connection')
+ @patch("sql.engines.mongo.MongoEngine.get_connection")
def test_get_all_tables(self, mock_get_connection):
mock_db = Mock()
# 下面是查表示例返回结果
- mock_db.list_collection_names.return_value = ['u', 'v', 'w']
- mock_get_connection.return_value = {'some_db': mock_db}
- table_list = self.engine.get_all_tables('some_db')
+ mock_db.list_collection_names.return_value = ["u", "v", "w"]
+ mock_get_connection.return_value = {"some_db": mock_db}
+ table_list = self.engine.get_all_tables("some_db")
mock_db.list_collection_names.assert_called_once()
- self.assertEqual(table_list.rows, ['u', 'v', 'w'])
+ self.assertEqual(table_list.rows, ["u", "v", "w"])
def test_filter_sql(self):
sql = """explain db.job.find().count()"""
check_result = self.engine.filter_sql(sql, 0)
- self.assertEqual(check_result, 'db.job.find().count().explain()')
+ self.assertEqual(check_result, "db.job.find().count().explain()")
- @patch('sql.engines.mongo.MongoEngine.exec_cmd')
+ @patch("sql.engines.mongo.MongoEngine.exec_cmd")
def test_get_slave(self, mock_exec_cmd):
mock_exec_cmd.return_value = "172.30.2.123:27017"
flag = self.engine.get_slave()
self.assertEqual(True, flag)
- @patch('sql.engines.mongo.MongoEngine.get_all_columns_by_tb')
+ @patch("sql.engines.mongo.MongoEngine.get_all_columns_by_tb")
def test_parse_tuple(self, mock_get_all_columns_by_tb):
cols = ["_id", "title", "tags", "likes"]
mock_get_all_columns_by_tb.return_value.rows = cols
- cursor = [{'_id': {'$oid': '5f10162029684728e70045ab'}, 'title': 'MongoDB', 'tags': 'mongodb', 'likes': 100}]
- rows, columns = self.engine.parse_tuple(cursor, 'some_db', 'job')
- alldata = json.dumps(cursor[0], ensure_ascii=False, indent=2, separators=(",", ":"))
- rerows = (alldata, "ObjectId('5f10162029684728e70045ab')", 'MongoDB', 'mongodb', '100')
- self.assertEqual(columns, ['mongodballdata', '_id', 'title', 'tags', 'likes'])
+ cursor = [
+ {
+ "_id": {"$oid": "5f10162029684728e70045ab"},
+ "title": "MongoDB",
+ "tags": "mongodb",
+ "likes": 100,
+ }
+ ]
+ rows, columns = self.engine.parse_tuple(cursor, "some_db", "job")
+ alldata = json.dumps(
+ cursor[0], ensure_ascii=False, indent=2, separators=(",", ":")
+ )
+ rerows = (
+ alldata,
+ "ObjectId('5f10162029684728e70045ab')",
+ "MongoDB",
+ "mongodb",
+ "100",
+ )
+ self.assertEqual(columns, ["mongodballdata", "_id", "title", "tags", "likes"])
self.assertEqual(rows[0], rerows)
- @patch('sql.engines.mongo.MongoEngine.get_table_conut')
- @patch('sql.engines.mongo.MongoEngine.get_all_tables')
+ @patch("sql.engines.mongo.MongoEngine.get_table_conut")
+ @patch("sql.engines.mongo.MongoEngine.get_all_tables")
def test_execute_check(self, mock_get_all_tables, mock_get_table_conut):
- sql = '''db.job.createIndex({"skuId":1},{background:true});'''
- mock_get_all_tables.return_value.rows = ("job")
+ sql = """db.job.createIndex({"skuId":1},{background:true});"""
+ mock_get_all_tables.return_value.rows = "job"
mock_get_table_conut.return_value = 1000
- row = ReviewResult(id=1, errlevel=0,
- stagestatus='Audit completed',
- errormessage='检测通过',
- affected_rows=1000,
- sql=sql,
- execute_time=0)
- check_result = self.engine.execute_check('some_db', sql)
- self.assertEqual(check_result.rows[0].__dict__["errormessage"], row.__dict__["errormessage"])
-
- @patch('sql.engines.mongo.MongoEngine.exec_cmd')
- @patch('sql.engines.mongo.MongoEngine.get_master')
+ row = ReviewResult(
+ id=1,
+ errlevel=0,
+ stagestatus="Audit completed",
+ errormessage="检测通过",
+ affected_rows=1000,
+ sql=sql,
+ execute_time=0,
+ )
+ check_result = self.engine.execute_check("some_db", sql)
+ self.assertEqual(
+ check_result.rows[0].__dict__["errormessage"], row.__dict__["errormessage"]
+ )
+
+ @patch("sql.engines.mongo.MongoEngine.exec_cmd")
+ @patch("sql.engines.mongo.MongoEngine.get_master")
def test_execute(self, mock_get_master, mock_exec_cmd):
- sql = '''db.job.find().createIndex({"skuId":1},{background:true})'''
- mock_exec_cmd.return_value = '''{
+ sql = """db.job.find().createIndex({"skuId":1},{background:true})"""
+ mock_exec_cmd.return_value = """{
"createdCollectionAutomatically" : false,
"numIndexesBefore" : 2,
"numIndexesAfter" : 3,
"ok" : 1
- }'''
+ }"""
check_result = self.engine.execute("some_db", sql)
mock_get_master.assert_called_once()
@@ -1520,74 +1878,92 @@ def test_execute(self, mock_get_master, mock_exec_cmd):
def test_fill_query_columns(self):
columns = ["_id", "title", "tags", "likes"]
- cursor = [{"_id": {"$oid": "5f10162029684728e70045ab"}, "title": "MongoDB", "text": "archery", "likes": 100},
- {"_id": {"$oid": "7f10162029684728e70045ab"}, "author": "archery"}]
+ cursor = [
+ {
+ "_id": {"$oid": "5f10162029684728e70045ab"},
+ "title": "MongoDB",
+ "text": "archery",
+ "likes": 100,
+ },
+ {"_id": {"$oid": "7f10162029684728e70045ab"}, "author": "archery"},
+ ]
cols = self.engine.fill_query_columns(cursor, columns=columns)
self.assertEqual(cols, ["_id", "title", "tags", "likes", "text", "author"])
- @patch('sql.engines.mongo.MongoEngine.get_connection')
+ @patch("sql.engines.mongo.MongoEngine.get_connection")
def test_current_op(self, mock_get_connection):
- class Aggregate():
+ class Aggregate:
def __enter__(self):
- yield {'client':'single_client'}
- yield {'clientMetadata':{'mongos':{'client':'sharding_client'}}}
- def __exit__(self,*arg, **kwargs):
+ yield {"client": "single_client"}
+ yield {"clientMetadata": {"mongos": {"client": "sharding_client"}}}
+
+ def __exit__(self, *arg, **kwargs):
pass
+
mock_conn = Mock()
mock_conn.admin.aggregate.return_value = Aggregate()
mock_get_connection.return_value = mock_conn
- command_types = ['Full','All','Inner','Active']
+ command_types = ["Full", "All", "Inner", "Active"]
for command_type in command_types:
result_set = self.engine.current_op(command_type)
self.assertIsInstance(result_set, ResultSet)
- @patch('sql.engines.mongo.MongoEngine.get_connection')
- def test_get_kill_command(self, mock_get_connection):
- class Aggregate():
+ @patch("sql.engines.mongo.MongoEngine.get_connection")
+ def test_get_kill_command(self, mock_get_connection):
+ class Aggregate:
def __enter__(self):
- yield {'opid': 111}
- yield {'opid': 'shard1: 111'}
- def __exit__(self,*arg, **kwargs):
+ yield {"opid": 111}
+ yield {"opid": "shard1: 111"}
+
+ def __exit__(self, *arg, **kwargs):
pass
+
mock_conn = Mock()
mock_conn.admin.aggregate.return_value = Aggregate()
mock_get_connection.return_value = mock_conn
- kill_command1 = self.engine.get_kill_command([111,222])
- kill_command2 = self.engine.get_kill_command(['shard1: 111','shard2: 222'])
- self.assertEqual(kill_command1, 'db.killOp(111);')
+ kill_command1 = self.engine.get_kill_command([111, 222])
+ kill_command2 = self.engine.get_kill_command(["shard1: 111", "shard2: 222"])
+ self.assertEqual(kill_command1, "db.killOp(111);")
self.assertEqual(kill_command2, 'db.killOp("shard1: 111");')
- @patch('sql.engines.mongo.MongoEngine.get_connection')
+ @patch("sql.engines.mongo.MongoEngine.get_connection")
def test_kill_op(self, mock_get_connection):
def command(self, *arg, **kwargs):
pass
+
mock_conn = Mock()
mock_conn.admin.command.return_value = command
mock_get_connection.return_value = mock_conn
- self.engine.kill_op([111,222])
- self.engine.kill_op(['shards: 111','shards: 222'])
+ self.engine.kill_op([111, 222])
+ self.engine.kill_op(["shards: 111", "shards: 222"])
mock_conn.admin.command.assert_called()
class TestClickHouse(TestCase):
-
def setUp(self):
- self.ins1 = Instance(instance_name='some_ins', type='slave', db_type='clickhouse', host='some_host',
- port=9000, user='ins_user', password='some_str')
+ self.ins1 = Instance(
+ instance_name="some_ins",
+ type="slave",
+ db_type="clickhouse",
+ host="some_host",
+ port=9000,
+ user="ins_user",
+ password="some_str",
+ )
self.ins1.save()
self.sys_config = SysConfig()
self.wf = SqlWorkflow.objects.create(
- workflow_name='some_name',
+ workflow_name="some_name",
group_id=1,
- group_name='g1',
- engineer_display='',
- audit_auth_groups='some_group',
+ group_name="g1",
+ engineer_display="",
+ audit_auth_groups="some_group",
create_time=datetime.now() - timedelta(days=1),
- status='workflow_finish',
+ status="workflow_finish",
is_backup=False,
instance=self.ins1,
- db_name='some_db',
- syntax_type=1
+ db_name="some_db",
+ syntax_type=1,
)
SqlWorkflowContent.objects.create(workflow=self.wf)
@@ -1597,205 +1973,234 @@ def tearDown(self):
SqlWorkflow.objects.all().delete()
SqlWorkflowContent.objects.all().delete()
- @patch.object(ClickHouseEngine, 'query')
+ @patch.object(ClickHouseEngine, "query")
def test_server_version(self, mock_query):
result = ResultSet()
- result.rows = [('ClickHouse 22.1.3.7',)]
+ result.rows = [("ClickHouse 22.1.3.7",)]
mock_query.return_value = result
new_engine = ClickHouseEngine(instance=self.ins1)
server_version = new_engine.server_version
self.assertTupleEqual(server_version, (22, 1, 3))
- @patch.object(ClickHouseEngine, 'query')
+ @patch.object(ClickHouseEngine, "query")
def test_table_engine(self, mock_query):
- table_name = 'default.tb_test'
+ table_name = "default.tb_test"
result = ResultSet()
- result.rows = [('MergeTree',)]
+ result.rows = [("MergeTree",)]
mock_query.return_value = result
new_engine = ClickHouseEngine(instance=self.ins1)
table_engine = new_engine.get_table_engine(table_name)
- self.assertDictEqual(table_engine, {'status': 1, 'engine': 'MergeTree'})
+ self.assertDictEqual(table_engine, {"status": 1, "engine": "MergeTree"})
- @patch('clickhouse_driver.connect')
+ @patch("clickhouse_driver.connect")
def test_engine_base_info(self, _conn):
new_engine = ClickHouseEngine(instance=self.ins1)
- self.assertEqual(new_engine.name, 'ClickHouse')
- self.assertEqual(new_engine.info, 'ClickHouse engine')
+ self.assertEqual(new_engine.name, "ClickHouse")
+ self.assertEqual(new_engine.info, "ClickHouse engine")
- @patch.object(ClickHouseEngine, 'get_connection')
+ @patch.object(ClickHouseEngine, "get_connection")
def testGetConnection(self, connect):
new_engine = ClickHouseEngine(instance=self.ins1)
new_engine.get_connection()
connect.assert_called_once()
- @patch.object(ClickHouseEngine, 'query')
+ @patch.object(ClickHouseEngine, "query")
def testQuery(self, mock_query):
result = ResultSet()
- result.rows = [('v1', 'v2'), ]
+ result.rows = [
+ ("v1", "v2"),
+ ]
mock_query.return_value = result
new_engine = ClickHouseEngine(instance=self.ins1)
- query_result = new_engine.query(sql='some_sql', limit_num=100)
- self.assertListEqual(query_result.rows, [('v1', 'v2'), ])
+ query_result = new_engine.query(sql="some_sql", limit_num=100)
+ self.assertListEqual(
+ query_result.rows,
+ [
+ ("v1", "v2"),
+ ],
+ )
- @patch.object(ClickHouseEngine, 'query')
+ @patch.object(ClickHouseEngine, "query")
def testAllDb(self, mock_query):
db_result = ResultSet()
- db_result.rows = [('db_1',), ('db_2',)]
+ db_result.rows = [("db_1",), ("db_2",)]
mock_query.return_value = db_result
new_engine = ClickHouseEngine(instance=self.ins1)
dbs = new_engine.get_all_databases()
- self.assertEqual(dbs.rows, ['db_1', 'db_2'])
+ self.assertEqual(dbs.rows, ["db_1", "db_2"])
- @patch.object(ClickHouseEngine, 'query')
+ @patch.object(ClickHouseEngine, "query")
def testAllTables(self, mock_query):
table_result = ResultSet()
- table_result.rows = [('tb_1', 'some_des'), ('tb_2', 'some_des')]
+ table_result.rows = [("tb_1", "some_des"), ("tb_2", "some_des")]
mock_query.return_value = table_result
new_engine = ClickHouseEngine(instance=self.ins1)
- tables = new_engine.get_all_tables('some_db')
- mock_query.assert_called_once_with(db_name='some_db', sql=ANY)
- self.assertEqual(tables.rows, ['tb_1', 'tb_2'])
+ tables = new_engine.get_all_tables("some_db")
+ mock_query.assert_called_once_with(db_name="some_db", sql=ANY)
+ self.assertEqual(tables.rows, ["tb_1", "tb_2"])
- @patch.object(ClickHouseEngine, 'query')
+ @patch.object(ClickHouseEngine, "query")
def testAllColumns(self, mock_query):
db_result = ResultSet()
- db_result.rows = [('col_1', 'type'), ('col_2', 'type2')]
+ db_result.rows = [("col_1", "type"), ("col_2", "type2")]
mock_query.return_value = db_result
new_engine = ClickHouseEngine(instance=self.ins1)
- dbs = new_engine.get_all_columns_by_tb('some_db', 'some_tb')
- self.assertEqual(dbs.rows, ['col_1', 'col_2'])
+ dbs = new_engine.get_all_columns_by_tb("some_db", "some_tb")
+ self.assertEqual(dbs.rows, ["col_1", "col_2"])
- @patch.object(ClickHouseEngine, 'query')
+ @patch.object(ClickHouseEngine, "query")
def testDescribe(self, mock_query):
new_engine = ClickHouseEngine(instance=self.ins1)
- new_engine.describe_table('some_db', 'some_db')
+ new_engine.describe_table("some_db", "some_db")
mock_query.assert_called_once()
def test_query_check_wrong_sql(self):
new_engine = ClickHouseEngine(instance=self.ins1)
- wrong_sql = '-- 测试'
- check_result = new_engine.query_check(db_name='some_db', sql=wrong_sql)
- self.assertDictEqual(check_result,
- {'msg': '不支持的查询语法类型!', 'bad_query': True, 'filtered_sql': '-- 测试', 'has_star': False})
+ wrong_sql = "-- 测试"
+ check_result = new_engine.query_check(db_name="some_db", sql=wrong_sql)
+ self.assertDictEqual(
+ check_result,
+ {
+ "msg": "不支持的查询语法类型!",
+ "bad_query": True,
+ "filtered_sql": "-- 测试",
+ "has_star": False,
+ },
+ )
def test_query_check_update_sql(self):
new_engine = ClickHouseEngine(instance=self.ins1)
- update_sql = 'update user set id=0'
- check_result = new_engine.query_check(db_name='some_db', sql=update_sql)
- self.assertDictEqual(check_result,
- {'msg': '不支持的查询语法类型!', 'bad_query': True, 'filtered_sql': 'update user set id=0',
- 'has_star': False})
+ update_sql = "update user set id=0"
+ check_result = new_engine.query_check(db_name="some_db", sql=update_sql)
+ self.assertDictEqual(
+ check_result,
+ {
+ "msg": "不支持的查询语法类型!",
+ "bad_query": True,
+ "filtered_sql": "update user set id=0",
+ "has_star": False,
+ },
+ )
- @patch.object(ClickHouseEngine, 'query')
+ @patch.object(ClickHouseEngine, "query")
def test_explain_check(self, mock_query):
result = ResultSet()
- result.rows = [('ClickHouse 20.1.3.7',)]
+ result.rows = [("ClickHouse 20.1.3.7",)]
mock_query.return_value = result
new_engine = ClickHouseEngine(instance=self.ins1)
server_version = new_engine.server_version
sql = "insert into tb_test(note) values ('xbb');"
check_result = ReviewSet(full_sql=sql)
- explain_result = new_engine.explain_check(check_result, db_name='some_db', line=1, statement=sql)
+ explain_result = new_engine.explain_check(
+ check_result, db_name="some_db", line=1, statement=sql
+ )
self.assertEqual(explain_result.stagestatus, "Audit completed")
def test_execute_check_select_sql(self):
new_engine = ClickHouseEngine(instance=self.ins1)
- select_sql = 'select id,name from tb_test'
- check_result = new_engine.execute_check(db_name='some_db', sql=select_sql)
- self.assertEqual(check_result.rows[0].errormessage, "仅支持DML和DDL语句,查询语句请使用SQL查询功能!")
+ select_sql = "select id,name from tb_test"
+ check_result = new_engine.execute_check(db_name="some_db", sql=select_sql)
+ self.assertEqual(
+ check_result.rows[0].errormessage, "仅支持DML和DDL语句,查询语句请使用SQL查询功能!"
+ )
- @patch.object(ClickHouseEngine, 'query')
+ @patch.object(ClickHouseEngine, "query")
def test_execute_check_alter_sql(self, mock_query):
- table_name = 'default.tb_test'
+ table_name = "default.tb_test"
result = ResultSet()
- result.rows = [('Log',)]
+ result.rows = [("Log",)]
mock_query.return_value = result
new_engine = ClickHouseEngine(instance=self.ins1)
table_engine = new_engine.get_table_engine(table_name)
alter_sql = "alter table default.tb_test add column remark String"
- check_result = new_engine.execute_check(db_name='some_db', sql=alter_sql)
- self.assertEqual(check_result.rows[0].errormessage, "ALTER TABLE仅支持*MergeTree,Merge以及Distributed等引擎表!")
+ check_result = new_engine.execute_check(db_name="some_db", sql=alter_sql)
+ self.assertEqual(
+ check_result.rows[0].errormessage,
+ "ALTER TABLE仅支持*MergeTree,Merge以及Distributed等引擎表!",
+ )
def test_filter_sql_with_delimiter(self):
new_engine = ClickHouseEngine(instance=self.ins1)
- sql_without_limit = 'select user from usertable;'
+ sql_without_limit = "select user from usertable;"
check_result = new_engine.filter_sql(sql=sql_without_limit, limit_num=100)
- self.assertEqual(check_result, 'select user from usertable limit 100;')
+ self.assertEqual(check_result, "select user from usertable limit 100;")
def test_filter_sql_without_delimiter(self):
new_engine = ClickHouseEngine(instance=self.ins1)
- sql_without_limit = 'select user from usertable'
+ sql_without_limit = "select user from usertable"
check_result = new_engine.filter_sql(sql=sql_without_limit, limit_num=100)
- self.assertEqual(check_result, 'select user from usertable limit 100;')
+ self.assertEqual(check_result, "select user from usertable limit 100;")
def test_filter_sql_with_limit(self):
new_engine = ClickHouseEngine(instance=self.ins1)
- sql_without_limit = 'select user from usertable limit 10'
+ sql_without_limit = "select user from usertable limit 10"
check_result = new_engine.filter_sql(sql=sql_without_limit, limit_num=1)
- self.assertEqual(check_result, 'select user from usertable limit 1;')
+ self.assertEqual(check_result, "select user from usertable limit 1;")
def test_filter_sql_with_limit_min(self):
new_engine = ClickHouseEngine(instance=self.ins1)
- sql_without_limit = 'select user from usertable limit 10'
+ sql_without_limit = "select user from usertable limit 10"
check_result = new_engine.filter_sql(sql=sql_without_limit, limit_num=100)
- self.assertEqual(check_result, 'select user from usertable limit 10;')
+ self.assertEqual(check_result, "select user from usertable limit 10;")
def test_filter_sql_with_limit_offset(self):
new_engine = ClickHouseEngine(instance=self.ins1)
- sql_without_limit = 'select user from usertable limit 10 offset 100'
+ sql_without_limit = "select user from usertable limit 10 offset 100"
check_result = new_engine.filter_sql(sql=sql_without_limit, limit_num=1)
- self.assertEqual(check_result, 'select user from usertable limit 1 offset 100;')
+ self.assertEqual(check_result, "select user from usertable limit 1 offset 100;")
def test_filter_sql_with_limit_nn(self):
new_engine = ClickHouseEngine(instance=self.ins1)
- sql_without_limit = 'select user from usertable limit 10, 100'
+ sql_without_limit = "select user from usertable limit 10, 100"
check_result = new_engine.filter_sql(sql=sql_without_limit, limit_num=1)
- self.assertEqual(check_result, 'select user from usertable limit 10,1;')
+ self.assertEqual(check_result, "select user from usertable limit 10,1;")
def test_filter_sql_upper(self):
new_engine = ClickHouseEngine(instance=self.ins1)
- sql_without_limit = 'SELECT USER FROM usertable LIMIT 10, 100'
+ sql_without_limit = "SELECT USER FROM usertable LIMIT 10, 100"
check_result = new_engine.filter_sql(sql=sql_without_limit, limit_num=1)
- self.assertEqual(check_result, 'SELECT USER FROM usertable limit 10,1;')
+ self.assertEqual(check_result, "SELECT USER FROM usertable limit 10,1;")
def test_filter_sql_not_select(self):
new_engine = ClickHouseEngine(instance=self.ins1)
- sql_without_limit = 'show create table usertable;'
+ sql_without_limit = "show create table usertable;"
check_result = new_engine.filter_sql(sql=sql_without_limit, limit_num=1)
- self.assertEqual(check_result, 'show create table usertable;')
+ self.assertEqual(check_result, "show create table usertable;")
- @patch('clickhouse_driver.connect.cursor.execute')
- @patch('clickhouse_driver.connect.cursor')
- @patch('clickhouse_driver.connect')
+ @patch("clickhouse_driver.connect.cursor.execute")
+ @patch("clickhouse_driver.connect.cursor")
+ @patch("clickhouse_driver.connect")
def test_execute(self, _connect, _cursor, _execute):
new_engine = ClickHouseEngine(instance=self.ins1)
execute_result = new_engine.execute(self.wf)
self.assertIsInstance(execute_result, ResultSet)
- @patch('clickhouse_driver.connect.cursor.execute')
- @patch('clickhouse_driver.connect.cursor')
- @patch('clickhouse_driver.connect')
+ @patch("clickhouse_driver.connect.cursor.execute")
+ @patch("clickhouse_driver.connect.cursor")
+ @patch("clickhouse_driver.connect")
def test_execute_workflow_success(self, _conn, _cursor, _execute):
sql = "insert into tb_test values('test')"
- row = ReviewResult(id=1,
- errlevel=0,
- stagestatus='Execute Successfully',
- errormessage='None',
- sql=sql,
- affected_rows=0,
- execute_time=0)
+ row = ReviewResult(
+ id=1,
+ errlevel=0,
+ stagestatus="Execute Successfully",
+ errormessage="None",
+ sql=sql,
+ affected_rows=0,
+ execute_time=0,
+ )
wf = SqlWorkflow.objects.create(
- workflow_name='some_name',
+ workflow_name="some_name",
group_id=1,
- group_name='g1',
- engineer_display='',
- audit_auth_groups='some_group',
+ group_name="g1",
+ engineer_display="",
+ audit_auth_groups="some_group",
create_time=datetime.now() - timedelta(days=1),
- status='workflow_finish',
+ status="workflow_finish",
is_backup=False,
instance=self.ins1,
- db_name='some_db',
- syntax_type=1
+ db_name="some_db",
+ syntax_type=1,
)
SqlWorkflowContent.objects.create(workflow=wf, sql_content=sql)
new_engine = ClickHouseEngine(instance=self.ins1)
@@ -1806,22 +2211,29 @@ def test_execute_workflow_success(self, _conn, _cursor, _execute):
class ODPSTest(TestCase):
def setUp(self) -> None:
- self.ins = Instance.objects.create(instance_name='some_ins', type='slave', db_type='odps',
- host='some_host', port=9200, user='ins_user', db_name='some_db')
+ self.ins = Instance.objects.create(
+ instance_name="some_ins",
+ type="slave",
+ db_type="odps",
+ host="some_host",
+ port=9200,
+ user="ins_user",
+ db_name="some_db",
+ )
self.engine = ODPSEngine(instance=self.ins)
def tearDown(self) -> None:
self.ins.delete()
- @patch('sql.engines.odps.ODPSEngine.get_connection')
+ @patch("sql.engines.odps.ODPSEngine.get_connection")
def test_get_connection(self, mock_odps):
_ = self.engine.get_connection()
mock_odps.assert_called_once()
- @patch('sql.engines.odps.ODPSEngine.get_connection')
+ @patch("sql.engines.odps.ODPSEngine.get_connection")
def test_query(self, mock_get_connection):
test_sql = """select 123"""
- self.assertIsInstance(self.engine.query('some_db', test_sql), ResultSet)
+ self.assertIsInstance(self.engine.query("some_db", test_sql), ResultSet)
def test_query_check(self):
test_sql = """select 123; -- this is comment
@@ -1843,20 +2255,20 @@ def test_query_check_error(self):
self.assertIsInstance(check_result, dict)
self.assertEqual(True, check_result.get("bad_query"))
- @patch('sql.engines.odps.ODPSEngine.get_connection')
+ @patch("sql.engines.odps.ODPSEngine.get_connection")
def test_get_all_databases(self, mock_get_connection):
mock_conn = Mock()
mock_conn.exist_project.return_value = True
- mock_conn.project = 'some_db'
+ mock_conn.project = "some_db"
mock_get_connection.return_value = mock_conn
result = self.engine.get_all_databases()
self.assertIsInstance(result, ResultSet)
- self.assertEqual(result.rows, ['some_db'])
+ self.assertEqual(result.rows, ["some_db"])
- @patch('sql.engines.odps.ODPSEngine.get_connection')
+ @patch("sql.engines.odps.ODPSEngine.get_connection")
def test_get_all_tables(self, mock_get_connection):
# 下面是查表示例返回结果
class T:
@@ -1864,33 +2276,35 @@ def __init__(self, name):
self.name = name
mock_conn = Mock()
- mock_conn.list_tables.return_value = [T('u'), T('v'), T('w')]
+ mock_conn.list_tables.return_value = [T("u"), T("v"), T("w")]
mock_get_connection.return_value = mock_conn
- table_list = self.engine.get_all_tables('some_db')
+ table_list = self.engine.get_all_tables("some_db")
- self.assertEqual(table_list.rows, ['u', 'v', 'w'])
+ self.assertEqual(table_list.rows, ["u", "v", "w"])
- @patch('sql.engines.odps.ODPSEngine.get_all_columns_by_tb')
+ @patch("sql.engines.odps.ODPSEngine.get_all_columns_by_tb")
def test_describe_table(self, mock_get_all_columns_by_tb):
- self.engine.describe_table('some_db', 'some_table')
+ self.engine.describe_table("some_db", "some_table")
mock_get_all_columns_by_tb.assert_called_once()
- @patch('sql.engines.odps.ODPSEngine.get_connection')
+ @patch("sql.engines.odps.ODPSEngine.get_connection")
def test_get_all_columns_by_tb(self, mock_get_connection):
mock_conn = Mock()
mock_cols = Mock()
mock_col = Mock()
- mock_col.name, mock_col.type, mock_col.comment = 'XiaoMing', 'string', 'name'
+ mock_col.name, mock_col.type, mock_col.comment = "XiaoMing", "string", "name"
mock_cols.schema.columns = [mock_col]
mock_conn.get_table.return_value = mock_cols
mock_get_connection.return_value = mock_conn
- result = self.engine.get_all_columns_by_tb('some_db', 'some_table')
+ result = self.engine.get_all_columns_by_tb("some_db", "some_table")
mock_get_connection.assert_called_once()
mock_conn.get_table.assert_called_once()
- self.assertEqual(result.rows, [['XiaoMing', 'string', 'name']])
- self.assertEqual(result.column_list, ['COLUMN_NAME', 'COLUMN_TYPE', 'COLUMN_COMMENT'])
+ self.assertEqual(result.rows, [["XiaoMing", "string", "name"]])
+ self.assertEqual(
+ result.column_list, ["COLUMN_NAME", "COLUMN_TYPE", "COLUMN_COMMENT"]
+ )
diff --git a/sql/form.py b/sql/form.py
index 7459dd9f6f..3187b08833 100644
--- a/sql/form.py
+++ b/sql/form.py
@@ -18,16 +18,18 @@ class Meta:
model = Tunnel
fields = "__all__"
widgets = {
- 'PKey': Textarea(attrs={'cols': 40, 'rows': 8}),
+ "PKey": Textarea(attrs={"cols": 40, "rows": 8}),
}
def clean(self):
cleaned_data = super().clean()
- if cleaned_data.get('pkey_path'):
+ if cleaned_data.get("pkey_path"):
try:
- pkey_path = cleaned_data.get('pkey_path').read()
+ pkey_path = cleaned_data.get("pkey_path").read()
if pkey_path:
- cleaned_data['pkey'] = str(pkey_path, 'utf-8').replace(r'\r', '').replace(r'\n', '')
+ cleaned_data["pkey"] = (
+ str(pkey_path, "utf-8").replace(r"\r", "").replace(r"\n", "")
+ )
except IOError:
raise ValidationError("秘钥文件不存在, 请勾选秘钥路径的清除选项再进行保存")
@@ -35,4 +37,7 @@ def clean(self):
class InstanceForm(ModelForm):
class Media:
model = Instance
- js = ('jquery/jquery.min.js', 'dist/js/utils.js', )
+ js = (
+ "jquery/jquery.min.js",
+ "dist/js/utils.js",
+ )
diff --git a/sql/instance.py b/sql/instance.py
index 6b75676e0a..9e31fca3ec 100644
--- a/sql/instance.py
+++ b/sql/instance.py
@@ -19,30 +19,30 @@
from .models import Instance, ParamTemplate, ParamHistory
-@permission_required('sql.menu_instance_list', raise_exception=True)
+@permission_required("sql.menu_instance_list", raise_exception=True)
def lists(request):
"""获取实例列表"""
- limit = int(request.POST.get('limit'))
- offset = int(request.POST.get('offset'))
- type = request.POST.get('type')
- db_type = request.POST.get('db_type')
- tags = request.POST.getlist('tags[]')
+ limit = int(request.POST.get("limit"))
+ offset = int(request.POST.get("offset"))
+ type = request.POST.get("type")
+ db_type = request.POST.get("db_type")
+ tags = request.POST.getlist("tags[]")
limit = offset + limit
- search = request.POST.get('search', '')
- sortName = str(request.POST.get('sortName'))
- sortOrder = str(request.POST.get('sortOrder')).lower()
+ search = request.POST.get("search", "")
+ sortName = str(request.POST.get("sortName"))
+ sortOrder = str(request.POST.get("sortOrder")).lower()
# 组合筛选项
filter_dict = dict()
# 过滤搜索
if search:
- filter_dict['instance_name__icontains'] = search
+ filter_dict["instance_name__icontains"] = search
# 过滤实例类型
if type:
- filter_dict['type'] = type
+ filter_dict["type"] = type
# 过滤数据库类型
if db_type:
- filter_dict['db_type'] = db_type
+ filter_dict["db_type"] = db_type
instances = Instance.objects.filter(**filter_dict)
# 过滤标签,返回同时包含全部标签的实例,TODO 循环会生成多表JOIN,如果数据量大会存在效率问题
@@ -51,41 +51,57 @@ def lists(request):
instances = instances.filter(instance_tag=tag, instance_tag__active=True)
count = instances.count()
- if sortName == 'instance_name':
- instances = instances.order_by(getattr(Convert(sortName, 'gbk'), sortOrder)())[offset:limit]
+ if sortName == "instance_name":
+ instances = instances.order_by(getattr(Convert(sortName, "gbk"), sortOrder)())[
+ offset:limit
+ ]
else:
- instances = instances.order_by('-' + sortName if sortOrder == 'desc' else sortName)[offset:limit]
- instances = instances.values("id", "instance_name", "db_type", "type", "host", "port", "user")
+ instances = instances.order_by(
+ "-" + sortName if sortOrder == "desc" else sortName
+ )[offset:limit]
+ instances = instances.values(
+ "id", "instance_name", "db_type", "type", "host", "port", "user"
+ )
# QuerySet 序列化
rows = [row for row in instances]
result = {"total": count, "rows": rows}
- return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
- content_type='application/json')
+ return HttpResponse(
+ json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
+ content_type="application/json",
+ )
-@permission_required('sql.param_view', raise_exception=True)
+@permission_required("sql.param_view", raise_exception=True)
def param_list(request):
"""
获取实例参数列表
:param request:
:return:
"""
- instance_id = request.POST.get('instance_id')
- editable = True if request.POST.get('editable') else False
- search = request.POST.get('search', '')
+ instance_id = request.POST.get("instance_id")
+ editable = True if request.POST.get("editable") else False
+ search = request.POST.get("search", "")
try:
ins = Instance.objects.get(id=instance_id)
except Instance.DoesNotExist:
- result = {'status': 1, 'msg': '实例不存在', 'data': []}
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result = {"status": 1, "msg": "实例不存在", "data": []}
+ return HttpResponse(json.dumps(result), content_type="application/json")
# 获取已配置参数列表
cnf_params = dict()
- for param in ParamTemplate.objects.filter(db_type=ins.db_type, variable_name__contains=search).values(
- 'id', 'variable_name', 'default_value', 'valid_values', 'description', 'editable'):
- param['variable_name'] = param['variable_name'].lower()
- cnf_params[param['variable_name']] = param
+ for param in ParamTemplate.objects.filter(
+ db_type=ins.db_type, variable_name__contains=search
+ ).values(
+ "id",
+ "variable_name",
+ "default_value",
+ "valid_values",
+ "description",
+ "editable",
+ ):
+ param["variable_name"] = param["variable_name"].lower()
+ cnf_params[param["variable_name"]] = param
# 获取实例参数列表
engine = get_engine(instance=ins)
ins_variables = engine.get_variables()
@@ -94,73 +110,85 @@ def param_list(request):
for variable in ins_variables.rows:
variable_name = variable[0].lower()
row = {
- 'variable_name': variable_name,
- 'runtime_value': variable[1],
- 'editable': False,
+ "variable_name": variable_name,
+ "runtime_value": variable[1],
+ "editable": False,
}
if variable_name in cnf_params.keys():
row = dict(row, **cnf_params[variable_name])
rows.append(row)
# 过滤参数
if editable:
- rows = [row for row in rows if row['editable']]
+ rows = [row for row in rows if row["editable"]]
else:
- rows = [row for row in rows if not row['editable']]
- return HttpResponse(json.dumps(rows, cls=ExtendJSONEncoder, bigint_as_string=True),
- content_type='application/json')
+ rows = [row for row in rows if not row["editable"]]
+ return HttpResponse(
+ json.dumps(rows, cls=ExtendJSONEncoder, bigint_as_string=True),
+ content_type="application/json",
+ )
-@permission_required('sql.param_view', raise_exception=True)
+@permission_required("sql.param_view", raise_exception=True)
def param_history(request):
"""实例参数修改历史"""
- limit = int(request.POST.get('limit'))
- offset = int(request.POST.get('offset'))
+ limit = int(request.POST.get("limit"))
+ offset = int(request.POST.get("offset"))
limit = offset + limit
- instance_id = request.POST.get('instance_id')
- search = request.POST.get('search', '')
+ instance_id = request.POST.get("instance_id")
+ search = request.POST.get("search", "")
phs = ParamHistory.objects.filter(instance__id=instance_id)
# 过滤搜索条件
if search:
phs = ParamHistory.objects.filter(variable_name__contains=search)
count = phs.count()
- phs = phs[offset:limit].values("instance__instance_name", "variable_name", "old_var", "new_var",
- "user_display", "create_time")
+ phs = phs[offset:limit].values(
+ "instance__instance_name",
+ "variable_name",
+ "old_var",
+ "new_var",
+ "user_display",
+ "create_time",
+ )
# QuerySet 序列化
rows = [row for row in phs]
result = {"total": count, "rows": rows}
- return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
- content_type='application/json')
+ return HttpResponse(
+ json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
+ content_type="application/json",
+ )
-@permission_required('sql.param_edit', raise_exception=True)
+@permission_required("sql.param_edit", raise_exception=True)
def param_edit(request):
user = request.user
- instance_id = request.POST.get('instance_id')
- variable_name = request.POST.get('variable_name')
- variable_value = request.POST.get('runtime_value')
+ instance_id = request.POST.get("instance_id")
+ variable_name = request.POST.get("variable_name")
+ variable_value = request.POST.get("runtime_value")
try:
ins = Instance.objects.get(id=instance_id)
except Instance.DoesNotExist:
- result = {'status': 1, 'msg': '实例不存在', 'data': []}
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result = {"status": 1, "msg": "实例不存在", "data": []}
+ return HttpResponse(json.dumps(result), content_type="application/json")
# 修改参数
engine = get_engine(instance=ins)
# 校验是否配置模板
if not ParamTemplate.objects.filter(variable_name=variable_name).exists():
- result = {'status': 1, 'msg': '请先在参数模板中配置该参数!', 'data': []}
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result = {"status": 1, "msg": "请先在参数模板中配置该参数!", "data": []}
+ return HttpResponse(json.dumps(result), content_type="application/json")
# 获取当前运行参数值
runtime_value = engine.get_variables(variables=[variable_name]).rows[0][1]
if variable_value == runtime_value:
- result = {'status': 1, 'msg': '参数值与实际运行值一致,未调整!', 'data': []}
- return HttpResponse(json.dumps(result), content_type='application/json')
- set_result = engine.set_variable(variable_name=variable_name, variable_value=variable_value)
+ result = {"status": 1, "msg": "参数值与实际运行值一致,未调整!", "data": []}
+ return HttpResponse(json.dumps(result), content_type="application/json")
+ set_result = engine.set_variable(
+ variable_name=variable_name, variable_value=variable_value
+ )
if set_result.error:
- result = {'status': 1, 'msg': f'设置错误,错误信息:{set_result.error}', 'data': []}
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result = {"status": 1, "msg": f"设置错误,错误信息:{set_result.error}", "data": []}
+ return HttpResponse(json.dumps(result), content_type="application/json")
# 修改成功的保存修改记录
else:
ParamHistory.objects.create(
@@ -170,27 +198,31 @@ def param_edit(request):
new_var=variable_value,
set_sql=set_result.full_sql,
user_name=user.username,
- user_display=user.display
+ user_display=user.display,
)
- result = {'status': 0, 'msg': '修改成功,请手动持久化到配置文件!', 'data': []}
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result = {"status": 0, "msg": "修改成功,请手动持久化到配置文件!", "data": []}
+ return HttpResponse(json.dumps(result), content_type="application/json")
-@permission_required('sql.menu_schemasync', raise_exception=True)
+@permission_required("sql.menu_schemasync", raise_exception=True)
def schemasync(request):
"""对比实例schema信息"""
- instance_name = request.POST.get('instance_name')
- db_name = request.POST.get('db_name')
- target_instance_name = request.POST.get('target_instance_name')
- target_db_name = request.POST.get('target_db_name')
- sync_auto_inc = True if request.POST.get('sync_auto_inc') == 'true' else False
- sync_comments = True if request.POST.get('sync_comments') == 'true' else False
- result = {'status': 0, 'msg': 'ok', 'data': {'diff_stdout': '', 'patch_stdout': '', 'revert_stdout': ''}}
+ instance_name = request.POST.get("instance_name")
+ db_name = request.POST.get("db_name")
+ target_instance_name = request.POST.get("target_instance_name")
+ target_db_name = request.POST.get("target_db_name")
+ sync_auto_inc = True if request.POST.get("sync_auto_inc") == "true" else False
+ sync_comments = True if request.POST.get("sync_comments") == "true" else False
+ result = {
+ "status": 0,
+ "msg": "ok",
+ "data": {"diff_stdout": "", "patch_stdout": "", "revert_stdout": ""},
+ }
# 循环对比全部数据库
- if db_name == 'all' or target_db_name == 'all':
- db_name = '*'
- target_db_name = '*'
+ if db_name == "all" or target_db_name == "all":
+ db_name = "*"
+ target_db_name = "*"
# 取出该实例的连接方式
instance_info = Instance.objects.get(instance_name=instance_name)
@@ -200,7 +232,7 @@ def schemasync(request):
schema_sync = SchemaSync()
# 准备参数
tag = int(time.time())
- output_directory = os.path.join(settings.BASE_DIR, 'downloads/schemasync/')
+ output_directory = os.path.join(settings.BASE_DIR, "downloads/schemasync/")
os.makedirs(output_directory, exist_ok=True)
db_name = shlex.quote(db_name)
target_db_name = shlex.quote(target_db_name)
@@ -209,50 +241,74 @@ def schemasync(request):
"sync-comments": sync_comments,
"tag": tag,
"output-directory": output_directory,
- "source": r"mysql://{user}:{pwd}@{host}:{port}/{database}".format(user=shlex.quote(str(instance_info.user)),
- pwd=shlex.quote(str(instance_info.password)),
- host=shlex.quote(str(instance_info.host)),
- port=shlex.quote(str(instance_info.port)),
- database=db_name),
- "target": r"mysql://{user}:{pwd}@{host}:{port}/{database}".format(user=shlex.quote(str(target_instance_info.user)),
- pwd=shlex.quote(str(target_instance_info.password)),
- host=shlex.quote(str(target_instance_info.host)),
- port=shlex.quote(str(target_instance_info.port)),
- database=target_db_name)
+ "source": r"mysql://{user}:{pwd}@{host}:{port}/{database}".format(
+ user=shlex.quote(str(instance_info.user)),
+ pwd=shlex.quote(str(instance_info.password)),
+ host=shlex.quote(str(instance_info.host)),
+ port=shlex.quote(str(instance_info.port)),
+ database=db_name,
+ ),
+ "target": r"mysql://{user}:{pwd}@{host}:{port}/{database}".format(
+ user=shlex.quote(str(target_instance_info.user)),
+ pwd=shlex.quote(str(target_instance_info.password)),
+ host=shlex.quote(str(target_instance_info.host)),
+ port=shlex.quote(str(target_instance_info.port)),
+ database=target_db_name,
+ ),
}
# 参数检查
args_check_result = schema_sync.check_args(args)
- if args_check_result['status'] == 1:
- return HttpResponse(json.dumps(args_check_result), content_type='application/json')
+ if args_check_result["status"] == 1:
+ return HttpResponse(
+ json.dumps(args_check_result), content_type="application/json"
+ )
# 参数转换
cmd_args = schema_sync.generate_args2cmd(args, shell=True)
# 执行命令
try:
stdout, stderr = schema_sync.execute_cmd(cmd_args, shell=True).communicate()
- diff_stdout = f'{stdout}{stderr}'
+ diff_stdout = f"{stdout}{stderr}"
except RuntimeError as e:
diff_stdout = str(e)
# 非全部数据库对比可以读取对比结果并在前端展示
- if db_name != '*':
+ if db_name != "*":
date = time.strftime("%Y%m%d", time.localtime())
- patch_sql_file = '%s%s_%s.%s.patch.sql' % (output_directory, target_db_name, tag, date)
- revert_sql_file = '%s%s_%s.%s.revert.sql' % (output_directory, target_db_name, tag, date)
+ patch_sql_file = "%s%s_%s.%s.patch.sql" % (
+ output_directory,
+ target_db_name,
+ tag,
+ date,
+ )
+ revert_sql_file = "%s%s_%s.%s.revert.sql" % (
+ output_directory,
+ target_db_name,
+ tag,
+ date,
+ )
try:
- with open(patch_sql_file, 'r') as f:
+ with open(patch_sql_file, "r") as f:
patch_sql = f.read()
except FileNotFoundError as e:
patch_sql = str(e)
try:
- with open(revert_sql_file, 'r') as f:
+ with open(revert_sql_file, "r") as f:
revert_sql = f.read()
except FileNotFoundError as e:
revert_sql = str(e)
- result['data'] = {'diff_stdout': diff_stdout, 'patch_stdout': patch_sql, 'revert_stdout': revert_sql}
+ result["data"] = {
+ "diff_stdout": diff_stdout,
+ "patch_stdout": patch_sql,
+ "revert_stdout": revert_sql,
+ }
else:
- result['data'] = {'diff_stdout': diff_stdout, 'patch_stdout': '', 'revert_stdout': ''}
+ result["data"] = {
+ "diff_stdout": diff_stdout,
+ "patch_stdout": "",
+ "revert_stdout": "",
+ }
- return HttpResponse(json.dumps(result), content_type='application/json')
+ return HttpResponse(json.dumps(result), content_type="application/json")
@cache_page(60 * 5, key_prefix="insRes")
@@ -262,73 +318,79 @@ def instance_resource(request):
:param request:
:return:
"""
- instance_id = request.GET.get('instance_id')
- instance_name = request.GET.get('instance_name')
- db_name = request.GET.get('db_name', '')
- schema_name = request.GET.get('schema_name', '')
- tb_name = request.GET.get('tb_name', '')
+ instance_id = request.GET.get("instance_id")
+ instance_name = request.GET.get("instance_name")
+ db_name = request.GET.get("db_name", "")
+ schema_name = request.GET.get("schema_name", "")
+ tb_name = request.GET.get("tb_name", "")
- resource_type = request.GET.get('resource_type')
+ resource_type = request.GET.get("resource_type")
if instance_id:
instance = Instance.objects.get(id=instance_id)
else:
try:
instance = Instance.objects.get(instance_name=instance_name)
except Instance.DoesNotExist:
- result = {'status': 1, 'msg': '实例不存在', 'data': []}
- return HttpResponse(json.dumps(result), content_type='application/json')
- result = {'status': 0, 'msg': 'ok', 'data': []}
+ result = {"status": 1, "msg": "实例不存在", "data": []}
+ return HttpResponse(json.dumps(result), content_type="application/json")
+ result = {"status": 0, "msg": "ok", "data": []}
try:
# escape
- db_name = MySQLdb.escape_string(db_name).decode('utf-8')
- schema_name = MySQLdb.escape_string(schema_name).decode('utf-8')
- tb_name = MySQLdb.escape_string(tb_name).decode('utf-8')
+ db_name = MySQLdb.escape_string(db_name).decode("utf-8")
+ schema_name = MySQLdb.escape_string(schema_name).decode("utf-8")
+ tb_name = MySQLdb.escape_string(tb_name).decode("utf-8")
query_engine = get_engine(instance=instance)
- if resource_type == 'database':
+ if resource_type == "database":
resource = query_engine.get_all_databases()
- elif resource_type == 'schema' and db_name:
+ elif resource_type == "schema" and db_name:
resource = query_engine.get_all_schemas(db_name=db_name)
- elif resource_type == 'table' and db_name:
- resource = query_engine.get_all_tables(db_name=db_name, schema_name=schema_name)
- elif resource_type == 'column' and db_name and tb_name:
- resource = query_engine.get_all_columns_by_tb(db_name=db_name, tb_name=tb_name, schema_name=schema_name)
+ elif resource_type == "table" and db_name:
+ resource = query_engine.get_all_tables(
+ db_name=db_name, schema_name=schema_name
+ )
+ elif resource_type == "column" and db_name and tb_name:
+ resource = query_engine.get_all_columns_by_tb(
+ db_name=db_name, tb_name=tb_name, schema_name=schema_name
+ )
else:
- raise TypeError('不支持的资源类型或者参数不完整!')
+ raise TypeError("不支持的资源类型或者参数不完整!")
except Exception as msg:
- result['status'] = 1
- result['msg'] = str(msg)
+ result["status"] = 1
+ result["msg"] = str(msg)
else:
if resource.error:
- result['status'] = 1
- result['msg'] = resource.error
+ result["status"] = 1
+ result["msg"] = resource.error
else:
- result['data'] = resource.rows
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result["data"] = resource.rows
+ return HttpResponse(json.dumps(result), content_type="application/json")
def describe(request):
"""获取表结构"""
- instance_name = request.POST.get('instance_name')
+ instance_name = request.POST.get("instance_name")
try:
instance = Instance.objects.get(instance_name=instance_name)
except Instance.DoesNotExist:
- result = {'status': 1, 'msg': '实例不存在', 'data': []}
- return HttpResponse(json.dumps(result), content_type='application/json')
- db_name = request.POST.get('db_name')
- schema_name = request.POST.get('schema_name')
- tb_name = request.POST.get('tb_name')
- result = {'status': 0, 'msg': 'ok', 'data': []}
+ result = {"status": 1, "msg": "实例不存在", "data": []}
+ return HttpResponse(json.dumps(result), content_type="application/json")
+ db_name = request.POST.get("db_name")
+ schema_name = request.POST.get("schema_name")
+ tb_name = request.POST.get("tb_name")
+ result = {"status": 0, "msg": "ok", "data": []}
try:
query_engine = get_engine(instance=instance)
- query_result = query_engine.describe_table(db_name, tb_name, schema_name=schema_name)
- result['data'] = query_result.__dict__
+ query_result = query_engine.describe_table(
+ db_name, tb_name, schema_name=schema_name
+ )
+ result["data"] = query_result.__dict__
except Exception as msg:
- result['status'] = 1
- result['msg'] = str(msg)
- if result['data']['error']:
- result['status'] = 1
- result['msg'] = result['data']['error']
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result["status"] = 1
+ result["msg"] = str(msg)
+ if result["data"]["error"]:
+ result["status"] = 1
+ result["msg"] = result["data"]["error"]
+ return HttpResponse(json.dumps(result), content_type="application/json")
diff --git a/sql/instance_account.py b/sql/instance_account.py
index 7af01a3436..06b71fe961 100644
--- a/sql/instance_account.py
+++ b/sql/instance_account.py
@@ -13,23 +13,25 @@
from .models import Instance, InstanceAccount
-@permission_required('sql.menu_instance_account', raise_exception=True)
+@permission_required("sql.menu_instance_account", raise_exception=True)
def users(request):
"""获取实例用户列表"""
- instance_id = request.POST.get('instance_id')
- saved = True if request.POST.get('saved') == 'true' else False # 平台是否保存
+ instance_id = request.POST.get("instance_id")
+ saved = True if request.POST.get("saved") == "true" else False # 平台是否保存
if not instance_id:
- return JsonResponse({'status': 0, 'msg': '', 'data': []})
+ return JsonResponse({"status": 0, "msg": "", "data": []})
try:
- instance = user_instances(request.user, db_type=['mysql']).get(id=instance_id)
+ instance = user_instances(request.user, db_type=["mysql"]).get(id=instance_id)
except Instance.DoesNotExist:
- return JsonResponse({'status': 1, 'msg': '你所在组未关联该实例', 'data': []})
+ return JsonResponse({"status": 1, "msg": "你所在组未关联该实例", "data": []})
# 获取已录入用户
cnf_users = dict()
- for user in InstanceAccount.objects.filter(instance=instance).values('id', 'user', 'host', 'remark'):
- user['saved'] = True
+ for user in InstanceAccount.objects.filter(instance=instance).values(
+ "id", "user", "host", "remark"
+ ):
+ user["saved"] = True
cnf_users[f"`{user['user']}`@`{user['host']}`"] = user
# 获取所有用户
query_engine = get_engine(instance=instance)
@@ -39,21 +41,23 @@ def users(request):
sql_get_user = "select concat('`', user, '`', '@', '`', host,'`') as query,user,host,account_locked from mysql.user;"
else:
sql_get_user = "select concat('`', user, '`', '@', '`', host,'`') as query,user,host from mysql.user;"
- query_result = query_engine.query('mysql', sql_get_user)
+ query_result = query_engine.query("mysql", sql_get_user)
if not query_result.error:
db_users = query_result.rows
# 获取用户权限信息
rows = []
for db_user in db_users:
user_host = db_user[0]
- user_priv = query_engine.query('mysql', 'show grants for {};'.format(user_host), close_conn=False).rows
+ user_priv = query_engine.query(
+ "mysql", "show grants for {};".format(user_host), close_conn=False
+ ).rows
row = {
- 'user_host': user_host,
- 'user': db_user[1],
- 'host': db_user[2],
- 'privileges': user_priv,
- 'saved': False,
- 'is_locked': db_user[3] if server_version >= (5, 7, 6) else None
+ "user_host": user_host,
+ "user": db_user[1],
+ "host": db_user[2],
+ "privileges": user_priv,
+ "saved": False,
+ "is_locked": db_user[3] if server_version >= (5, 7, 6) else None,
}
# 合并数据
if user_host in cnf_users.keys():
@@ -61,123 +65,129 @@ def users(request):
rows.append(row)
# 过滤参数
if saved:
- rows = [row for row in rows if row['saved']]
+ rows = [row for row in rows if row["saved"]]
- result = {'status': 0, 'msg': 'ok', 'rows': rows}
+ result = {"status": 0, "msg": "ok", "rows": rows}
else:
- result = {'status': 1, 'msg': query_result.error}
+ result = {"status": 1, "msg": query_result.error}
# 关闭连接
query_engine.close()
- return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
- content_type='application/json')
+ return HttpResponse(
+ json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
+ content_type="application/json",
+ )
-@permission_required('sql.instance_account_manage', raise_exception=True)
+@permission_required("sql.instance_account_manage", raise_exception=True)
def create(request):
"""创建数据库账号"""
- instance_id = request.POST.get('instance_id', 0)
- user = request.POST.get('user')
- host = request.POST.get('host')
- password1 = request.POST.get('password1')
- password2 = request.POST.get('password2')
- remark = request.POST.get('remark', '')
+ instance_id = request.POST.get("instance_id", 0)
+ user = request.POST.get("user")
+ host = request.POST.get("host")
+ password1 = request.POST.get("password1")
+ password2 = request.POST.get("password2")
+ remark = request.POST.get("remark", "")
try:
- instance = user_instances(request.user, db_type=['mysql']).get(id=instance_id)
+ instance = user_instances(request.user, db_type=["mysql"]).get(id=instance_id)
except Instance.DoesNotExist:
- return JsonResponse({'status': 1, 'msg': '你所在组未关联该实例', 'data': []})
+ return JsonResponse({"status": 1, "msg": "你所在组未关联该实例", "data": []})
if not all([user, host, password1, password2]):
- return JsonResponse({'status': 1, 'msg': '参数不完整,请确认后提交', 'data': []})
+ return JsonResponse({"status": 1, "msg": "参数不完整,请确认后提交", "data": []})
if password1 != password2:
- return JsonResponse({'status': 1, 'msg': '两次输入密码不一致', 'data': []})
+ return JsonResponse({"status": 1, "msg": "两次输入密码不一致", "data": []})
# TODO 目前使用系统自带验证,后续实现验证器校验
try:
validate_password(password1, user=None, password_validators=None)
except ValidationError as msg:
- return JsonResponse({'status': 1, 'msg': f'{msg}', 'data': []})
+ return JsonResponse({"status": 1, "msg": f"{msg}", "data": []})
# escape
- user = MySQLdb.escape_string(user).decode('utf-8')
- host = MySQLdb.escape_string(host).decode('utf-8')
- password1 = MySQLdb.escape_string(password1).decode('utf-8')
+ user = MySQLdb.escape_string(user).decode("utf-8")
+ host = MySQLdb.escape_string(host).decode("utf-8")
+ password1 = MySQLdb.escape_string(password1).decode("utf-8")
engine = get_engine(instance=instance)
# 在一个事务内执行
hosts = host.split("|")
- create_user_cmd = ''
+ create_user_cmd = ""
accounts = []
for host in hosts:
create_user_cmd += f"create user '{user}'@'{host}' identified by '{password1}';"
- accounts.append(InstanceAccount(instance=instance, user=user, host=host, password=password1, remark=remark))
- exec_result = engine.execute(db_name='mysql', sql=create_user_cmd)
+ accounts.append(
+ InstanceAccount(
+ instance=instance,
+ user=user,
+ host=host,
+ password=password1,
+ remark=remark,
+ )
+ )
+ exec_result = engine.execute(db_name="mysql", sql=create_user_cmd)
if exec_result.error:
- return JsonResponse({'status': 1, 'msg': exec_result.error})
+ return JsonResponse({"status": 1, "msg": exec_result.error})
# 保存到数据库
else:
InstanceAccount.objects.bulk_create(accounts)
- return JsonResponse({'status': 0, 'msg': '', 'data': []})
+ return JsonResponse({"status": 0, "msg": "", "data": []})
-@permission_required('sql.instance_account_manage', raise_exception=True)
+@permission_required("sql.instance_account_manage", raise_exception=True)
def edit(request):
"""修改、录入数据库账号"""
- instance_id = request.POST.get('instance_id', 0)
- user = request.POST.get('user')
- host = request.POST.get('host')
- password = request.POST.get('password')
- remark = request.POST.get('remark', '')
+ instance_id = request.POST.get("instance_id", 0)
+ user = request.POST.get("user")
+ host = request.POST.get("host")
+ password = request.POST.get("password")
+ remark = request.POST.get("remark", "")
try:
- instance = user_instances(request.user, db_type=['mysql']).get(id=instance_id)
+ instance = user_instances(request.user, db_type=["mysql"]).get(id=instance_id)
except Instance.DoesNotExist:
- return JsonResponse({'status': 1, 'msg': '你所在组未关联该实例', 'data': []})
+ return JsonResponse({"status": 1, "msg": "你所在组未关联该实例", "data": []})
if not all([user, host]):
- return JsonResponse({'status': 1, 'msg': '参数不完整,请确认后提交', 'data': []})
+ return JsonResponse({"status": 1, "msg": "参数不完整,请确认后提交", "data": []})
# 保存到数据库
if password:
InstanceAccount.objects.update_or_create(
- instance=instance, user=user, host=host,
- defaults={
- "password": password,
- "remark": remark
- }
+ instance=instance,
+ user=user,
+ host=host,
+ defaults={"password": password, "remark": remark},
)
else:
InstanceAccount.objects.update_or_create(
- instance=instance, user=user, host=host,
- defaults={
- "remark": remark
- }
+ instance=instance, user=user, host=host, defaults={"remark": remark}
)
- return JsonResponse({'status': 0, 'msg': '', 'data': []})
+ return JsonResponse({"status": 0, "msg": "", "data": []})
-@permission_required('sql.instance_account_manage', raise_exception=True)
+@permission_required("sql.instance_account_manage", raise_exception=True)
def grant(request):
"""获取用户权限变更语句,并执行权限变更"""
- instance_id = request.POST.get('instance_id', 0)
- user_host = request.POST.get('user_host')
- op_type = int(request.POST.get('op_type'))
- priv_type = int(request.POST.get('priv_type'))
- privs = json.loads(request.POST.get('privs'))
- grant_sql = ''
+ instance_id = request.POST.get("instance_id", 0)
+ user_host = request.POST.get("user_host")
+ op_type = int(request.POST.get("op_type"))
+ priv_type = int(request.POST.get("priv_type"))
+ privs = json.loads(request.POST.get("privs"))
+ grant_sql = ""
# escape
- user_host = MySQLdb.escape_string(user_host).decode('utf-8')
+ user_host = MySQLdb.escape_string(user_host).decode("utf-8")
# 全局权限
if priv_type == 0:
- global_privs = privs['global_privs']
+ global_privs = privs["global_privs"]
if not all([global_privs]):
- return JsonResponse({'status': 1, 'msg': '信息不完整,请确认后提交', 'data': []})
- global_privs = ['GRANT OPTION' if g == 'GRANT' else g for g in global_privs]
+ return JsonResponse({"status": 1, "msg": "信息不完整,请确认后提交", "data": []})
+ global_privs = ["GRANT OPTION" if g == "GRANT" else g for g in global_privs]
if op_type == 0:
grant_sql = f"GRANT {','.join(global_privs)} ON *.* TO {user_host};"
elif op_type == 1:
@@ -185,37 +195,41 @@ def grant(request):
# 库权限
elif priv_type == 1:
- db_privs = privs['db_privs']
- db_name = request.POST.getlist('db_name[]')
+ db_privs = privs["db_privs"]
+ db_name = request.POST.getlist("db_name[]")
if not all([db_privs, db_name]):
- return JsonResponse({'status': 1, 'msg': '信息不完整,请确认后提交', 'data': []})
+ return JsonResponse({"status": 1, "msg": "信息不完整,请确认后提交", "data": []})
for db in db_name:
- db_privs = ['GRANT OPTION' if d == 'GRANT' else d for d in db_privs]
+ db_privs = ["GRANT OPTION" if d == "GRANT" else d for d in db_privs]
if op_type == 0:
grant_sql += f"GRANT {','.join(db_privs)} ON `{db}`.* TO {user_host};"
elif op_type == 1:
- grant_sql += f"REVOKE {','.join(db_privs)} ON `{db}`.* FROM {user_host};"
+ grant_sql += (
+ f"REVOKE {','.join(db_privs)} ON `{db}`.* FROM {user_host};"
+ )
# 表权限
elif priv_type == 2:
- tb_privs = privs['tb_privs']
- db_name = request.POST.get('db_name')
- tb_name = request.POST.getlist('tb_name[]')
+ tb_privs = privs["tb_privs"]
+ db_name = request.POST.get("db_name")
+ tb_name = request.POST.getlist("tb_name[]")
if not all([tb_privs, db_name, tb_name]):
- return JsonResponse({'status': 1, 'msg': '信息不完整,请确认后提交', 'data': []})
+ return JsonResponse({"status": 1, "msg": "信息不完整,请确认后提交", "data": []})
for tb in tb_name:
- tb_privs = ['GRANT OPTION' if t == 'GRANT' else t for t in tb_privs]
+ tb_privs = ["GRANT OPTION" if t == "GRANT" else t for t in tb_privs]
if op_type == 0:
- grant_sql += f"GRANT {','.join(tb_privs)} ON `{db_name}`.`{tb}` TO {user_host};"
+ grant_sql += (
+ f"GRANT {','.join(tb_privs)} ON `{db_name}`.`{tb}` TO {user_host};"
+ )
elif op_type == 1:
grant_sql += f"REVOKE {','.join(tb_privs)} ON `{db_name}`.`{tb}` FROM {user_host};"
# 列权限
elif priv_type == 3:
- col_privs = privs['col_privs']
- db_name = request.POST.get('db_name')
- tb_name = request.POST.get('tb_name')
- col_name = request.POST.getlist('col_name[]')
+ col_privs = privs["col_privs"]
+ db_name = request.POST.get("db_name")
+ tb_name = request.POST.get("tb_name")
+ col_name = request.POST.getlist("col_name[]")
if not all([col_privs, db_name, tb_name, col_name]):
- return JsonResponse({'status': 1, 'msg': '信息不完整,请确认后提交', 'data': []})
+ return JsonResponse({"status": 1, "msg": "信息不完整,请确认后提交", "data": []})
for priv in col_privs:
if op_type == 0:
grant_sql += f"GRANT {priv}(`{'`,`'.join(col_name)}`) ON `{db_name}`.`{tb_name}` TO {user_host};"
@@ -224,116 +238,118 @@ def grant(request):
# 执行变更语句
try:
- instance = user_instances(request.user, db_type=['mysql']).get(id=instance_id)
+ instance = user_instances(request.user, db_type=["mysql"]).get(id=instance_id)
except Instance.DoesNotExist:
- return JsonResponse({'status': 1, 'msg': '你所在组未关联该实例', 'data': []})
+ return JsonResponse({"status": 1, "msg": "你所在组未关联该实例", "data": []})
engine = get_engine(instance=instance)
- exec_result = engine.execute(db_name='mysql', sql=grant_sql)
+ exec_result = engine.execute(db_name="mysql", sql=grant_sql)
if exec_result.error:
- return JsonResponse({'status': 1, 'msg': exec_result.error})
- return JsonResponse({'status': 0, 'msg': '', 'data': grant_sql})
+ return JsonResponse({"status": 1, "msg": exec_result.error})
+ return JsonResponse({"status": 0, "msg": "", "data": grant_sql})
-@permission_required('sql.instance_account_manage', raise_exception=True)
+@permission_required("sql.instance_account_manage", raise_exception=True)
def reset_pwd(request):
"""创建数据库账号"""
- instance_id = request.POST.get('instance_id', 0)
- user_host = request.POST.get('user_host')
- user = request.POST.get('user')
- host = request.POST.get('host')
- reset_pwd1 = request.POST.get('reset_pwd1')
- reset_pwd2 = request.POST.get('reset_pwd2')
+ instance_id = request.POST.get("instance_id", 0)
+ user_host = request.POST.get("user_host")
+ user = request.POST.get("user")
+ host = request.POST.get("host")
+ reset_pwd1 = request.POST.get("reset_pwd1")
+ reset_pwd2 = request.POST.get("reset_pwd2")
if not all([user, host, reset_pwd1, reset_pwd2]):
- return JsonResponse({'status': 1, 'msg': '参数不完整,请确认后提交', 'data': []})
+ return JsonResponse({"status": 1, "msg": "参数不完整,请确认后提交", "data": []})
if reset_pwd1 != reset_pwd2:
- return JsonResponse({'status': 1, 'msg': '两次输入密码不一致', 'data': []})
+ return JsonResponse({"status": 1, "msg": "两次输入密码不一致", "data": []})
try:
- instance = user_instances(request.user, db_type=['mysql']).get(id=instance_id)
+ instance = user_instances(request.user, db_type=["mysql"]).get(id=instance_id)
except Instance.DoesNotExist:
- return JsonResponse({'status': 1, 'msg': '你所在组未关联该实例', 'data': []})
+ return JsonResponse({"status": 1, "msg": "你所在组未关联该实例", "data": []})
# escape
- user_host = MySQLdb.escape_string(user_host).decode('utf-8')
- reset_pwd1 = MySQLdb.escape_string(reset_pwd1).decode('utf-8')
+ user_host = MySQLdb.escape_string(user_host).decode("utf-8")
+ reset_pwd1 = MySQLdb.escape_string(reset_pwd1).decode("utf-8")
# TODO 目前使用系统自带验证,后续实现验证器校验
try:
validate_password(reset_pwd1, user=None, password_validators=None)
except ValidationError as msg:
- return JsonResponse({'status': 1, 'msg': f'{msg}', 'data': []})
+ return JsonResponse({"status": 1, "msg": f"{msg}", "data": []})
engine = get_engine(instance=instance)
- exec_result = engine.execute(db_name='mysql',
- sql=f"ALTER USER {user_host} IDENTIFIED BY '{reset_pwd1}';")
+ exec_result = engine.execute(
+ db_name="mysql", sql=f"ALTER USER {user_host} IDENTIFIED BY '{reset_pwd1}';"
+ )
if exec_result.error:
- result = {'status': 1, 'msg': exec_result.error}
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result = {"status": 1, "msg": exec_result.error}
+ return HttpResponse(json.dumps(result), content_type="application/json")
# 保存到数据库
else:
- InstanceAccount.objects.update_or_create(instance=instance, user=user, host=host,
- defaults={'password': reset_pwd1})
- return JsonResponse({'status': 0, 'msg': '', 'data': []})
+ InstanceAccount.objects.update_or_create(
+ instance=instance, user=user, host=host, defaults={"password": reset_pwd1}
+ )
+ return JsonResponse({"status": 0, "msg": "", "data": []})
-@permission_required('sql.instance_account_manage', raise_exception=True)
+@permission_required("sql.instance_account_manage", raise_exception=True)
def lock(request):
"""锁定/解锁账号"""
- instance_id = request.POST.get('instance_id', 0)
- user_host = request.POST.get('user_host')
- is_locked = request.POST.get('is_locked')
- lock_sql = ''
+ instance_id = request.POST.get("instance_id", 0)
+ user_host = request.POST.get("user_host")
+ is_locked = request.POST.get("is_locked")
+ lock_sql = ""
if not all([user_host]):
- return JsonResponse({'status': 1, 'msg': '参数不完整,请确认后提交', 'data': []})
+ return JsonResponse({"status": 1, "msg": "参数不完整,请确认后提交", "data": []})
try:
- instance = user_instances(request.user, db_type=['mysql']).get(id=instance_id)
+ instance = user_instances(request.user, db_type=["mysql"]).get(id=instance_id)
except Instance.DoesNotExist:
- return JsonResponse({'status': 1, 'msg': '你所在组未关联该实例', 'data': []})
+ return JsonResponse({"status": 1, "msg": "你所在组未关联该实例", "data": []})
# escape
- user_host = MySQLdb.escape_string(user_host).decode('utf-8')
+ user_host = MySQLdb.escape_string(user_host).decode("utf-8")
- if is_locked == 'N':
+ if is_locked == "N":
lock_sql = f"ALTER USER {user_host} ACCOUNT LOCK;"
- elif is_locked == 'Y':
+ elif is_locked == "Y":
lock_sql = f"ALTER USER {user_host} ACCOUNT UNLOCK;"
engine = get_engine(instance=instance)
- exec_result = engine.execute(db_name='mysql', sql=lock_sql)
+ exec_result = engine.execute(db_name="mysql", sql=lock_sql)
if exec_result.error:
- return JsonResponse({'status': 1, 'msg': exec_result.error})
- return JsonResponse({'status': 0, 'msg': '', 'data': []})
+ return JsonResponse({"status": 1, "msg": exec_result.error})
+ return JsonResponse({"status": 0, "msg": "", "data": []})
-@permission_required('sql.instance_account_manage', raise_exception=True)
+@permission_required("sql.instance_account_manage", raise_exception=True)
def delete(request):
"""删除账号"""
- instance_id = request.POST.get('instance_id', 0)
- user_host = request.POST.get('user_host')
- user = request.POST.get('user')
- host = request.POST.get('host')
+ instance_id = request.POST.get("instance_id", 0)
+ user_host = request.POST.get("user_host")
+ user = request.POST.get("user")
+ host = request.POST.get("host")
if not all([user_host]):
- return JsonResponse({'status': 1, 'msg': '参数不完整,请确认后提交', 'data': []})
+ return JsonResponse({"status": 1, "msg": "参数不完整,请确认后提交", "data": []})
try:
- instance = user_instances(request.user, db_type=['mysql']).get(id=instance_id)
+ instance = user_instances(request.user, db_type=["mysql"]).get(id=instance_id)
except Instance.DoesNotExist:
- return JsonResponse({'status': 1, 'msg': '你所在组未关联该实例', 'data': []})
+ return JsonResponse({"status": 1, "msg": "你所在组未关联该实例", "data": []})
# escape
- user_host = MySQLdb.escape_string(user_host).decode('utf-8')
+ user_host = MySQLdb.escape_string(user_host).decode("utf-8")
engine = get_engine(instance=instance)
- exec_result = engine.execute(db_name='mysql', sql=f"DROP USER {user_host};")
+ exec_result = engine.execute(db_name="mysql", sql=f"DROP USER {user_host};")
if exec_result.error:
- return JsonResponse({'status': 1, 'msg': exec_result.error})
+ return JsonResponse({"status": 1, "msg": exec_result.error})
# 删除数据库对应记录
else:
InstanceAccount.objects.filter(instance=instance, user=user, host=host).delete()
- return JsonResponse({'status': 0, 'msg': '', 'data': []})
+ return JsonResponse({"status": 0, "msg": "", "data": []})
diff --git a/sql/instance_database.py b/sql/instance_database.py
index 3deb528197..15d1572c35 100644
--- a/sql/instance_database.py
+++ b/sql/instance_database.py
@@ -17,28 +17,29 @@
from sql.models import Instance, InstanceDatabase, Users
from sql.utils.resource_group import user_instances
-__author__ = 'hhyo'
+__author__ = "hhyo"
-@permission_required('sql.menu_database', raise_exception=True)
+@permission_required("sql.menu_database", raise_exception=True)
def databases(request):
"""获取实例数据库列表"""
- instance_id = request.POST.get('instance_id')
- saved = True if request.POST.get('saved') == 'true' else False # 平台是否保存
+ instance_id = request.POST.get("instance_id")
+ saved = True if request.POST.get("saved") == "true" else False # 平台是否保存
if not instance_id:
- return JsonResponse({'status': 0, 'msg': '', 'data': []})
+ return JsonResponse({"status": 0, "msg": "", "data": []})
try:
- instance = user_instances(request.user, db_type=['mysql']).get(id=instance_id)
+ instance = user_instances(request.user, db_type=["mysql"]).get(id=instance_id)
except Instance.DoesNotExist:
- return JsonResponse({'status': 1, 'msg': '你所在组未关联该实例', 'data': []})
+ return JsonResponse({"status": 1, "msg": "你所在组未关联该实例", "data": []})
# 获取已录入数据库
cnf_dbs = dict()
- for db in InstanceDatabase.objects.filter(
- instance=instance).values('id', 'db_name', 'owner', 'owner_display', 'remark'):
- db['saved'] = True
+ for db in InstanceDatabase.objects.filter(instance=instance).values(
+ "id", "db_name", "owner", "owner_display", "remark"
+ ):
+ db["saved"] = True
cnf_dbs[f"{db['db_name']}"] = db
# 获取所有数据库
@@ -46,7 +47,9 @@ def databases(request):
FROM information_schema.SCHEMATA
WHERE SCHEMA_NAME NOT IN ('information_schema', 'performance_schema', 'mysql', 'test', 'sys');"""
query_engine = get_engine(instance=instance)
- query_result = query_engine.query('information_schema', sql_get_db, close_conn=False)
+ query_result = query_engine.query(
+ "information_schema", sql_get_db, close_conn=False
+ )
if not query_result.error:
dbs = query_result.rows
# 获取数据库关联用户信息
@@ -57,13 +60,15 @@ def databases(request):
from information_schema.SCHEMA_PRIVILEGES
where TABLE_SCHEMA='{db_name}'
group by TABLE_SCHEMA;"""
- bind_users = query_engine.query('information_schema', sql_get_bind_users, close_conn=False).rows
+ bind_users = query_engine.query(
+ "information_schema", sql_get_bind_users, close_conn=False
+ ).rows
row = {
- 'db_name': db_name,
- 'charset': db[1],
- 'collation': db[2],
- 'grantees': bind_users[0][0].split(',') if bind_users else [],
- 'saved': False
+ "db_name": db_name,
+ "charset": db[1],
+ "collation": db[2],
+ "grantees": bind_users[0][0].split(",") if bind_users else [],
+ "saved": False,
}
# 合并数据
if db_name in cnf_dbs.keys():
@@ -71,81 +76,91 @@ def databases(request):
rows.append(row)
# 过滤参数
if saved:
- rows = [row for row in rows if row['saved']]
+ rows = [row for row in rows if row["saved"]]
- result = {'status': 0, 'msg': 'ok', 'rows': rows}
+ result = {"status": 0, "msg": "ok", "rows": rows}
else:
- result = {'status': 1, 'msg': query_result.error}
+ result = {"status": 1, "msg": query_result.error}
# 关闭连接
query_engine.close()
- return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
- content_type='application/json')
+ return HttpResponse(
+ json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
+ content_type="application/json",
+ )
-@permission_required('sql.menu_database', raise_exception=True)
+@permission_required("sql.menu_database", raise_exception=True)
def create(request):
"""创建数据库"""
- instance_id = request.POST.get('instance_id', 0)
- db_name = request.POST.get('db_name')
- owner = request.POST.get('owner', '')
- remark = request.POST.get('remark', '')
+ instance_id = request.POST.get("instance_id", 0)
+ db_name = request.POST.get("db_name")
+ owner = request.POST.get("owner", "")
+ remark = request.POST.get("remark", "")
if not all([db_name]):
- return JsonResponse({'status': 1, 'msg': '参数不完整,请确认后提交', 'data': []})
+ return JsonResponse({"status": 1, "msg": "参数不完整,请确认后提交", "data": []})
try:
- instance = user_instances(request.user, db_type=['mysql']).get(id=instance_id)
+ instance = user_instances(request.user, db_type=["mysql"]).get(id=instance_id)
except Instance.DoesNotExist:
- return JsonResponse({'status': 1, 'msg': '你所在组未关联该实例', 'data': []})
+ return JsonResponse({"status": 1, "msg": "你所在组未关联该实例", "data": []})
try:
owner_display = Users.objects.get(username=owner).display
except Users.DoesNotExist:
- return JsonResponse({'status': 1, 'msg': '负责人不存在', 'data': []})
+ return JsonResponse({"status": 1, "msg": "负责人不存在", "data": []})
# escape
- db_name = MySQLdb.escape_string(db_name).decode('utf-8')
+ db_name = MySQLdb.escape_string(db_name).decode("utf-8")
engine = get_engine(instance=instance)
- exec_result = engine.execute(db_name='information_schema', sql=f"create database {db_name};")
+ exec_result = engine.execute(
+ db_name="information_schema", sql=f"create database {db_name};"
+ )
if exec_result.error:
- return JsonResponse({'status': 1, 'msg': exec_result.error})
+ return JsonResponse({"status": 1, "msg": exec_result.error})
# 保存到数据库
else:
InstanceDatabase.objects.create(
- instance=instance, db_name=db_name, owner=owner, owner_display=owner_display, remark=remark)
+ instance=instance,
+ db_name=db_name,
+ owner=owner,
+ owner_display=owner_display,
+ remark=remark,
+ )
# 清空实例资源缓存
r = get_redis_connection("default")
- for key in r.scan_iter(match='*insRes*', count=2000):
+ for key in r.scan_iter(match="*insRes*", count=2000):
r.delete(key)
- return JsonResponse({'status': 0, 'msg': '', 'data': []})
+ return JsonResponse({"status": 0, "msg": "", "data": []})
-@permission_required('sql.menu_database', raise_exception=True)
+@permission_required("sql.menu_database", raise_exception=True)
def edit(request):
"""编辑/录入数据库"""
- instance_id = request.POST.get('instance_id', 0)
- db_name = request.POST.get('db_name')
- owner = request.POST.get('owner', '')
- remark = request.POST.get('remark', '')
+ instance_id = request.POST.get("instance_id", 0)
+ db_name = request.POST.get("db_name")
+ owner = request.POST.get("owner", "")
+ remark = request.POST.get("remark", "")
if not all([db_name]):
- return JsonResponse({'status': 1, 'msg': '参数不完整,请确认后提交', 'data': []})
+ return JsonResponse({"status": 1, "msg": "参数不完整,请确认后提交", "data": []})
try:
- instance = user_instances(request.user, db_type=['mysql']).get(id=instance_id)
+ instance = user_instances(request.user, db_type=["mysql"]).get(id=instance_id)
except Instance.DoesNotExist:
- return JsonResponse({'status': 1, 'msg': '你所在组未关联该实例', 'data': []})
+ return JsonResponse({"status": 1, "msg": "你所在组未关联该实例", "data": []})
try:
owner_display = Users.objects.get(username=owner).display
except Users.DoesNotExist:
- return JsonResponse({'status': 1, 'msg': '负责人不存在', 'data': []})
+ return JsonResponse({"status": 1, "msg": "负责人不存在", "data": []})
# 更新或者录入信息
InstanceDatabase.objects.update_or_create(
instance=instance,
db_name=db_name,
- defaults={"owner": owner, "owner_display": owner_display, "remark": remark})
- return JsonResponse({'status': 0, 'msg': '', 'data': []})
+ defaults={"owner": owner, "owner_display": owner_display, "remark": remark},
+ )
+ return JsonResponse({"status": 0, "msg": "", "data": []})
diff --git a/sql/models.py b/sql/models.py
index d998bacb09..84c0377369 100755
--- a/sql/models.py
+++ b/sql/models.py
@@ -10,15 +10,16 @@ class ResourceGroup(models.Model):
"""
资源组
"""
- group_id = models.AutoField('组ID', primary_key=True)
- group_name = models.CharField('组名称', max_length=100, unique=True)
- group_parent_id = models.BigIntegerField('父级id', default=0)
- group_sort = models.IntegerField('排序', default=1)
- group_level = models.IntegerField('层级', default=1)
- ding_webhook = models.CharField('钉钉webhook地址', max_length=255, blank=True)
- feishu_webhook = models.CharField('飞书webhook地址', max_length=255, blank=True)
- qywx_webhook = models.CharField('企业微信webhook地址', max_length=255, blank=True)
- is_deleted = models.IntegerField('是否删除', choices=((0, '否'), (1, '是')), default=0)
+
+ group_id = models.AutoField("组ID", primary_key=True)
+ group_name = models.CharField("组名称", max_length=100, unique=True)
+ group_parent_id = models.BigIntegerField("父级id", default=0)
+ group_sort = models.IntegerField("排序", default=1)
+ group_level = models.IntegerField("层级", default=1)
+ ding_webhook = models.CharField("钉钉webhook地址", max_length=255, blank=True)
+ feishu_webhook = models.CharField("飞书webhook地址", max_length=255, blank=True)
+ qywx_webhook = models.CharField("企业微信webhook地址", max_length=255, blank=True)
+ is_deleted = models.IntegerField("是否删除", choices=((0, "否"), (1, "是")), default=0)
create_time = models.DateTimeField(auto_now_add=True)
sys_time = models.DateTimeField(auto_now=True)
@@ -27,22 +28,25 @@ def __str__(self):
class Meta:
managed = True
- db_table = 'resource_group'
- verbose_name = u'资源组管理'
- verbose_name_plural = u'资源组管理'
+ db_table = "resource_group"
+ verbose_name = "资源组管理"
+ verbose_name_plural = "资源组管理"
class Users(AbstractUser):
"""
用户信息扩展
"""
- display = models.CharField('显示的中文名', max_length=50, default='')
- ding_user_id = models.CharField('钉钉UserID', max_length=64, blank=True)
- wx_user_id = models.CharField('企业微信UserID', max_length=64, blank=True)
- feishu_open_id = models.CharField('飞书OpenID', max_length=64, blank=True)
- failed_login_count = models.IntegerField('失败计数', default=0)
- last_login_failed_at = models.DateTimeField('上次失败登录时间', blank=True, null=True)
- resource_group = models.ManyToManyField(ResourceGroup, verbose_name='资源组', blank=True)
+
+ display = models.CharField("显示的中文名", max_length=50, default="")
+ ding_user_id = models.CharField("钉钉UserID", max_length=64, blank=True)
+ wx_user_id = models.CharField("企业微信UserID", max_length=64, blank=True)
+ feishu_open_id = models.CharField("飞书OpenID", max_length=64, blank=True)
+ failed_login_count = models.IntegerField("失败计数", default=0)
+ last_login_failed_at = models.DateTimeField("上次失败登录时间", blank=True, null=True)
+ resource_group = models.ManyToManyField(
+ ResourceGroup, verbose_name="资源组", blank=True
+ )
def save(self, *args, **kwargs):
self.failed_login_count = min(127, self.failed_login_count)
@@ -56,24 +60,31 @@ def __str__(self):
class Meta:
managed = True
- db_table = 'sql_users'
- verbose_name = u'用户管理'
- verbose_name_plural = u'用户管理'
+ db_table = "sql_users"
+ verbose_name = "用户管理"
+ verbose_name_plural = "用户管理"
class TwoFactorAuthConfig(models.Model):
"""
2fa配置信息
"""
+
auth_type_choice = (
- ('totp', 'Google身份验证器'),
- ('sms', '短信验证码'),
+ ("totp", "Google身份验证器"),
+ ("sms", "短信验证码"),
)
- username = fields.EncryptedCharField(verbose_name='用户名', max_length=200)
- auth_type = fields.EncryptedCharField(verbose_name='认证类型', max_length=128, choices=auth_type_choice)
- phone = fields.EncryptedCharField(verbose_name='手机号码', max_length=64, null=True, default='')
- secret_key = fields.EncryptedCharField(verbose_name='用户密钥', max_length=256, null=True)
+ username = fields.EncryptedCharField(verbose_name="用户名", max_length=200)
+ auth_type = fields.EncryptedCharField(
+ verbose_name="认证类型", max_length=128, choices=auth_type_choice
+ )
+ phone = fields.EncryptedCharField(
+ verbose_name="手机号码", max_length=64, null=True, default=""
+ )
+ secret_key = fields.EncryptedCharField(
+ verbose_name="用户密钥", max_length=256, null=True
+ )
user = models.ForeignKey(Users, on_delete=models.CASCADE)
def __int__(self):
@@ -81,147 +92,193 @@ def __int__(self):
class Meta:
managed = True
- db_table = '2fa_config'
- verbose_name = u'2FA配置'
- verbose_name_plural = u'2FA配置'
- unique_together = ('user', 'auth_type')
+ db_table = "2fa_config"
+ verbose_name = "2FA配置"
+ verbose_name_plural = "2FA配置"
+ unique_together = ("user", "auth_type")
class InstanceTag(models.Model):
"""实例标签配置"""
- tag_code = models.CharField('标签代码', max_length=20, unique=True)
- tag_name = models.CharField('标签名称', max_length=20, unique=True)
- active = models.BooleanField('激活状态', default=True)
- create_time = models.DateTimeField('创建时间', auto_now_add=True)
+
+ tag_code = models.CharField("标签代码", max_length=20, unique=True)
+ tag_name = models.CharField("标签名称", max_length=20, unique=True)
+ active = models.BooleanField("激活状态", default=True)
+ create_time = models.DateTimeField("创建时间", auto_now_add=True)
def __str__(self):
return self.tag_name
class Meta:
managed = True
- db_table = 'sql_instance_tag'
- verbose_name = u'实例标签'
- verbose_name_plural = u'实例标签'
+ db_table = "sql_instance_tag"
+ verbose_name = "实例标签"
+ verbose_name_plural = "实例标签"
DB_TYPE_CHOICES = (
- ('mysql', 'MySQL'),
- ('mssql', 'MsSQL'),
- ('redis', 'Redis'),
- ('pgsql', 'PgSQL'),
- ('oracle', 'Oracle'),
- ('mongo', 'Mongo'),
- ('phoenix', 'Phoenix'),
- ('odps', 'ODPS'),
- ('clickhouse', 'ClickHouse'),
- ('goinception', 'goInception'))
+ ("mysql", "MySQL"),
+ ("mssql", "MsSQL"),
+ ("redis", "Redis"),
+ ("pgsql", "PgSQL"),
+ ("oracle", "Oracle"),
+ ("mongo", "Mongo"),
+ ("phoenix", "Phoenix"),
+ ("odps", "ODPS"),
+ ("clickhouse", "ClickHouse"),
+ ("goinception", "goInception"),
+)
class Tunnel(models.Model):
"""
SSH隧道配置
"""
- tunnel_name = models.CharField('隧道名称', max_length=50, unique=True)
- host = models.CharField('隧道连接', max_length=200)
- port = models.IntegerField('端口', default=0)
- user = fields.EncryptedCharField(verbose_name='用户名', max_length=200, default='', blank=True, null=True)
- password = fields.EncryptedCharField(verbose_name='密码', max_length=300, default='', blank=True, null=True)
+
+ tunnel_name = models.CharField("隧道名称", max_length=50, unique=True)
+ host = models.CharField("隧道连接", max_length=200)
+ port = models.IntegerField("端口", default=0)
+ user = fields.EncryptedCharField(
+ verbose_name="用户名", max_length=200, default="", blank=True, null=True
+ )
+ password = fields.EncryptedCharField(
+ verbose_name="密码", max_length=300, default="", blank=True, null=True
+ )
pkey = fields.EncryptedTextField(verbose_name="密钥", blank=True, null=True)
- pkey_path = models.FileField(verbose_name="密钥地址", blank=True, null=True, upload_to='keys/')
- pkey_password = fields.EncryptedCharField(verbose_name='密钥密码', max_length=300, default='', blank=True, null=True)
- create_time = models.DateTimeField('创建时间', auto_now_add=True)
- update_time = models.DateTimeField('更新时间', auto_now=True)
+ pkey_path = models.FileField(
+ verbose_name="密钥地址", blank=True, null=True, upload_to="keys/"
+ )
+ pkey_password = fields.EncryptedCharField(
+ verbose_name="密钥密码", max_length=300, default="", blank=True, null=True
+ )
+ create_time = models.DateTimeField("创建时间", auto_now_add=True)
+ update_time = models.DateTimeField("更新时间", auto_now=True)
def __str__(self):
return self.tunnel_name
def short_pkey(self):
if len(str(self.pkey)) > 20:
- return '{}...'.format(str(self.pkey)[0:19])
+ return "{}...".format(str(self.pkey)[0:19])
else:
return str(self.pkey)
class Meta:
managed = True
- db_table = 'ssh_tunnel'
- verbose_name = u'隧道配置'
- verbose_name_plural = u'隧道配置'
+ db_table = "ssh_tunnel"
+ verbose_name = "隧道配置"
+ verbose_name_plural = "隧道配置"
class Instance(models.Model):
"""
各个线上实例配置
"""
- instance_name = models.CharField('实例名称', max_length=50, unique=True)
- type = models.CharField('实例类型', max_length=6, choices=(('master', '主库'), ('slave', '从库')))
- db_type = models.CharField('数据库类型', max_length=20, choices=DB_TYPE_CHOICES)
- mode = models.CharField('运行模式', max_length=10, default='', blank=True, choices=(('standalone', '单机'), ('cluster', '集群')))
- host = models.CharField('实例连接', max_length=200)
- port = models.IntegerField('端口', default=0)
- user = fields.EncryptedCharField(verbose_name='用户名', max_length=200, default='', blank=True)
- password = fields.EncryptedCharField(verbose_name='密码', max_length=300, default='', blank=True)
- db_name = models.CharField('数据库', max_length=64, default='', blank=True)
- charset = models.CharField('字符集', max_length=20, default='', blank=True)
- service_name = models.CharField('Oracle service name', max_length=50, null=True, blank=True)
- sid = models.CharField('Oracle sid', max_length=50, null=True, blank=True)
- resource_group = models.ManyToManyField(ResourceGroup, verbose_name='资源组', blank=True)
- instance_tag = models.ManyToManyField(InstanceTag, verbose_name='实例标签', blank=True)
- tunnel = models.ForeignKey(Tunnel, verbose_name='连接隧道', blank=True, null=True, on_delete=models.CASCADE, default=None)
- create_time = models.DateTimeField('创建时间', auto_now_add=True)
- update_time = models.DateTimeField('更新时间', auto_now=True)
+
+ instance_name = models.CharField("实例名称", max_length=50, unique=True)
+ type = models.CharField(
+ "实例类型", max_length=6, choices=(("master", "主库"), ("slave", "从库"))
+ )
+ db_type = models.CharField("数据库类型", max_length=20, choices=DB_TYPE_CHOICES)
+ mode = models.CharField(
+ "运行模式",
+ max_length=10,
+ default="",
+ blank=True,
+ choices=(("standalone", "单机"), ("cluster", "集群")),
+ )
+ host = models.CharField("实例连接", max_length=200)
+ port = models.IntegerField("端口", default=0)
+ user = fields.EncryptedCharField(
+ verbose_name="用户名", max_length=200, default="", blank=True
+ )
+ password = fields.EncryptedCharField(
+ verbose_name="密码", max_length=300, default="", blank=True
+ )
+ db_name = models.CharField("数据库", max_length=64, default="", blank=True)
+ charset = models.CharField("字符集", max_length=20, default="", blank=True)
+ service_name = models.CharField(
+ "Oracle service name", max_length=50, null=True, blank=True
+ )
+ sid = models.CharField("Oracle sid", max_length=50, null=True, blank=True)
+ resource_group = models.ManyToManyField(
+ ResourceGroup, verbose_name="资源组", blank=True
+ )
+ instance_tag = models.ManyToManyField(InstanceTag, verbose_name="实例标签", blank=True)
+ tunnel = models.ForeignKey(
+ Tunnel,
+ verbose_name="连接隧道",
+ blank=True,
+ null=True,
+ on_delete=models.CASCADE,
+ default=None,
+ )
+ create_time = models.DateTimeField("创建时间", auto_now_add=True)
+ update_time = models.DateTimeField("更新时间", auto_now=True)
def __str__(self):
return self.instance_name
class Meta:
managed = True
- db_table = 'sql_instance'
- verbose_name = u'实例配置'
- verbose_name_plural = u'实例配置'
+ db_table = "sql_instance"
+ verbose_name = "实例配置"
+ verbose_name_plural = "实例配置"
SQL_WORKFLOW_CHOICES = (
- ('workflow_finish', _('workflow_finish')),
- ('workflow_abort', _('workflow_abort')),
- ('workflow_manreviewing', _('workflow_manreviewing')),
- ('workflow_review_pass', _('workflow_review_pass')),
- ('workflow_timingtask', _('workflow_timingtask')),
- ('workflow_queuing', _('workflow_queuing')),
- ('workflow_executing', _('workflow_executing')),
- ('workflow_autoreviewwrong', _('workflow_autoreviewwrong')),
- ('workflow_exception', _('workflow_exception')))
+ ("workflow_finish", _("workflow_finish")),
+ ("workflow_abort", _("workflow_abort")),
+ ("workflow_manreviewing", _("workflow_manreviewing")),
+ ("workflow_review_pass", _("workflow_review_pass")),
+ ("workflow_timingtask", _("workflow_timingtask")),
+ ("workflow_queuing", _("workflow_queuing")),
+ ("workflow_executing", _("workflow_executing")),
+ ("workflow_autoreviewwrong", _("workflow_autoreviewwrong")),
+ ("workflow_exception", _("workflow_exception")),
+)
class SqlWorkflow(models.Model):
"""
存放各个SQL上线工单的基础内容
"""
- workflow_name = models.CharField('工单内容', max_length=50)
- demand_url = models.CharField('需求链接', max_length=500)
- group_id = models.IntegerField('组ID')
- group_name = models.CharField('组名称', max_length=100)
+
+ workflow_name = models.CharField("工单内容", max_length=50)
+ demand_url = models.CharField("需求链接", max_length=500)
+ group_id = models.IntegerField("组ID")
+ group_name = models.CharField("组名称", max_length=100)
instance = models.ForeignKey(Instance, on_delete=models.CASCADE)
- db_name = models.CharField('数据库', max_length=64)
- syntax_type = models.IntegerField('工单类型 0、未知,1、DDL,2、DML', choices=((0, '其他'), (1, 'DDL'), (2, 'DML')), default=0)
- is_backup = models.BooleanField('是否备份', choices=((False, '否'), (True, '是'),), default=True)
- engineer = models.CharField('发起人', max_length=30)
- engineer_display = models.CharField('发起人中文名', max_length=50, default='')
+ db_name = models.CharField("数据库", max_length=64)
+ syntax_type = models.IntegerField(
+ "工单类型 0、未知,1、DDL,2、DML", choices=((0, "其他"), (1, "DDL"), (2, "DML")), default=0
+ )
+ is_backup = models.BooleanField(
+ "是否备份",
+ choices=(
+ (False, "否"),
+ (True, "是"),
+ ),
+ default=True,
+ )
+ engineer = models.CharField("发起人", max_length=30)
+ engineer_display = models.CharField("发起人中文名", max_length=50, default="")
status = models.CharField(max_length=50, choices=SQL_WORKFLOW_CHOICES)
- audit_auth_groups = models.CharField('审批权限组列表', max_length=255)
- run_date_start = models.DateTimeField('可执行起始时间', null=True, blank=True)
- run_date_end = models.DateTimeField('可执行结束时间', null=True, blank=True)
- create_time = models.DateTimeField('创建时间', auto_now_add=True)
- finish_time = models.DateTimeField('结束时间', null=True, blank=True)
- is_manual = models.IntegerField('是否原生执行', choices=((0, '否'), (1, '是')), default=0)
+ audit_auth_groups = models.CharField("审批权限组列表", max_length=255)
+ run_date_start = models.DateTimeField("可执行起始时间", null=True, blank=True)
+ run_date_end = models.DateTimeField("可执行结束时间", null=True, blank=True)
+ create_time = models.DateTimeField("创建时间", auto_now_add=True)
+ finish_time = models.DateTimeField("结束时间", null=True, blank=True)
+ is_manual = models.IntegerField("是否原生执行", choices=((0, "否"), (1, "是")), default=0)
def __str__(self):
return self.workflow_name
class Meta:
managed = True
- db_table = 'sql_workflow'
- verbose_name = u'SQL工单'
- verbose_name_plural = u'SQL工单'
+ db_table = "sql_workflow"
+ verbose_name = "SQL工单"
+ verbose_name_plural = "SQL工单"
class SqlWorkflowContent(models.Model):
@@ -229,87 +286,91 @@ class SqlWorkflowContent(models.Model):
存放各个SQL上线工单的SQL|审核|执行内容
可定期归档或清理历史数据,也可通过``alter table sql_workflow_content row_format=compressed; ``来进行压缩
"""
+
workflow = models.OneToOneField(SqlWorkflow, on_delete=models.CASCADE)
- sql_content = models.TextField('具体sql内容')
- review_content = models.TextField('自动审核内容的JSON格式')
- execute_result = models.TextField('执行结果的JSON格式', blank=True)
+ sql_content = models.TextField("具体sql内容")
+ review_content = models.TextField("自动审核内容的JSON格式")
+ execute_result = models.TextField("执行结果的JSON格式", blank=True)
def __str__(self):
return self.workflow.workflow_name
class Meta:
managed = True
- db_table = 'sql_workflow_content'
- verbose_name = u'SQL工单内容'
- verbose_name_plural = u'SQL工单内容'
+ db_table = "sql_workflow_content"
+ verbose_name = "SQL工单内容"
+ verbose_name_plural = "SQL工单内容"
-workflow_type_choices = ((1, _('sql_query')), (2, _('sql_review')))
-workflow_status_choices = ((0, '待审核'), (1, '审核通过'), (2, '审核不通过'), (3, '审核取消'))
+workflow_type_choices = ((1, _("sql_query")), (2, _("sql_review")))
+workflow_status_choices = ((0, "待审核"), (1, "审核通过"), (2, "审核不通过"), (3, "审核取消"))
class WorkflowAudit(models.Model):
"""
工作流审核状态表
"""
+
audit_id = models.AutoField(primary_key=True)
- group_id = models.IntegerField('组ID')
- group_name = models.CharField('组名称', max_length=100)
- workflow_id = models.BigIntegerField('关联业务id')
- workflow_type = models.IntegerField('申请类型', choices=workflow_type_choices)
- workflow_title = models.CharField('申请标题', max_length=50)
- workflow_remark = models.CharField('申请备注', default='', max_length=140, blank=True)
- audit_auth_groups = models.CharField('审批权限组列表', max_length=255)
- current_audit = models.CharField('当前审批权限组', max_length=20)
- next_audit = models.CharField('下级审批权限组', max_length=20)
- current_status = models.IntegerField('审核状态', choices=workflow_status_choices)
- create_user = models.CharField('申请人', max_length=30)
- create_user_display = models.CharField('申请人中文名', max_length=50, default='')
- create_time = models.DateTimeField('申请时间', auto_now_add=True)
- sys_time = models.DateTimeField('系统时间', auto_now=True)
+ group_id = models.IntegerField("组ID")
+ group_name = models.CharField("组名称", max_length=100)
+ workflow_id = models.BigIntegerField("关联业务id")
+ workflow_type = models.IntegerField("申请类型", choices=workflow_type_choices)
+ workflow_title = models.CharField("申请标题", max_length=50)
+ workflow_remark = models.CharField("申请备注", default="", max_length=140, blank=True)
+ audit_auth_groups = models.CharField("审批权限组列表", max_length=255)
+ current_audit = models.CharField("当前审批权限组", max_length=20)
+ next_audit = models.CharField("下级审批权限组", max_length=20)
+ current_status = models.IntegerField("审核状态", choices=workflow_status_choices)
+ create_user = models.CharField("申请人", max_length=30)
+ create_user_display = models.CharField("申请人中文名", max_length=50, default="")
+ create_time = models.DateTimeField("申请时间", auto_now_add=True)
+ sys_time = models.DateTimeField("系统时间", auto_now=True)
def __int__(self):
return self.audit_id
class Meta:
managed = True
- db_table = 'workflow_audit'
- unique_together = ('workflow_id', 'workflow_type')
- verbose_name = u'工作流审批列表'
- verbose_name_plural = u'工作流审批列表'
+ db_table = "workflow_audit"
+ unique_together = ("workflow_id", "workflow_type")
+ verbose_name = "工作流审批列表"
+ verbose_name_plural = "工作流审批列表"
class WorkflowAuditDetail(models.Model):
"""
审批明细表
"""
+
audit_detail_id = models.AutoField(primary_key=True)
- audit_id = models.IntegerField('审核主表id')
- audit_user = models.CharField('审核人', max_length=30)
- audit_time = models.DateTimeField('审核时间')
- audit_status = models.IntegerField('审核状态', choices=workflow_status_choices)
- remark = models.CharField('审核备注', default='', max_length=1000)
- sys_time = models.DateTimeField('系统时间', auto_now=True)
+ audit_id = models.IntegerField("审核主表id")
+ audit_user = models.CharField("审核人", max_length=30)
+ audit_time = models.DateTimeField("审核时间")
+ audit_status = models.IntegerField("审核状态", choices=workflow_status_choices)
+ remark = models.CharField("审核备注", default="", max_length=1000)
+ sys_time = models.DateTimeField("系统时间", auto_now=True)
def __int__(self):
return self.audit_detail_id
class Meta:
managed = True
- db_table = 'workflow_audit_detail'
- verbose_name = u'工作流审批明细'
- verbose_name_plural = u'工作流审批明细'
+ db_table = "workflow_audit_detail"
+ verbose_name = "工作流审批明细"
+ verbose_name_plural = "工作流审批明细"
class WorkflowAuditSetting(models.Model):
"""
审批配置表
"""
+
audit_setting_id = models.AutoField(primary_key=True)
- group_id = models.IntegerField('组ID')
- group_name = models.CharField('组名称', max_length=100)
- workflow_type = models.IntegerField('审批类型', choices=workflow_type_choices)
- audit_auth_groups = models.CharField('审批权限组列表', max_length=255)
+ group_id = models.IntegerField("组ID")
+ group_name = models.CharField("组名称", max_length=100)
+ workflow_type = models.IntegerField("审批类型", choices=workflow_type_choices)
+ audit_auth_groups = models.CharField("审批权限组列表", max_length=255)
create_time = models.DateTimeField(auto_now_add=True)
sys_time = models.DateTimeField(auto_now=True)
@@ -318,33 +379,34 @@ def __int__(self):
class Meta:
managed = True
- db_table = 'workflow_audit_setting'
- unique_together = ('group_id', 'workflow_type')
- verbose_name = u'审批流程配置'
- verbose_name_plural = u'审批流程配置'
+ db_table = "workflow_audit_setting"
+ unique_together = ("group_id", "workflow_type")
+ verbose_name = "审批流程配置"
+ verbose_name_plural = "审批流程配置"
class WorkflowLog(models.Model):
"""
工作流日志表
"""
+
operation_type_choices = (
- (0, '提交/待审核'),
- (1, '审核通过'),
- (2, '审核不通过'),
- (3, '审核取消'),
- (4, '定时执行'),
- (5, '执行工单'),
- (6, '执行结束'),
+ (0, "提交/待审核"),
+ (1, "审核通过"),
+ (2, "审核不通过"),
+ (3, "审核取消"),
+ (4, "定时执行"),
+ (5, "执行工单"),
+ (6, "执行结束"),
)
id = models.AutoField(primary_key=True)
- audit_id = models.IntegerField('工单审批id', db_index=True)
- operation_type = models.SmallIntegerField('操作类型', choices=operation_type_choices)
- operation_type_desc = models.CharField('操作类型描述', max_length=10)
- operation_info = models.CharField('操作信息', max_length=1000)
- operator = models.CharField('操作人', max_length=30)
- operator_display = models.CharField('操作人中文名', max_length=50, default='')
+ audit_id = models.IntegerField("工单审批id", db_index=True)
+ operation_type = models.SmallIntegerField("操作类型", choices=operation_type_choices)
+ operation_type_desc = models.CharField("操作类型描述", max_length=10)
+ operation_info = models.CharField("操作信息", max_length=1000)
+ operator = models.CharField("操作人", max_length=30)
+ operator_display = models.CharField("操作人中文名", max_length=50, default="")
operation_time = models.DateTimeField(auto_now_add=True)
def __int__(self):
@@ -352,30 +414,38 @@ def __int__(self):
class Meta:
managed = True
- db_table = 'workflow_log'
- verbose_name = u'工作流日志'
- verbose_name_plural = u'工作流日志'
+ db_table = "workflow_log"
+ verbose_name = "工作流日志"
+ verbose_name_plural = "工作流日志"
class QueryPrivilegesApply(models.Model):
"""
查询权限申请记录表
"""
+
apply_id = models.AutoField(primary_key=True)
- group_id = models.IntegerField('组ID')
- group_name = models.CharField('组名称', max_length=100)
- title = models.CharField('申请标题', max_length=50)
+ group_id = models.IntegerField("组ID")
+ group_name = models.CharField("组名称", max_length=100)
+ title = models.CharField("申请标题", max_length=50)
# TODO user_name display 改为外键
- user_name = models.CharField('申请人', max_length=30)
- user_display = models.CharField('申请人中文名', max_length=50, default='')
+ user_name = models.CharField("申请人", max_length=30)
+ user_display = models.CharField("申请人中文名", max_length=50, default="")
instance = models.ForeignKey(Instance, on_delete=models.CASCADE)
- db_list = models.TextField('数据库', default='') # 逗号分隔的数据库列表
- table_list = models.TextField('表', default='') # 逗号分隔的表列表
- valid_date = models.DateField('有效时间')
- limit_num = models.IntegerField('行数限制', default=100)
- priv_type = models.IntegerField('权限类型', choices=((1, 'DATABASE'), (2, 'TABLE'),), default=0)
- status = models.IntegerField('审核状态', choices=workflow_status_choices)
- audit_auth_groups = models.CharField('审批权限组列表', max_length=255)
+ db_list = models.TextField("数据库", default="") # 逗号分隔的数据库列表
+ table_list = models.TextField("表", default="") # 逗号分隔的表列表
+ valid_date = models.DateField("有效时间")
+ limit_num = models.IntegerField("行数限制", default=100)
+ priv_type = models.IntegerField(
+ "权限类型",
+ choices=(
+ (1, "DATABASE"),
+ (2, "TABLE"),
+ ),
+ default=0,
+ )
+ status = models.IntegerField("审核状态", choices=workflow_status_choices)
+ audit_auth_groups = models.CharField("审批权限组列表", max_length=255)
create_time = models.DateTimeField(auto_now_add=True)
sys_time = models.DateTimeField(auto_now=True)
@@ -384,25 +454,33 @@ def __int__(self):
class Meta:
managed = True
- db_table = 'query_privileges_apply'
- verbose_name = u'查询权限申请记录表'
- verbose_name_plural = u'查询权限申请记录表'
+ db_table = "query_privileges_apply"
+ verbose_name = "查询权限申请记录表"
+ verbose_name_plural = "查询权限申请记录表"
class QueryPrivileges(models.Model):
"""
用户权限关系表
"""
+
privilege_id = models.AutoField(primary_key=True)
- user_name = models.CharField('用户名', max_length=30)
- user_display = models.CharField('申请人中文名', max_length=50, default='')
+ user_name = models.CharField("用户名", max_length=30)
+ user_display = models.CharField("申请人中文名", max_length=50, default="")
instance = models.ForeignKey(Instance, on_delete=models.CASCADE)
- db_name = models.CharField('数据库', max_length=64, default='')
- table_name = models.CharField('表', max_length=64, default='')
- valid_date = models.DateField('有效时间')
- limit_num = models.IntegerField('行数限制', default=100)
- priv_type = models.IntegerField('权限类型', choices=((1, 'DATABASE'), (2, 'TABLE'),), default=0)
- is_deleted = models.IntegerField('是否删除', default=0)
+ db_name = models.CharField("数据库", max_length=64, default="")
+ table_name = models.CharField("表", max_length=64, default="")
+ valid_date = models.DateField("有效时间")
+ limit_num = models.IntegerField("行数限制", default=100)
+ priv_type = models.IntegerField(
+ "权限类型",
+ choices=(
+ (1, "DATABASE"),
+ (2, "TABLE"),
+ ),
+ default=0,
+ )
+ is_deleted = models.IntegerField("是否删除", default=0)
create_time = models.DateTimeField(auto_now_add=True)
sys_time = models.DateTimeField(auto_now=True)
@@ -411,245 +489,306 @@ def __int__(self):
class Meta:
managed = True
- db_table = 'query_privileges'
+ db_table = "query_privileges"
index_together = ["user_name", "instance", "db_name", "valid_date"]
- verbose_name = u'查询权限记录'
- verbose_name_plural = u'查询权限记录'
+ verbose_name = "查询权限记录"
+ verbose_name_plural = "查询权限记录"
class QueryLog(models.Model):
"""
记录在线查询sql的日志
"""
+
# TODO 改为实例外键
- instance_name = models.CharField('实例名称', max_length=50)
- db_name = models.CharField('数据库名称', max_length=64)
- sqllog = models.TextField('执行的查询语句')
- effect_row = models.BigIntegerField('返回行数')
- cost_time = models.CharField('执行耗时', max_length=10, default='')
+ instance_name = models.CharField("实例名称", max_length=50)
+ db_name = models.CharField("数据库名称", max_length=64)
+ sqllog = models.TextField("执行的查询语句")
+ effect_row = models.BigIntegerField("返回行数")
+ cost_time = models.CharField("执行耗时", max_length=10, default="")
# TODO 改为user 外键
- username = models.CharField('操作人', max_length=30)
- user_display = models.CharField('操作人中文名', max_length=50, default='')
- priv_check = models.BooleanField('查询权限是否正常校验', choices=((False, '跳过'), (True, '正常'),), default=False)
- hit_rule = models.BooleanField('查询是否命中脱敏规则', choices=((False, '未命中/未知'), (True, '命中')), default=False)
- masking = models.BooleanField('查询结果是否正常脱敏', choices=((False, '否'), (True, '是'),), default=False)
- favorite = models.BooleanField('是否收藏', choices=((False, '否'), (True, '是'),), default=False)
- alias = models.CharField('语句标识', max_length=64, default='', blank=True)
- create_time = models.DateTimeField('操作时间', auto_now_add=True)
+ username = models.CharField("操作人", max_length=30)
+ user_display = models.CharField("操作人中文名", max_length=50, default="")
+ priv_check = models.BooleanField(
+ "查询权限是否正常校验",
+ choices=(
+ (False, "跳过"),
+ (True, "正常"),
+ ),
+ default=False,
+ )
+ hit_rule = models.BooleanField(
+ "查询是否命中脱敏规则", choices=((False, "未命中/未知"), (True, "命中")), default=False
+ )
+ masking = models.BooleanField(
+ "查询结果是否正常脱敏",
+ choices=(
+ (False, "否"),
+ (True, "是"),
+ ),
+ default=False,
+ )
+ favorite = models.BooleanField(
+ "是否收藏",
+ choices=(
+ (False, "否"),
+ (True, "是"),
+ ),
+ default=False,
+ )
+ alias = models.CharField("语句标识", max_length=64, default="", blank=True)
+ create_time = models.DateTimeField("操作时间", auto_now_add=True)
sys_time = models.DateTimeField(auto_now=True)
class Meta:
managed = True
- db_table = 'query_log'
- verbose_name = u'查询日志'
- verbose_name_plural = u'查询日志'
+ db_table = "query_log"
+ verbose_name = "查询日志"
+ verbose_name_plural = "查询日志"
-rule_type_choices = ((1, '手机号'), (2, '证件号码'), (3, '银行卡'), (4, '邮箱'), (5, '金额'), (6, '其他'))
+rule_type_choices = (
+ (1, "手机号"),
+ (2, "证件号码"),
+ (3, "银行卡"),
+ (4, "邮箱"),
+ (5, "金额"),
+ (6, "其他"),
+)
class DataMaskingColumns(models.Model):
"""
脱敏字段配置
"""
- column_id = models.AutoField('字段id', primary_key=True)
- rule_type = models.IntegerField('规则类型', choices=rule_type_choices)
- active = models.BooleanField('激活状态', choices=((False, '未激活'), (True, '激活')))
+
+ column_id = models.AutoField("字段id", primary_key=True)
+ rule_type = models.IntegerField("规则类型", choices=rule_type_choices)
+ active = models.BooleanField("激活状态", choices=((False, "未激活"), (True, "激活")))
instance = models.ForeignKey(Instance, on_delete=models.CASCADE)
- table_schema = models.CharField('字段所在库名', max_length=64)
- table_name = models.CharField('字段所在表名', max_length=64)
- column_name = models.CharField('字段名', max_length=64)
- column_comment = models.CharField('字段描述', max_length=1024, default='', blank=True)
+ table_schema = models.CharField("字段所在库名", max_length=64)
+ table_name = models.CharField("字段所在表名", max_length=64)
+ column_name = models.CharField("字段名", max_length=64)
+ column_comment = models.CharField("字段描述", max_length=1024, default="", blank=True)
create_time = models.DateTimeField(auto_now_add=True)
sys_time = models.DateTimeField(auto_now=True)
class Meta:
managed = True
- db_table = 'data_masking_columns'
- verbose_name = u'脱敏字段配置'
- verbose_name_plural = u'脱敏字段配置'
+ db_table = "data_masking_columns"
+ verbose_name = "脱敏字段配置"
+ verbose_name_plural = "脱敏字段配置"
class DataMaskingRules(models.Model):
"""
脱敏规则配置
"""
- rule_type = models.IntegerField('规则类型', choices=rule_type_choices, unique=True)
- rule_regex = models.CharField('规则脱敏所用的正则表达式,表达式必须分组,隐藏的组会使用****代替', max_length=255)
- hide_group = models.IntegerField('需要隐藏的组')
- rule_desc = models.CharField('规则描述', max_length=100, default='', blank=True)
+
+ rule_type = models.IntegerField("规则类型", choices=rule_type_choices, unique=True)
+ rule_regex = models.CharField("规则脱敏所用的正则表达式,表达式必须分组,隐藏的组会使用****代替", max_length=255)
+ hide_group = models.IntegerField("需要隐藏的组")
+ rule_desc = models.CharField("规则描述", max_length=100, default="", blank=True)
sys_time = models.DateTimeField(auto_now=True)
class Meta:
managed = True
- db_table = 'data_masking_rules'
- verbose_name = u'脱敏规则配置'
- verbose_name_plural = u'脱敏规则配置'
+ db_table = "data_masking_rules"
+ verbose_name = "脱敏规则配置"
+ verbose_name_plural = "脱敏规则配置"
class InstanceAccount(models.Model):
"""
实例账号列表
"""
+
instance = models.ForeignKey(Instance, on_delete=models.CASCADE)
- user = fields.EncryptedCharField(verbose_name='账号', max_length=128)
- host = models.CharField(verbose_name='主机', max_length=64)
- password = fields.EncryptedCharField(verbose_name='密码', max_length=128, default='', blank=True)
- remark = models.CharField('备注', max_length=255)
- sys_time = models.DateTimeField('系统修改时间', auto_now=True)
+ user = fields.EncryptedCharField(verbose_name="账号", max_length=128)
+ host = models.CharField(verbose_name="主机", max_length=64)
+ password = fields.EncryptedCharField(
+ verbose_name="密码", max_length=128, default="", blank=True
+ )
+ remark = models.CharField("备注", max_length=255)
+ sys_time = models.DateTimeField("系统修改时间", auto_now=True)
class Meta:
managed = True
- db_table = 'instance_account'
- unique_together = ('instance', 'user', 'host')
- verbose_name = '实例账号列表'
- verbose_name_plural = '实例账号列表'
+ db_table = "instance_account"
+ unique_together = ("instance", "user", "host")
+ verbose_name = "实例账号列表"
+ verbose_name_plural = "实例账号列表"
class InstanceDatabase(models.Model):
"""
实例数据库列表
"""
+
instance = models.ForeignKey(Instance, on_delete=models.CASCADE)
- db_name = models.CharField('数据库名', max_length=128)
- owner = models.CharField('负责人', max_length=50, default='', blank=True)
- owner_display = models.CharField('负责人中文名', max_length=50, default='', blank=True)
- remark = models.CharField('备注', max_length=255, default='', blank=True)
- sys_time = models.DateTimeField('系统修改时间', auto_now=True)
+ db_name = models.CharField("数据库名", max_length=128)
+ owner = models.CharField("负责人", max_length=50, default="", blank=True)
+ owner_display = models.CharField("负责人中文名", max_length=50, default="", blank=True)
+ remark = models.CharField("备注", max_length=255, default="", blank=True)
+ sys_time = models.DateTimeField("系统修改时间", auto_now=True)
class Meta:
managed = True
- db_table = 'instance_database'
- unique_together = ('instance', 'db_name')
- verbose_name = '实例数据库'
- verbose_name_plural = '实例数据库列表'
+ db_table = "instance_database"
+ unique_together = ("instance", "db_name")
+ verbose_name = "实例数据库"
+ verbose_name_plural = "实例数据库列表"
class ParamTemplate(models.Model):
"""
实例参数模板配置
"""
- db_type = models.CharField('数据库类型', max_length=20, choices=DB_TYPE_CHOICES)
- variable_name = models.CharField('参数名', max_length=64)
- default_value = models.CharField('默认参数值', max_length=1024)
- editable = models.BooleanField('是否支持修改', default=False)
- valid_values = models.CharField('有效参数值,范围参数[1-65535],值参数[ON|OFF]', max_length=1024, blank=True)
- description = models.CharField('参数描述', max_length=1024, blank=True)
- create_time = models.DateTimeField('创建时间', auto_now_add=True)
- sys_time = models.DateTimeField('系统时间修改', auto_now=True)
+
+ db_type = models.CharField("数据库类型", max_length=20, choices=DB_TYPE_CHOICES)
+ variable_name = models.CharField("参数名", max_length=64)
+ default_value = models.CharField("默认参数值", max_length=1024)
+ editable = models.BooleanField("是否支持修改", default=False)
+ valid_values = models.CharField(
+ "有效参数值,范围参数[1-65535],值参数[ON|OFF]", max_length=1024, blank=True
+ )
+ description = models.CharField("参数描述", max_length=1024, blank=True)
+ create_time = models.DateTimeField("创建时间", auto_now_add=True)
+ sys_time = models.DateTimeField("系统时间修改", auto_now=True)
class Meta:
managed = True
- db_table = 'param_template'
- unique_together = ('db_type', 'variable_name')
- verbose_name = u'实例参数模板配置'
- verbose_name_plural = u'实例参数模板配置'
+ db_table = "param_template"
+ unique_together = ("db_type", "variable_name")
+ verbose_name = "实例参数模板配置"
+ verbose_name_plural = "实例参数模板配置"
class ParamHistory(models.Model):
"""
可在线修改的动态参数配置
"""
+
instance = models.ForeignKey(Instance, on_delete=models.CASCADE)
- variable_name = models.CharField('参数名', max_length=64)
- old_var = models.CharField('修改前参数值', max_length=1024)
- new_var = models.CharField('修改后参数值', max_length=1024)
- set_sql = models.CharField('在线变更配置执行的SQL语句', max_length=1024)
- user_name = models.CharField('修改人', max_length=30)
- user_display = models.CharField('修改人中文名', max_length=50)
- create_time = models.DateTimeField('参数被修改时间点', auto_now_add=True)
+ variable_name = models.CharField("参数名", max_length=64)
+ old_var = models.CharField("修改前参数值", max_length=1024)
+ new_var = models.CharField("修改后参数值", max_length=1024)
+ set_sql = models.CharField("在线变更配置执行的SQL语句", max_length=1024)
+ user_name = models.CharField("修改人", max_length=30)
+ user_display = models.CharField("修改人中文名", max_length=50)
+ create_time = models.DateTimeField("参数被修改时间点", auto_now_add=True)
class Meta:
managed = True
- ordering = ['-create_time']
- db_table = 'param_history'
- verbose_name = u'实例参数修改历史'
- verbose_name_plural = u'实例参数修改历史'
+ ordering = ["-create_time"]
+ db_table = "param_history"
+ verbose_name = "实例参数修改历史"
+ verbose_name_plural = "实例参数修改历史"
class ArchiveConfig(models.Model):
"""
归档配置表
"""
- title = models.CharField('归档配置说明', max_length=50)
+
+ title = models.CharField("归档配置说明", max_length=50)
resource_group = models.ForeignKey(ResourceGroup, on_delete=models.CASCADE)
- audit_auth_groups = models.CharField('审批权限组列表', max_length=255, blank=True)
- src_instance = models.ForeignKey(Instance, related_name='src_instance', on_delete=models.CASCADE)
- src_db_name = models.CharField('源数据库', max_length=64)
- src_table_name = models.CharField('源表', max_length=64)
- dest_instance = models.ForeignKey(Instance, related_name='dest_instance', on_delete=models.CASCADE,
- blank=True, null=True)
- dest_db_name = models.CharField('目标数据库', max_length=64, blank=True, null=True)
- dest_table_name = models.CharField('目标表', max_length=64, blank=True, null=True)
- condition = models.CharField('归档条件,where条件', max_length=1000)
- mode = models.CharField('归档模式', max_length=10, choices=(('file', '文件'), ('dest', '其他实例'), ('purge', '直接删除')))
- no_delete = models.BooleanField('是否保留源数据')
- sleep = models.IntegerField('归档limit行后的休眠秒数', default=1)
- status = models.IntegerField('审核状态', choices=workflow_status_choices, blank=True, default=1)
- state = models.BooleanField('是否启用归档', default=True)
- user_name = models.CharField('申请人', max_length=30, blank=True, default='')
- user_display = models.CharField('申请人中文名', max_length=50, blank=True, default='')
- create_time = models.DateTimeField('创建时间', auto_now_add=True)
- last_archive_time = models.DateTimeField('最近归档时间', blank=True, null=True)
- sys_time = models.DateTimeField('系统时间修改', auto_now=True)
+ audit_auth_groups = models.CharField("审批权限组列表", max_length=255, blank=True)
+ src_instance = models.ForeignKey(
+ Instance, related_name="src_instance", on_delete=models.CASCADE
+ )
+ src_db_name = models.CharField("源数据库", max_length=64)
+ src_table_name = models.CharField("源表", max_length=64)
+ dest_instance = models.ForeignKey(
+ Instance,
+ related_name="dest_instance",
+ on_delete=models.CASCADE,
+ blank=True,
+ null=True,
+ )
+ dest_db_name = models.CharField("目标数据库", max_length=64, blank=True, null=True)
+ dest_table_name = models.CharField("目标表", max_length=64, blank=True, null=True)
+ condition = models.CharField("归档条件,where条件", max_length=1000)
+ mode = models.CharField(
+ "归档模式",
+ max_length=10,
+ choices=(("file", "文件"), ("dest", "其他实例"), ("purge", "直接删除")),
+ )
+ no_delete = models.BooleanField("是否保留源数据")
+ sleep = models.IntegerField("归档limit行后的休眠秒数", default=1)
+ status = models.IntegerField(
+ "审核状态", choices=workflow_status_choices, blank=True, default=1
+ )
+ state = models.BooleanField("是否启用归档", default=True)
+ user_name = models.CharField("申请人", max_length=30, blank=True, default="")
+ user_display = models.CharField("申请人中文名", max_length=50, blank=True, default="")
+ create_time = models.DateTimeField("创建时间", auto_now_add=True)
+ last_archive_time = models.DateTimeField("最近归档时间", blank=True, null=True)
+ sys_time = models.DateTimeField("系统时间修改", auto_now=True)
class Meta:
managed = True
- db_table = 'archive_config'
- verbose_name = u'归档配置表'
- verbose_name_plural = u'归档配置表'
+ db_table = "archive_config"
+ verbose_name = "归档配置表"
+ verbose_name_plural = "归档配置表"
class ArchiveLog(models.Model):
"""
归档日志表
"""
+
archive = models.ForeignKey(ArchiveConfig, on_delete=models.CASCADE)
- cmd = models.CharField('归档命令', max_length=2000)
- condition = models.CharField('归档条件,where条件', max_length=1000)
- mode = models.CharField('归档模式', max_length=10, choices=(('file', '文件'), ('dest', '其他实例'), ('purge', '直接删除')))
- no_delete = models.BooleanField('是否保留源数据')
- sleep = models.IntegerField('归档limit行记录后的休眠秒数', default=0)
- select_cnt = models.IntegerField('查询数量')
- insert_cnt = models.IntegerField('插入数量')
- delete_cnt = models.IntegerField('删除数量')
- statistics = models.TextField('归档统计日志')
- success = models.BooleanField('是否归档成功')
- error_info = models.TextField('错误信息')
- start_time = models.DateTimeField('开始时间')
- end_time = models.DateTimeField('结束时间')
- sys_time = models.DateTimeField('系统时间修改', auto_now=True)
+ cmd = models.CharField("归档命令", max_length=2000)
+ condition = models.CharField("归档条件,where条件", max_length=1000)
+ mode = models.CharField(
+ "归档模式",
+ max_length=10,
+ choices=(("file", "文件"), ("dest", "其他实例"), ("purge", "直接删除")),
+ )
+ no_delete = models.BooleanField("是否保留源数据")
+ sleep = models.IntegerField("归档limit行记录后的休眠秒数", default=0)
+ select_cnt = models.IntegerField("查询数量")
+ insert_cnt = models.IntegerField("插入数量")
+ delete_cnt = models.IntegerField("删除数量")
+ statistics = models.TextField("归档统计日志")
+ success = models.BooleanField("是否归档成功")
+ error_info = models.TextField("错误信息")
+ start_time = models.DateTimeField("开始时间")
+ end_time = models.DateTimeField("结束时间")
+ sys_time = models.DateTimeField("系统时间修改", auto_now=True)
class Meta:
managed = True
- db_table = 'archive_log'
- verbose_name = u'归档日志表'
- verbose_name_plural = u'归档日志表'
+ db_table = "archive_log"
+ verbose_name = "归档日志表"
+ verbose_name_plural = "归档日志表"
class Config(models.Model):
"""
配置信息表
"""
- item = models.CharField('配置项', max_length=100, unique=True)
- value = fields.EncryptedCharField(verbose_name='配置项值', max_length=500)
- description = models.CharField('描述', max_length=200, default='', blank=True)
+
+ item = models.CharField("配置项", max_length=100, unique=True)
+ value = fields.EncryptedCharField(verbose_name="配置项值", max_length=500)
+ description = models.CharField("描述", max_length=200, default="", blank=True)
class Meta:
managed = True
- db_table = 'sql_config'
- verbose_name = u'系统配置'
- verbose_name_plural = u'系统配置'
+ db_table = "sql_config"
+ verbose_name = "系统配置"
+ verbose_name_plural = "系统配置"
# 云服务认证信息配置
class CloudAccessKey(models.Model):
- cloud_type_choices = (('aliyun', 'aliyun'),)
+ cloud_type_choices = (("aliyun", "aliyun"),)
- type = models.CharField(max_length=20, default='', choices=cloud_type_choices)
+ type = models.CharField(max_length=20, default="", choices=cloud_type_choices)
key_id = models.CharField(max_length=200)
key_secret = models.CharField(max_length=200)
- remark = models.CharField(max_length=50, default='', blank=True)
+ remark = models.CharField(max_length=50, default="", blank=True)
def __init__(self, *args, **kwargs):
self.c = Crypto()
@@ -657,12 +796,12 @@ def __init__(self, *args, **kwargs):
@property
def raw_key_id(self):
- """ 返回明文信息"""
+ """返回明文信息"""
return self.c.decrypt(self.key_id)
@property
def raw_key_secret(self):
- """ 返回明文信息"""
+ """返回明文信息"""
return self.c.decrypt(self.key_secret)
def save(self, *args, **kwargs):
@@ -671,32 +810,35 @@ def save(self, *args, **kwargs):
super(CloudAccessKey, self).save(*args, **kwargs)
def __str__(self):
- return f'{self.type}({self.remark})'
+ return f"{self.type}({self.remark})"
class Meta:
managed = True
- db_table = 'cloud_access_key'
- verbose_name = u'云服务认证信息配置'
- verbose_name_plural = u'云服务认证信息配置'
+ db_table = "cloud_access_key"
+ verbose_name = "云服务认证信息配置"
+ verbose_name_plural = "云服务认证信息配置"
class AliyunRdsConfig(models.Model):
"""
阿里云rds配置信息
"""
+
instance = models.OneToOneField(Instance, on_delete=models.CASCADE)
- rds_dbinstanceid = models.CharField('对应阿里云RDS实例ID', max_length=100)
- ak = models.ForeignKey(CloudAccessKey, verbose_name='RDS实例对应的AK配置', on_delete=models.CASCADE)
- is_enable = models.BooleanField('是否启用', default=False)
+ rds_dbinstanceid = models.CharField("对应阿里云RDS实例ID", max_length=100)
+ ak = models.ForeignKey(
+ CloudAccessKey, verbose_name="RDS实例对应的AK配置", on_delete=models.CASCADE
+ )
+ is_enable = models.BooleanField("是否启用", default=False)
def __int__(self):
return self.rds_dbinstanceid
class Meta:
managed = True
- db_table = 'aliyun_rds_config'
- verbose_name = u'阿里云rds配置'
- verbose_name_plural = u'阿里云rds配置'
+ db_table = "aliyun_rds_config"
+ verbose_name = "阿里云rds配置"
+ verbose_name_plural = "阿里云rds配置"
class Permission(models.Model):
@@ -707,58 +849,58 @@ class Permission(models.Model):
class Meta:
managed = True
permissions = (
- ('menu_dashboard', '菜单 Dashboard'),
- ('menu_sqlcheck', '菜单 SQL审核'),
- ('menu_sqlworkflow', '菜单 SQL上线'),
- ('menu_sqlanalyze', '菜单 SQL分析'),
- ('menu_query', '菜单 SQL查询'),
- ('menu_sqlquery', '菜单 在线查询'),
- ('menu_queryapplylist', '菜单 权限管理'),
- ('menu_sqloptimize', '菜单 SQL优化'),
- ('menu_sqladvisor', '菜单 优化工具'),
- ('menu_slowquery', '菜单 慢查日志'),
- ('menu_instance', '菜单 实例管理'),
- ('menu_instance_list', '菜单 实例列表'),
- ('menu_dbdiagnostic', '菜单 会话管理'),
- ('menu_database', '菜单 数据库管理'),
- ('menu_instance_account', '菜单 实例账号管理'),
- ('menu_param', '菜单 参数配置'),
- ('menu_data_dictionary', '菜单 数据字典'),
- ('menu_tools', '菜单 工具插件'),
- ('menu_archive', '菜单 数据归档'),
- ('menu_my2sql', '菜单 My2SQL'),
- ('menu_schemasync', '菜单 SchemaSync'),
- ('menu_system', '菜单 系统管理'),
- ('menu_document', '菜单 相关文档'),
- ('menu_openapi', '菜单 OpenAPI'),
- ('sql_submit', '提交SQL上线工单'),
- ('sql_review', '审核SQL上线工单'),
- ('sql_execute_for_resource_group', '执行SQL上线工单(资源组粒度)'),
- ('sql_execute', '执行SQL上线工单(仅自己提交的)'),
- ('sql_analyze', '执行SQL分析'),
- ('optimize_sqladvisor', '执行SQLAdvisor'),
- ('optimize_sqltuning', '执行SQLTuning'),
- ('optimize_soar', '执行SOAR'),
- ('query_applypriv', '申请查询权限'),
- ('query_mgtpriv', '管理查询权限'),
- ('query_review', '审核查询权限'),
- ('query_submit', '提交SQL查询'),
- ('query_all_instances', '可查询所有实例'),
- ('query_resource_group_instance', '可查询所在资源组内的所有实例'),
- ('process_view', '查看会话'),
- ('process_kill', '终止会话'),
- ('tablespace_view', '查看表空间'),
- ('trx_view', '查看事务信息'),
- ('trxandlocks_view', '查看锁信息'),
- ('instance_account_manage', '管理实例账号'),
- ('param_view', '查看实例参数列表'),
- ('param_edit', '修改实例参数'),
- ('data_dictionary_export', '导出数据字典'),
- ('archive_apply', '提交归档申请'),
- ('archive_review', '审核归档申请'),
- ('archive_mgt', '管理归档申请'),
- ('audit_user','审计权限'),
- ('query_download', '在线查询下载权限'),
+ ("menu_dashboard", "菜单 Dashboard"),
+ ("menu_sqlcheck", "菜单 SQL审核"),
+ ("menu_sqlworkflow", "菜单 SQL上线"),
+ ("menu_sqlanalyze", "菜单 SQL分析"),
+ ("menu_query", "菜单 SQL查询"),
+ ("menu_sqlquery", "菜单 在线查询"),
+ ("menu_queryapplylist", "菜单 权限管理"),
+ ("menu_sqloptimize", "菜单 SQL优化"),
+ ("menu_sqladvisor", "菜单 优化工具"),
+ ("menu_slowquery", "菜单 慢查日志"),
+ ("menu_instance", "菜单 实例管理"),
+ ("menu_instance_list", "菜单 实例列表"),
+ ("menu_dbdiagnostic", "菜单 会话管理"),
+ ("menu_database", "菜单 数据库管理"),
+ ("menu_instance_account", "菜单 实例账号管理"),
+ ("menu_param", "菜单 参数配置"),
+ ("menu_data_dictionary", "菜单 数据字典"),
+ ("menu_tools", "菜单 工具插件"),
+ ("menu_archive", "菜单 数据归档"),
+ ("menu_my2sql", "菜单 My2SQL"),
+ ("menu_schemasync", "菜单 SchemaSync"),
+ ("menu_system", "菜单 系统管理"),
+ ("menu_document", "菜单 相关文档"),
+ ("menu_openapi", "菜单 OpenAPI"),
+ ("sql_submit", "提交SQL上线工单"),
+ ("sql_review", "审核SQL上线工单"),
+ ("sql_execute_for_resource_group", "执行SQL上线工单(资源组粒度)"),
+ ("sql_execute", "执行SQL上线工单(仅自己提交的)"),
+ ("sql_analyze", "执行SQL分析"),
+ ("optimize_sqladvisor", "执行SQLAdvisor"),
+ ("optimize_sqltuning", "执行SQLTuning"),
+ ("optimize_soar", "执行SOAR"),
+ ("query_applypriv", "申请查询权限"),
+ ("query_mgtpriv", "管理查询权限"),
+ ("query_review", "审核查询权限"),
+ ("query_submit", "提交SQL查询"),
+ ("query_all_instances", "可查询所有实例"),
+ ("query_resource_group_instance", "可查询所在资源组内的所有实例"),
+ ("process_view", "查看会话"),
+ ("process_kill", "终止会话"),
+ ("tablespace_view", "查看表空间"),
+ ("trx_view", "查看事务信息"),
+ ("trxandlocks_view", "查看锁信息"),
+ ("instance_account_manage", "管理实例账号"),
+ ("param_view", "查看实例参数列表"),
+ ("param_edit", "修改实例参数"),
+ ("data_dictionary_export", "导出数据字典"),
+ ("archive_apply", "提交归档申请"),
+ ("archive_review", "审核归档申请"),
+ ("archive_mgt", "管理归档申请"),
+ ("audit_user", "审计权限"),
+ ("query_download", "在线查询下载权限"),
)
@@ -766,6 +908,7 @@ class SlowQuery(models.Model):
"""
SlowQuery
"""
+
checksum = models.CharField(max_length=32, primary_key=True)
fingerprint = models.TextField()
sample = models.TextField()
@@ -777,143 +920,286 @@ class SlowQuery(models.Model):
class Meta:
managed = False
- db_table = 'mysql_slow_query_review'
- verbose_name = u'慢日志统计'
- verbose_name_plural = u'慢日志统计'
+ db_table = "mysql_slow_query_review"
+ verbose_name = "慢日志统计"
+ verbose_name_plural = "慢日志统计"
class SlowQueryHistory(models.Model):
"""
SlowQueryHistory
"""
+
hostname_max = models.CharField(max_length=64, null=False)
client_max = models.CharField(max_length=64, null=True)
user_max = models.CharField(max_length=64, null=False)
db_max = models.CharField(max_length=64, null=True, default=None)
bytes_max = models.CharField(max_length=64, null=True)
- checksum = models.ForeignKey(SlowQuery, db_constraint=False, to_field='checksum', db_column='checksum',
- on_delete=models.CASCADE)
+ checksum = models.ForeignKey(
+ SlowQuery,
+ db_constraint=False,
+ to_field="checksum",
+ db_column="checksum",
+ on_delete=models.CASCADE,
+ )
sample = models.TextField()
ts_min = models.DateTimeField(db_index=True)
ts_max = models.DateTimeField()
ts_cnt = models.FloatField(blank=True, null=True)
- query_time_sum = models.FloatField(db_column='Query_time_sum', blank=True, null=True)
- query_time_min = models.FloatField(db_column='Query_time_min', blank=True, null=True)
- query_time_max = models.FloatField(db_column='Query_time_max', blank=True, null=True)
- query_time_pct_95 = models.FloatField(db_column='Query_time_pct_95', blank=True, null=True)
- query_time_stddev = models.FloatField(db_column='Query_time_stddev', blank=True, null=True)
- query_time_median = models.FloatField(db_column='Query_time_median', blank=True, null=True)
- lock_time_sum = models.FloatField(db_column='Lock_time_sum', blank=True, null=True)
- lock_time_min = models.FloatField(db_column='Lock_time_min', blank=True, null=True)
- lock_time_max = models.FloatField(db_column='Lock_time_max', blank=True, null=True)
- lock_time_pct_95 = models.FloatField(db_column='Lock_time_pct_95', blank=True, null=True)
- lock_time_stddev = models.FloatField(db_column='Lock_time_stddev', blank=True, null=True)
- lock_time_median = models.FloatField(db_column='Lock_time_median', blank=True, null=True)
- rows_sent_sum = models.FloatField(db_column='Rows_sent_sum', blank=True, null=True)
- rows_sent_min = models.FloatField(db_column='Rows_sent_min', blank=True, null=True)
- rows_sent_max = models.FloatField(db_column='Rows_sent_max', blank=True, null=True)
- rows_sent_pct_95 = models.FloatField(db_column='Rows_sent_pct_95', blank=True, null=True)
- rows_sent_stddev = models.FloatField(db_column='Rows_sent_stddev', blank=True, null=True)
- rows_sent_median = models.FloatField(db_column='Rows_sent_median', blank=True, null=True)
- rows_examined_sum = models.FloatField(db_column='Rows_examined_sum', blank=True, null=True)
- rows_examined_min = models.FloatField(db_column='Rows_examined_min', blank=True, null=True)
- rows_examined_max = models.FloatField(db_column='Rows_examined_max', blank=True, null=True)
- rows_examined_pct_95 = models.FloatField(db_column='Rows_examined_pct_95', blank=True, null=True)
- rows_examined_stddev = models.FloatField(db_column='Rows_examined_stddev', blank=True, null=True)
- rows_examined_median = models.FloatField(db_column='Rows_examined_median', blank=True, null=True)
- rows_affected_sum = models.FloatField(db_column='Rows_affected_sum', blank=True, null=True)
- rows_affected_min = models.FloatField(db_column='Rows_affected_min', blank=True, null=True)
- rows_affected_max = models.FloatField(db_column='Rows_affected_max', blank=True, null=True)
- rows_affected_pct_95 = models.FloatField(db_column='Rows_affected_pct_95', blank=True, null=True)
- rows_affected_stddev = models.FloatField(db_column='Rows_affected_stddev', blank=True, null=True)
- rows_affected_median = models.FloatField(db_column='Rows_affected_median', blank=True, null=True)
- rows_read_sum = models.FloatField(db_column='Rows_read_sum', blank=True, null=True)
- rows_read_min = models.FloatField(db_column='Rows_read_min', blank=True, null=True)
- rows_read_max = models.FloatField(db_column='Rows_read_max', blank=True, null=True)
- rows_read_pct_95 = models.FloatField(db_column='Rows_read_pct_95', blank=True, null=True)
- rows_read_stddev = models.FloatField(db_column='Rows_read_stddev', blank=True, null=True)
- rows_read_median = models.FloatField(db_column='Rows_read_median', blank=True, null=True)
- merge_passes_sum = models.FloatField(db_column='Merge_passes_sum', blank=True, null=True)
- merge_passes_min = models.FloatField(db_column='Merge_passes_min', blank=True, null=True)
- merge_passes_max = models.FloatField(db_column='Merge_passes_max', blank=True, null=True)
- merge_passes_pct_95 = models.FloatField(db_column='Merge_passes_pct_95', blank=True, null=True)
- merge_passes_stddev = models.FloatField(db_column='Merge_passes_stddev', blank=True, null=True)
- merge_passes_median = models.FloatField(db_column='Merge_passes_median', blank=True, null=True)
- innodb_io_r_ops_min = models.FloatField(db_column='InnoDB_IO_r_ops_min', blank=True, null=True)
- innodb_io_r_ops_max = models.FloatField(db_column='InnoDB_IO_r_ops_max', blank=True, null=True)
- innodb_io_r_ops_pct_95 = models.FloatField(db_column='InnoDB_IO_r_ops_pct_95', blank=True, null=True)
- innodb_io_r_ops_stddev = models.FloatField(db_column='InnoDB_IO_r_ops_stddev', blank=True, null=True)
- innodb_io_r_ops_median = models.FloatField(db_column='InnoDB_IO_r_ops_median', blank=True, null=True)
- innodb_io_r_bytes_min = models.FloatField(db_column='InnoDB_IO_r_bytes_min', blank=True, null=True)
- innodb_io_r_bytes_max = models.FloatField(db_column='InnoDB_IO_r_bytes_max', blank=True, null=True)
- innodb_io_r_bytes_pct_95 = models.FloatField(db_column='InnoDB_IO_r_bytes_pct_95', blank=True, null=True)
- innodb_io_r_bytes_stddev = models.FloatField(db_column='InnoDB_IO_r_bytes_stddev', blank=True, null=True)
- innodb_io_r_bytes_median = models.FloatField(db_column='InnoDB_IO_r_bytes_median', blank=True, null=True)
- innodb_io_r_wait_min = models.FloatField(db_column='InnoDB_IO_r_wait_min', blank=True, null=True)
- innodb_io_r_wait_max = models.FloatField(db_column='InnoDB_IO_r_wait_max', blank=True, null=True)
- innodb_io_r_wait_pct_95 = models.FloatField(db_column='InnoDB_IO_r_wait_pct_95', blank=True, null=True)
- innodb_io_r_wait_stddev = models.FloatField(db_column='InnoDB_IO_r_wait_stddev', blank=True, null=True)
- innodb_io_r_wait_median = models.FloatField(db_column='InnoDB_IO_r_wait_median', blank=True, null=True)
- innodb_rec_lock_wait_min = models.FloatField(db_column='InnoDB_rec_lock_wait_min', blank=True, null=True)
- innodb_rec_lock_wait_max = models.FloatField(db_column='InnoDB_rec_lock_wait_max', blank=True, null=True)
- innodb_rec_lock_wait_pct_95 = models.FloatField(db_column='InnoDB_rec_lock_wait_pct_95', blank=True, null=True)
- innodb_rec_lock_wait_stddev = models.FloatField(db_column='InnoDB_rec_lock_wait_stddev', blank=True, null=True)
- innodb_rec_lock_wait_median = models.FloatField(db_column='InnoDB_rec_lock_wait_median', blank=True, null=True)
- innodb_queue_wait_min = models.FloatField(db_column='InnoDB_queue_wait_min', blank=True, null=True)
- innodb_queue_wait_max = models.FloatField(db_column='InnoDB_queue_wait_max', blank=True, null=True)
- innodb_queue_wait_pct_95 = models.FloatField(db_column='InnoDB_queue_wait_pct_95', blank=True, null=True)
- innodb_queue_wait_stddev = models.FloatField(db_column='InnoDB_queue_wait_stddev', blank=True, null=True)
- innodb_queue_wait_median = models.FloatField(db_column='InnoDB_queue_wait_median', blank=True, null=True)
- innodb_pages_distinct_min = models.FloatField(db_column='InnoDB_pages_distinct_min', blank=True, null=True)
- innodb_pages_distinct_max = models.FloatField(db_column='InnoDB_pages_distinct_max', blank=True, null=True)
- innodb_pages_distinct_pct_95 = models.FloatField(db_column='InnoDB_pages_distinct_pct_95', blank=True, null=True)
- innodb_pages_distinct_stddev = models.FloatField(db_column='InnoDB_pages_distinct_stddev', blank=True, null=True)
- innodb_pages_distinct_median = models.FloatField(db_column='InnoDB_pages_distinct_median', blank=True, null=True)
- qc_hit_cnt = models.FloatField(db_column='QC_Hit_cnt', blank=True, null=True)
- qc_hit_sum = models.FloatField(db_column='QC_Hit_sum', blank=True, null=True)
- full_scan_cnt = models.FloatField(db_column='Full_scan_cnt', blank=True, null=True)
- full_scan_sum = models.FloatField(db_column='Full_scan_sum', blank=True, null=True)
- full_join_cnt = models.FloatField(db_column='Full_join_cnt', blank=True, null=True)
- full_join_sum = models.FloatField(db_column='Full_join_sum', blank=True, null=True)
- tmp_table_cnt = models.FloatField(db_column='Tmp_table_cnt', blank=True, null=True)
- tmp_table_sum = models.FloatField(db_column='Tmp_table_sum', blank=True, null=True)
- tmp_table_on_disk_cnt = models.FloatField(db_column='Tmp_table_on_disk_cnt', blank=True, null=True)
- tmp_table_on_disk_sum = models.FloatField(db_column='Tmp_table_on_disk_sum', blank=True, null=True)
- filesort_cnt = models.FloatField(db_column='Filesort_cnt', blank=True, null=True)
- filesort_sum = models.FloatField(db_column='Filesort_sum', blank=True, null=True)
- filesort_on_disk_cnt = models.FloatField(db_column='Filesort_on_disk_cnt', blank=True, null=True)
- filesort_on_disk_sum = models.FloatField(db_column='Filesort_on_disk_sum', blank=True, null=True)
+ query_time_sum = models.FloatField(
+ db_column="Query_time_sum", blank=True, null=True
+ )
+ query_time_min = models.FloatField(
+ db_column="Query_time_min", blank=True, null=True
+ )
+ query_time_max = models.FloatField(
+ db_column="Query_time_max", blank=True, null=True
+ )
+ query_time_pct_95 = models.FloatField(
+ db_column="Query_time_pct_95", blank=True, null=True
+ )
+ query_time_stddev = models.FloatField(
+ db_column="Query_time_stddev", blank=True, null=True
+ )
+ query_time_median = models.FloatField(
+ db_column="Query_time_median", blank=True, null=True
+ )
+ lock_time_sum = models.FloatField(db_column="Lock_time_sum", blank=True, null=True)
+ lock_time_min = models.FloatField(db_column="Lock_time_min", blank=True, null=True)
+ lock_time_max = models.FloatField(db_column="Lock_time_max", blank=True, null=True)
+ lock_time_pct_95 = models.FloatField(
+ db_column="Lock_time_pct_95", blank=True, null=True
+ )
+ lock_time_stddev = models.FloatField(
+ db_column="Lock_time_stddev", blank=True, null=True
+ )
+ lock_time_median = models.FloatField(
+ db_column="Lock_time_median", blank=True, null=True
+ )
+ rows_sent_sum = models.FloatField(db_column="Rows_sent_sum", blank=True, null=True)
+ rows_sent_min = models.FloatField(db_column="Rows_sent_min", blank=True, null=True)
+ rows_sent_max = models.FloatField(db_column="Rows_sent_max", blank=True, null=True)
+ rows_sent_pct_95 = models.FloatField(
+ db_column="Rows_sent_pct_95", blank=True, null=True
+ )
+ rows_sent_stddev = models.FloatField(
+ db_column="Rows_sent_stddev", blank=True, null=True
+ )
+ rows_sent_median = models.FloatField(
+ db_column="Rows_sent_median", blank=True, null=True
+ )
+ rows_examined_sum = models.FloatField(
+ db_column="Rows_examined_sum", blank=True, null=True
+ )
+ rows_examined_min = models.FloatField(
+ db_column="Rows_examined_min", blank=True, null=True
+ )
+ rows_examined_max = models.FloatField(
+ db_column="Rows_examined_max", blank=True, null=True
+ )
+ rows_examined_pct_95 = models.FloatField(
+ db_column="Rows_examined_pct_95", blank=True, null=True
+ )
+ rows_examined_stddev = models.FloatField(
+ db_column="Rows_examined_stddev", blank=True, null=True
+ )
+ rows_examined_median = models.FloatField(
+ db_column="Rows_examined_median", blank=True, null=True
+ )
+ rows_affected_sum = models.FloatField(
+ db_column="Rows_affected_sum", blank=True, null=True
+ )
+ rows_affected_min = models.FloatField(
+ db_column="Rows_affected_min", blank=True, null=True
+ )
+ rows_affected_max = models.FloatField(
+ db_column="Rows_affected_max", blank=True, null=True
+ )
+ rows_affected_pct_95 = models.FloatField(
+ db_column="Rows_affected_pct_95", blank=True, null=True
+ )
+ rows_affected_stddev = models.FloatField(
+ db_column="Rows_affected_stddev", blank=True, null=True
+ )
+ rows_affected_median = models.FloatField(
+ db_column="Rows_affected_median", blank=True, null=True
+ )
+ rows_read_sum = models.FloatField(db_column="Rows_read_sum", blank=True, null=True)
+ rows_read_min = models.FloatField(db_column="Rows_read_min", blank=True, null=True)
+ rows_read_max = models.FloatField(db_column="Rows_read_max", blank=True, null=True)
+ rows_read_pct_95 = models.FloatField(
+ db_column="Rows_read_pct_95", blank=True, null=True
+ )
+ rows_read_stddev = models.FloatField(
+ db_column="Rows_read_stddev", blank=True, null=True
+ )
+ rows_read_median = models.FloatField(
+ db_column="Rows_read_median", blank=True, null=True
+ )
+ merge_passes_sum = models.FloatField(
+ db_column="Merge_passes_sum", blank=True, null=True
+ )
+ merge_passes_min = models.FloatField(
+ db_column="Merge_passes_min", blank=True, null=True
+ )
+ merge_passes_max = models.FloatField(
+ db_column="Merge_passes_max", blank=True, null=True
+ )
+ merge_passes_pct_95 = models.FloatField(
+ db_column="Merge_passes_pct_95", blank=True, null=True
+ )
+ merge_passes_stddev = models.FloatField(
+ db_column="Merge_passes_stddev", blank=True, null=True
+ )
+ merge_passes_median = models.FloatField(
+ db_column="Merge_passes_median", blank=True, null=True
+ )
+ innodb_io_r_ops_min = models.FloatField(
+ db_column="InnoDB_IO_r_ops_min", blank=True, null=True
+ )
+ innodb_io_r_ops_max = models.FloatField(
+ db_column="InnoDB_IO_r_ops_max", blank=True, null=True
+ )
+ innodb_io_r_ops_pct_95 = models.FloatField(
+ db_column="InnoDB_IO_r_ops_pct_95", blank=True, null=True
+ )
+ innodb_io_r_ops_stddev = models.FloatField(
+ db_column="InnoDB_IO_r_ops_stddev", blank=True, null=True
+ )
+ innodb_io_r_ops_median = models.FloatField(
+ db_column="InnoDB_IO_r_ops_median", blank=True, null=True
+ )
+ innodb_io_r_bytes_min = models.FloatField(
+ db_column="InnoDB_IO_r_bytes_min", blank=True, null=True
+ )
+ innodb_io_r_bytes_max = models.FloatField(
+ db_column="InnoDB_IO_r_bytes_max", blank=True, null=True
+ )
+ innodb_io_r_bytes_pct_95 = models.FloatField(
+ db_column="InnoDB_IO_r_bytes_pct_95", blank=True, null=True
+ )
+ innodb_io_r_bytes_stddev = models.FloatField(
+ db_column="InnoDB_IO_r_bytes_stddev", blank=True, null=True
+ )
+ innodb_io_r_bytes_median = models.FloatField(
+ db_column="InnoDB_IO_r_bytes_median", blank=True, null=True
+ )
+ innodb_io_r_wait_min = models.FloatField(
+ db_column="InnoDB_IO_r_wait_min", blank=True, null=True
+ )
+ innodb_io_r_wait_max = models.FloatField(
+ db_column="InnoDB_IO_r_wait_max", blank=True, null=True
+ )
+ innodb_io_r_wait_pct_95 = models.FloatField(
+ db_column="InnoDB_IO_r_wait_pct_95", blank=True, null=True
+ )
+ innodb_io_r_wait_stddev = models.FloatField(
+ db_column="InnoDB_IO_r_wait_stddev", blank=True, null=True
+ )
+ innodb_io_r_wait_median = models.FloatField(
+ db_column="InnoDB_IO_r_wait_median", blank=True, null=True
+ )
+ innodb_rec_lock_wait_min = models.FloatField(
+ db_column="InnoDB_rec_lock_wait_min", blank=True, null=True
+ )
+ innodb_rec_lock_wait_max = models.FloatField(
+ db_column="InnoDB_rec_lock_wait_max", blank=True, null=True
+ )
+ innodb_rec_lock_wait_pct_95 = models.FloatField(
+ db_column="InnoDB_rec_lock_wait_pct_95", blank=True, null=True
+ )
+ innodb_rec_lock_wait_stddev = models.FloatField(
+ db_column="InnoDB_rec_lock_wait_stddev", blank=True, null=True
+ )
+ innodb_rec_lock_wait_median = models.FloatField(
+ db_column="InnoDB_rec_lock_wait_median", blank=True, null=True
+ )
+ innodb_queue_wait_min = models.FloatField(
+ db_column="InnoDB_queue_wait_min", blank=True, null=True
+ )
+ innodb_queue_wait_max = models.FloatField(
+ db_column="InnoDB_queue_wait_max", blank=True, null=True
+ )
+ innodb_queue_wait_pct_95 = models.FloatField(
+ db_column="InnoDB_queue_wait_pct_95", blank=True, null=True
+ )
+ innodb_queue_wait_stddev = models.FloatField(
+ db_column="InnoDB_queue_wait_stddev", blank=True, null=True
+ )
+ innodb_queue_wait_median = models.FloatField(
+ db_column="InnoDB_queue_wait_median", blank=True, null=True
+ )
+ innodb_pages_distinct_min = models.FloatField(
+ db_column="InnoDB_pages_distinct_min", blank=True, null=True
+ )
+ innodb_pages_distinct_max = models.FloatField(
+ db_column="InnoDB_pages_distinct_max", blank=True, null=True
+ )
+ innodb_pages_distinct_pct_95 = models.FloatField(
+ db_column="InnoDB_pages_distinct_pct_95", blank=True, null=True
+ )
+ innodb_pages_distinct_stddev = models.FloatField(
+ db_column="InnoDB_pages_distinct_stddev", blank=True, null=True
+ )
+ innodb_pages_distinct_median = models.FloatField(
+ db_column="InnoDB_pages_distinct_median", blank=True, null=True
+ )
+ qc_hit_cnt = models.FloatField(db_column="QC_Hit_cnt", blank=True, null=True)
+ qc_hit_sum = models.FloatField(db_column="QC_Hit_sum", blank=True, null=True)
+ full_scan_cnt = models.FloatField(db_column="Full_scan_cnt", blank=True, null=True)
+ full_scan_sum = models.FloatField(db_column="Full_scan_sum", blank=True, null=True)
+ full_join_cnt = models.FloatField(db_column="Full_join_cnt", blank=True, null=True)
+ full_join_sum = models.FloatField(db_column="Full_join_sum", blank=True, null=True)
+ tmp_table_cnt = models.FloatField(db_column="Tmp_table_cnt", blank=True, null=True)
+ tmp_table_sum = models.FloatField(db_column="Tmp_table_sum", blank=True, null=True)
+ tmp_table_on_disk_cnt = models.FloatField(
+ db_column="Tmp_table_on_disk_cnt", blank=True, null=True
+ )
+ tmp_table_on_disk_sum = models.FloatField(
+ db_column="Tmp_table_on_disk_sum", blank=True, null=True
+ )
+ filesort_cnt = models.FloatField(db_column="Filesort_cnt", blank=True, null=True)
+ filesort_sum = models.FloatField(db_column="Filesort_sum", blank=True, null=True)
+ filesort_on_disk_cnt = models.FloatField(
+ db_column="Filesort_on_disk_cnt", blank=True, null=True
+ )
+ filesort_on_disk_sum = models.FloatField(
+ db_column="Filesort_on_disk_sum", blank=True, null=True
+ )
class Meta:
managed = False
- db_table = 'mysql_slow_query_review_history'
- unique_together = ('checksum', 'ts_min', 'ts_max')
- index_together = ('hostname_max', 'ts_min')
- verbose_name = u'慢日志明细'
- verbose_name_plural = u'慢日志明细'
+ db_table = "mysql_slow_query_review_history"
+ unique_together = ("checksum", "ts_min", "ts_max")
+ index_together = ("hostname_max", "ts_min")
+ verbose_name = "慢日志明细"
+ verbose_name_plural = "慢日志明细"
class AuditEntry(models.Model):
"""
登录审计日志
"""
- user_id = models.IntegerField('用户ID')
- user_name = models.CharField('用户名称', max_length=30, null=True)
- user_display = models.CharField('用户中文名',max_length=50, null=True)
- action = models.CharField('动作', max_length=255)
- extra_info = models.TextField('额外的信息', null=True)
- action_time = models.DateTimeField('操作时间', auto_now_add=True)
+
+ user_id = models.IntegerField("用户ID")
+ user_name = models.CharField("用户名称", max_length=30, null=True)
+ user_display = models.CharField("用户中文名", max_length=50, null=True)
+ action = models.CharField("动作", max_length=255)
+ extra_info = models.TextField("额外的信息", null=True)
+ action_time = models.DateTimeField("操作时间", auto_now_add=True)
class Meta:
managed = True
- db_table = 'audit_log'
- verbose_name = u'审计日志'
- verbose_name_plural = u'审计日志'
+ db_table = "audit_log"
+ verbose_name = "审计日志"
+ verbose_name_plural = "审计日志"
def __unicode__(self):
- return '{0} - {1} - {2} - {3} - {4}'.format(self.user_id, self.user_name, self.extra_info
- , self.action, self.action_time)
+ return "{0} - {1} - {2} - {3} - {4}".format(
+ self.user_id, self.user_name, self.extra_info, self.action, self.action_time
+ )
def __str__(self):
- return '{0} - {1} - {2} - {3} - {4}'.format(self.user_id, self.user_name, self.extra_info
- , self.action, self.action_time)
+ return "{0} - {1} - {2} - {3} - {4}".format(
+ self.user_id, self.user_name, self.extra_info, self.action, self.action_time
+ )
diff --git a/sql/notify.py b/sql/notify.py
index ca70c5fd0e..5728b4a431 100755
--- a/sql/notify.py
+++ b/sql/notify.py
@@ -5,7 +5,13 @@
from django.contrib.auth.models import Group
from common.config import SysConfig
-from sql.models import QueryPrivilegesApply, Users, SqlWorkflow, ResourceGroup, ArchiveConfig
+from sql.models import (
+ QueryPrivilegesApply,
+ Users,
+ SqlWorkflow,
+ ResourceGroup,
+ ArchiveConfig,
+)
from sql.utils.resource_group import auth_group_users
from common.utils.sendmsg import MsgSender
from common.utils.const import WorkflowDict
@@ -13,22 +19,31 @@
import logging
-logger = logging.getLogger('default')
+logger = logging.getLogger("default")
def __notify_cnf_status():
"""返回消息通知开关"""
sys_config = SysConfig()
- mail_status = sys_config.get('mail')
- ding_status = sys_config.get('ding_to_person')
- ding_webhook_status = sys_config.get('ding')
- wx_status = sys_config.get('wx')
+ mail_status = sys_config.get("mail")
+ ding_status = sys_config.get("ding_to_person")
+ ding_webhook_status = sys_config.get("ding")
+ wx_status = sys_config.get("wx")
qywx_webhook_status = sys_config.get("qywx_webhook")
feishu_webhook_status = sys_config.get("feishu_webhook")
feishu_status = sys_config.get("feishu")
- if not any([mail_status, ding_status, ding_webhook_status, wx_status, feishu_status, feishu_webhook_status,
- qywx_webhook_status]):
- logger.info('未开启任何消息通知,可在系统设置中开启')
+ if not any(
+ [
+ mail_status,
+ ding_status,
+ ding_webhook_status,
+ wx_status,
+ feishu_status,
+ feishu_webhook_status,
+ qywx_webhook_status,
+ ]
+ ):
+ logger.info("未开启任何消息通知,可在系统设置中开启")
return False
else:
return True
@@ -45,30 +60,41 @@ def __send(msg_title, msg_content, msg_to, msg_cc=None, **kwargs):
sys_config = SysConfig()
msg_sender = MsgSender()
msg_cc = msg_cc if msg_cc else []
- dingding_webhook = kwargs.get('dingding_webhook')
- feishu_webhook = kwargs.get('feishu_webhook')
- qywx_webhook = kwargs.get('qywx_webhook')
+ dingding_webhook = kwargs.get("dingding_webhook")
+ feishu_webhook = kwargs.get("feishu_webhook")
+ qywx_webhook = kwargs.get("qywx_webhook")
msg_to_email = [user.email for user in msg_to if user.email]
msg_cc_email = [user.email for user in msg_cc if user.email]
- msg_to_ding_user = [user.ding_user_id for user in chain(msg_to, msg_cc) if user.ding_user_id]
- msg_to_wx_user = [user.wx_user_id if user.wx_user_id else user.username for user in chain(msg_to, msg_cc)]
- logger.info(f'{msg_to_email}{msg_cc_email}{msg_to_wx_user}{chain(msg_to, msg_cc)}')
- if sys_config.get('mail'):
- msg_sender.send_email(msg_title, msg_content, msg_to_email, list_cc_addr=msg_cc_email)
- if sys_config.get('ding') and dingding_webhook:
- msg_sender.send_ding(dingding_webhook, msg_title + '\n' + msg_content)
- if sys_config.get('ding_to_person'):
- msg_sender.send_ding2user(msg_to_ding_user, msg_title + '\n' + msg_content)
- if sys_config.get('wx'):
- msg_sender.send_wx2user(msg_title + '\n' + msg_content, msg_to_wx_user)
+ msg_to_ding_user = [
+ user.ding_user_id for user in chain(msg_to, msg_cc) if user.ding_user_id
+ ]
+ msg_to_wx_user = [
+ user.wx_user_id if user.wx_user_id else user.username
+ for user in chain(msg_to, msg_cc)
+ ]
+ logger.info(f"{msg_to_email}{msg_cc_email}{msg_to_wx_user}{chain(msg_to, msg_cc)}")
+ if sys_config.get("mail"):
+ msg_sender.send_email(
+ msg_title, msg_content, msg_to_email, list_cc_addr=msg_cc_email
+ )
+ if sys_config.get("ding") and dingding_webhook:
+ msg_sender.send_ding(dingding_webhook, msg_title + "\n" + msg_content)
+ if sys_config.get("ding_to_person"):
+ msg_sender.send_ding2user(msg_to_ding_user, msg_title + "\n" + msg_content)
+ if sys_config.get("wx"):
+ msg_sender.send_wx2user(msg_title + "\n" + msg_content, msg_to_wx_user)
if sys_config.get("feishu_webhook") and feishu_webhook:
msg_sender.send_feishu_webhook(feishu_webhook, msg_title, msg_content)
if sys_config.get("feishu"):
- open_id = [user.feishu_open_id for user in chain(msg_to, msg_cc) if user.feishu_open_id]
- user_mail = [user.email for user in chain(msg_to, msg_cc) if not user.feishu_open_id]
+ open_id = [
+ user.feishu_open_id for user in chain(msg_to, msg_cc) if user.feishu_open_id
+ ]
+ user_mail = [
+ user.email for user in chain(msg_to, msg_cc) if not user.feishu_open_id
+ ]
msg_sender.send_feishu_user(msg_title, msg_content, open_id, user_mail)
- if sys_config.get('qywx_webhook') and qywx_webhook:
- msg_sender.send_qywx_webhook(qywx_webhook, msg_title + '\n' + msg_content)
+ if sys_config.get("qywx_webhook") and qywx_webhook:
+ msg_sender.send_qywx_webhook(qywx_webhook, msg_title + "\n" + msg_content)
def notify_for_audit(audit_id, **kwargs):
@@ -86,71 +112,90 @@ def notify_for_audit(audit_id, **kwargs):
# 获取审核信息
audit_detail = Audit.detail(audit_id=audit_id)
audit_id = audit_detail.audit_id
- workflow_audit_remark = kwargs.get('audit_remark', '')
- base_url = sys_config.get('archery_base_url', 'http://127.0.0.1:8000').rstrip('/')
- workflow_url = "{base_url}/workflow/{audit_id}".format(base_url=base_url, audit_id=audit_detail.audit_id)
+ workflow_audit_remark = kwargs.get("audit_remark", "")
+ base_url = sys_config.get("archery_base_url", "http://127.0.0.1:8000").rstrip("/")
+ workflow_url = "{base_url}/workflow/{audit_id}".format(
+ base_url=base_url, audit_id=audit_detail.audit_id
+ )
workflow_id = audit_detail.workflow_id
workflow_type = audit_detail.workflow_type
status = audit_detail.current_status
workflow_title = audit_detail.workflow_title
workflow_from = audit_detail.create_user_display
group_name = audit_detail.group_name
- dingding_webhook = ResourceGroup.objects.get(group_id=audit_detail.group_id).ding_webhook
- feishu_webhook = ResourceGroup.objects.get(group_id=audit_detail.group_id).feishu_webhook
- qywx_webhook = ResourceGroup.objects.get(group_id=audit_detail.group_id).qywx_webhook
+ dingding_webhook = ResourceGroup.objects.get(
+ group_id=audit_detail.group_id
+ ).ding_webhook
+ feishu_webhook = ResourceGroup.objects.get(
+ group_id=audit_detail.group_id
+ ).feishu_webhook
+ qywx_webhook = ResourceGroup.objects.get(
+ group_id=audit_detail.group_id
+ ).qywx_webhook
# 获取当前审批和审批流程
- workflow_auditors, current_workflow_auditors = Audit.review_info(audit_detail.workflow_id,
- audit_detail.workflow_type)
+ workflow_auditors, current_workflow_auditors = Audit.review_info(
+ audit_detail.workflow_id, audit_detail.workflow_type
+ )
# 准备消息内容
- if workflow_type == WorkflowDict.workflow_type['query']:
- workflow_type_display = WorkflowDict.workflow_type['query_display']
+ if workflow_type == WorkflowDict.workflow_type["query"]:
+ workflow_type_display = WorkflowDict.workflow_type["query_display"]
workflow_detail = QueryPrivilegesApply.objects.get(apply_id=workflow_id)
instance = workflow_detail.instance.instance_name
- db_name = ' '
+ db_name = " "
if workflow_detail.priv_type == 1:
- workflow_content = '''数据库清单:{}\n授权截止时间:{}\n结果集:{}\n'''.format(
+ workflow_content = """数据库清单:{}\n授权截止时间:{}\n结果集:{}\n""".format(
workflow_detail.db_list,
- datetime.datetime.strftime(workflow_detail.valid_date, '%Y-%m-%d %H:%M:%S'),
- workflow_detail.limit_num)
+ datetime.datetime.strftime(
+ workflow_detail.valid_date, "%Y-%m-%d %H:%M:%S"
+ ),
+ workflow_detail.limit_num,
+ )
elif workflow_detail.priv_type == 2:
db_name = workflow_detail.db_list
- workflow_content = '''数据库:{}\n表清单:{}\n授权截止时间:{}\n结果集:{}\n'''.format(
+ workflow_content = """数据库:{}\n表清单:{}\n授权截止时间:{}\n结果集:{}\n""".format(
workflow_detail.db_list,
workflow_detail.table_list,
- datetime.datetime.strftime(workflow_detail.valid_date, '%Y-%m-%d %H:%M:%S'),
- workflow_detail.limit_num)
+ datetime.datetime.strftime(
+ workflow_detail.valid_date, "%Y-%m-%d %H:%M:%S"
+ ),
+ workflow_detail.limit_num,
+ )
else:
- workflow_content = ''
- elif workflow_type == WorkflowDict.workflow_type['sqlreview']:
- workflow_type_display = WorkflowDict.workflow_type['sqlreview_display']
+ workflow_content = ""
+ elif workflow_type == WorkflowDict.workflow_type["sqlreview"]:
+ workflow_type_display = WorkflowDict.workflow_type["sqlreview_display"]
workflow_detail = SqlWorkflow.objects.get(pk=workflow_id)
instance = workflow_detail.instance.instance_name
db_name = workflow_detail.db_name
- workflow_content = re.sub('[\r\n\f]{2,}', '\n',
- workflow_detail.sqlworkflowcontent.sql_content[0:500].replace('\r', ''))
- elif workflow_type == WorkflowDict.workflow_type['archive']:
- workflow_type_display = WorkflowDict.workflow_type['archive_display']
+ workflow_content = re.sub(
+ "[\r\n\f]{2,}",
+ "\n",
+ workflow_detail.sqlworkflowcontent.sql_content[0:500].replace("\r", ""),
+ )
+ elif workflow_type == WorkflowDict.workflow_type["archive"]:
+ workflow_type_display = WorkflowDict.workflow_type["archive_display"]
workflow_detail = ArchiveConfig.objects.get(pk=workflow_id)
instance = workflow_detail.src_instance.instance_name
db_name = workflow_detail.src_db_name
- workflow_content = '''归档表:{}\n归档模式:{}\n归档条件:{}\n'''.format(
+ workflow_content = """归档表:{}\n归档模式:{}\n归档条件:{}\n""".format(
workflow_detail.src_table_name,
workflow_detail.mode,
- workflow_detail.condition)
+ workflow_detail.condition,
+ )
else:
- raise Exception('工单类型不正确')
+ raise Exception("工单类型不正确")
# 准备消息格式
- if status == WorkflowDict.workflow_status['audit_wait']: # 申请阶段
+ if status == WorkflowDict.workflow_status["audit_wait"]: # 申请阶段
msg_title = "[{}]新的工单申请#{}".format(workflow_type_display, audit_id)
# 接收人,发送给该资源组内对应权限组所有的用户
auth_group_names = Group.objects.get(id=audit_detail.current_audit).name
msg_to = auth_group_users([auth_group_names], audit_detail.group_id)
- msg_cc = Users.objects.filter(username__in=kwargs.get('cc_users', []))
+ msg_cc = Users.objects.filter(username__in=kwargs.get("cc_users", []))
# 消息内容
- msg_content = '''发起时间:{}\n发起人:{}\n组:{}\n目标实例:{}\n数据库:{}\n审批流程:{}\n当前审批:{}\n工单名称:{}\n工单地址:{}\n工单详情预览:{}\n'''.format(
- workflow_detail.create_time.strftime('%Y-%m-%d %H:%M:%S'),
+ msg_content = """发起时间:{}\n发起人:{}\n组:{}\n目标实例:{}\n数据库:{}\n审批流程:{}\n当前审批:{}\n工单名称:{}\n工单地址:{}\n工单详情预览:{}\n""".format(
+ workflow_detail.create_time.strftime("%Y-%m-%d %H:%M:%S"),
workflow_from,
group_name,
instance,
@@ -159,15 +204,16 @@ def notify_for_audit(audit_id, **kwargs):
current_workflow_auditors,
workflow_title,
workflow_url,
- workflow_content)
- elif status == WorkflowDict.workflow_status['audit_success']: # 审核通过
+ workflow_content,
+ )
+ elif status == WorkflowDict.workflow_status["audit_success"]: # 审核通过
msg_title = "[{}]工单审核通过#{}".format(workflow_type_display, audit_id)
# 接收人,仅发送给申请人
msg_to = [Users.objects.get(username=audit_detail.create_user)]
- msg_cc = Users.objects.filter(username__in=kwargs.get('cc_users', []))
+ msg_cc = Users.objects.filter(username__in=kwargs.get("cc_users", []))
# 消息内容
- msg_content = '''发起时间:{}\n发起人:{}\n组:{}\n目标实例:{}\n数据库:{}\n审批流程:{}\n工单名称:{}\n工单地址:{}\n工单详情预览:{}\n'''.format(
- workflow_detail.create_time.strftime('%Y-%m-%d %H:%M:%S'),
+ msg_content = """发起时间:{}\n发起人:{}\n组:{}\n目标实例:{}\n数据库:{}\n审批流程:{}\n工单名称:{}\n工单地址:{}\n工单详情预览:{}\n""".format(
+ workflow_detail.create_time.strftime("%Y-%m-%d %H:%M:%S"),
workflow_from,
group_name,
instance,
@@ -175,43 +221,55 @@ def notify_for_audit(audit_id, **kwargs):
workflow_auditors,
workflow_title,
workflow_url,
- workflow_content)
- elif status == WorkflowDict.workflow_status['audit_reject']: # 审核驳回
+ workflow_content,
+ )
+ elif status == WorkflowDict.workflow_status["audit_reject"]: # 审核驳回
msg_title = "[{}]工单被驳回#{}".format(workflow_type_display, audit_id)
# 接收人,仅发送给申请人
msg_to = [Users.objects.get(username=audit_detail.create_user)]
- msg_cc = Users.objects.filter(username__in=kwargs.get('cc_users', []))
+ msg_cc = Users.objects.filter(username__in=kwargs.get("cc_users", []))
# 消息内容
- msg_content = '''发起时间:{}\n目标实例:{}\n数据库:{}\n工单名称:{}\n工单地址:{}\n驳回原因:{}\n提醒:此工单被审核不通过,请按照驳回原因进行修改!'''.format(
- workflow_detail.create_time.strftime('%Y-%m-%d %H:%M:%S'),
+ msg_content = """发起时间:{}\n目标实例:{}\n数据库:{}\n工单名称:{}\n工单地址:{}\n驳回原因:{}\n提醒:此工单被审核不通过,请按照驳回原因进行修改!""".format(
+ workflow_detail.create_time.strftime("%Y-%m-%d %H:%M:%S"),
instance,
db_name,
workflow_title,
workflow_url,
- re.sub('[\r\n\f]{2,}', '\n', workflow_audit_remark))
- elif status == WorkflowDict.workflow_status['audit_abort']: # 审核取消,通知所有审核人
+ re.sub("[\r\n\f]{2,}", "\n", workflow_audit_remark),
+ )
+ elif status == WorkflowDict.workflow_status["audit_abort"]: # 审核取消,通知所有审核人
msg_title = "[{}]提交人主动终止工单#{}".format(workflow_type_display, audit_id)
# 接收人,发送给该资源组内对应权限组所有的用户
- auth_group_names = [Group.objects.get(id=auth_group_id).name for auth_group_id in
- audit_detail.audit_auth_groups.split(',')]
+ auth_group_names = [
+ Group.objects.get(id=auth_group_id).name
+ for auth_group_id in audit_detail.audit_auth_groups.split(",")
+ ]
msg_to = auth_group_users(auth_group_names, audit_detail.group_id)
- msg_cc = Users.objects.filter(username__in=kwargs.get('cc_users', []))
+ msg_cc = Users.objects.filter(username__in=kwargs.get("cc_users", []))
# 消息内容
- msg_content = '''发起时间:{}\n发起人:{}\n组:{}\n目标实例:{}\n数据库:{}\n工单名称:{}\n工单地址:{}\n终止原因:{}'''.format(
- workflow_detail.create_time.strftime('%Y-%m-%d %H:%M:%S'),
+ msg_content = """发起时间:{}\n发起人:{}\n组:{}\n目标实例:{}\n数据库:{}\n工单名称:{}\n工单地址:{}\n终止原因:{}""".format(
+ workflow_detail.create_time.strftime("%Y-%m-%d %H:%M:%S"),
workflow_from,
group_name,
instance,
db_name,
workflow_title,
workflow_url,
- re.sub('[\r\n\f]{2,}', '\n', workflow_audit_remark))
+ re.sub("[\r\n\f]{2,}", "\n", workflow_audit_remark),
+ )
else:
- raise Exception('工单状态不正确')
+ raise Exception("工单状态不正确")
logger.info(f"通知Debug{msg_to}{msg_cc}")
# 发送通知
- __send(msg_title, msg_content, msg_to, msg_cc, feishu_webhook=feishu_webhook, dingding_webhook=dingding_webhook,
- qywx_webhook=qywx_webhook)
+ __send(
+ msg_title,
+ msg_content,
+ msg_to,
+ msg_cc,
+ feishu_webhook=feishu_webhook,
+ dingding_webhook=dingding_webhook,
+ qywx_webhook=qywx_webhook,
+ )
def notify_for_execute(workflow):
@@ -226,14 +284,17 @@ def notify_for_execute(workflow):
sys_config = SysConfig()
# 获取当前审批和审批流程
- base_url = sys_config.get('archery_base_url', 'http://127.0.0.1:8000').rstrip('/')
+ base_url = sys_config.get("archery_base_url", "http://127.0.0.1:8000").rstrip("/")
audit_auth_group, current_audit_auth_group = Audit.review_info(workflow.id, 2)
audit_id = Audit.detail_by_workflow_id(workflow.id, 2).audit_id
url = "{base_url}/workflow/{audit_id}".format(base_url=base_url, audit_id=audit_id)
- msg_title = "[{}]工单{}#{}".format(WorkflowDict.workflow_type['sqlreview_display'],
- workflow.get_status_display(), audit_id)
- msg_content = '''发起时间:{}\n发起人:{}\n组:{}\n目标实例:{}\n数据库:{}\n审批流程:{}\n工单名称:{}\n工单地址:{}\n工单详情预览:{}\n'''.format(
- workflow.create_time.strftime('%Y-%m-%d %H:%M:%S'),
+ msg_title = "[{}]工单{}#{}".format(
+ WorkflowDict.workflow_type["sqlreview_display"],
+ workflow.get_status_display(),
+ audit_id,
+ )
+ msg_content = """发起时间:{}\n发起人:{}\n组:{}\n目标实例:{}\n数据库:{}\n审批流程:{}\n工单名称:{}\n工单地址:{}\n工单详情预览:{}\n""".format(
+ workflow.create_time.strftime("%Y-%m-%d %H:%M:%S"),
workflow.engineer_display,
workflow.group_name,
workflow.instance.instance_name,
@@ -241,35 +302,54 @@ def notify_for_execute(workflow):
audit_auth_group,
workflow.workflow_name,
url,
- re.sub('[\r\n\f]{2,}', '\n', workflow.sqlworkflowcontent.sql_content[0:500].replace('\r', '')))
+ re.sub(
+ "[\r\n\f]{2,}",
+ "\n",
+ workflow.sqlworkflowcontent.sql_content[0:500].replace("\r", ""),
+ ),
+ )
# 邮件通知申请人,抄送DBA
msg_to = Users.objects.filter(username=workflow.engineer)
- msg_cc = auth_group_users(auth_group_names=['DBA'], group_id=workflow.group_id)
+ msg_cc = auth_group_users(auth_group_names=["DBA"], group_id=workflow.group_id)
# 处理接收人
- dingding_webhook = ResourceGroup.objects.get(group_id=workflow.group_id).ding_webhook
- feishu_webhook = ResourceGroup.objects.get(group_id=workflow.group_id).feishu_webhook
+ dingding_webhook = ResourceGroup.objects.get(
+ group_id=workflow.group_id
+ ).ding_webhook
+ feishu_webhook = ResourceGroup.objects.get(
+ group_id=workflow.group_id
+ ).feishu_webhook
qywx_webhook = ResourceGroup.objects.get(group_id=workflow.group_id).qywx_webhook
# 发送通知
- __send(msg_title, msg_content, msg_to, msg_cc, dingding_webhook=dingding_webhook, feishu_webhook=feishu_webhook,
- qywx_webhook=qywx_webhook)
+ __send(
+ msg_title,
+ msg_content,
+ msg_to,
+ msg_cc,
+ dingding_webhook=dingding_webhook,
+ feishu_webhook=feishu_webhook,
+ qywx_webhook=qywx_webhook,
+ )
# DDL通知
- if sys_config.get('ddl_notify_auth_group') and workflow.status == 'workflow_finish':
+ if sys_config.get("ddl_notify_auth_group") and workflow.status == "workflow_finish":
# 判断上线语句是否存在DDL,存在则通知相关人员
if workflow.syntax_type == 1:
# 消息内容通知
- msg_title = '[Archery]有新的DDL语句执行完成#{}'.format(audit_id)
- msg_content = '''发起人:{}\n变更组:{}\n变更实例:{}\n变更数据库:{}\n工单名称:{}\n工单地址:{}\n工单预览:{}\n'''.format(
+ msg_title = "[Archery]有新的DDL语句执行完成#{}".format(audit_id)
+ msg_content = """发起人:{}\n变更组:{}\n变更实例:{}\n变更数据库:{}\n工单名称:{}\n工单地址:{}\n工单预览:{}\n""".format(
Users.objects.get(username=workflow.engineer).display,
workflow.group_name,
workflow.instance.instance_name,
workflow.db_name,
workflow.workflow_name,
url,
- workflow.sqlworkflowcontent.sql_content[0:500])
+ workflow.sqlworkflowcontent.sql_content[0:500],
+ )
# 获取通知成员ddl_notify_auth_group
- ddl_notify_auth_group = sys_config.get('ddl_notify_auth_group', '').split(',')
+ ddl_notify_auth_group = sys_config.get("ddl_notify_auth_group", "").split(
+ ","
+ )
msg_to = Users.objects.filter(groups__name__in=ddl_notify_auth_group)
# 发送通知
__send(msg_title, msg_content, msg_to, msg_cc)
@@ -285,11 +365,11 @@ def notify_for_my2sql(task):
if not __notify_cnf_status():
return None
if task.success:
- msg_title = '[Archery 通知]My2SQL执行结束'
- msg_content = f'解析的SQL文件在{task.result[1]}目录下,请前往查看'
+ msg_title = "[Archery 通知]My2SQL执行结束"
+ msg_content = f"解析的SQL文件在{task.result[1]}目录下,请前往查看"
else:
- msg_title = '[Archery 通知]My2SQL执行失败'
- msg_content = f'{task.result}'
+ msg_title = "[Archery 通知]My2SQL执行失败"
+ msg_content = f"{task.result}"
# 发送
- msg_to = [task.kwargs['user']]
+ msg_to = [task.kwargs["user"]]
__send(msg_title, msg_content, msg_to)
diff --git a/sql/plugins/my2sql.py b/sql/plugins/my2sql.py
index 7b3457ef23..5546cc5f99 100644
--- a/sql/plugins/my2sql.py
+++ b/sql/plugins/my2sql.py
@@ -5,9 +5,8 @@
class My2SQL(Plugin):
-
def __init__(self):
- self.path = SysConfig().get('my2sql')
+ self.path = SysConfig().get("my2sql")
self.required_args = []
self.disable_args = []
super(Plugin, self).__init__()
@@ -19,34 +18,50 @@ def generate_args2cmd(self, args, shell):
:param shell:
:return:
"""
- conn_options = ['conn_options']
- args_options = ['work-type', 'threads', 'start-file', 'stop-file', 'start-pos',
- 'stop-pos', 'databases', 'tables', 'sql', 'output-dir']
- no_args_options = ['output-toScreen', 'add-extraInfo', 'ignore-primaryKey-forInsert',
- 'full-columns', 'do-not-add-prifixDb', 'file-per-table']
- datetime_options = ['start-datetime', 'stop-datetime']
+ conn_options = ["conn_options"]
+ args_options = [
+ "work-type",
+ "threads",
+ "start-file",
+ "stop-file",
+ "start-pos",
+ "stop-pos",
+ "databases",
+ "tables",
+ "sql",
+ "output-dir",
+ ]
+ no_args_options = [
+ "output-toScreen",
+ "add-extraInfo",
+ "ignore-primaryKey-forInsert",
+ "full-columns",
+ "do-not-add-prifixDb",
+ "file-per-table",
+ ]
+ datetime_options = ["start-datetime", "stop-datetime"]
if shell:
- cmd_args = f'{shlex.quote(str(self.path))}' if self.path else ''
+ cmd_args = f"{shlex.quote(str(self.path))}" if self.path else ""
for name, value in args.items():
if name in conn_options:
- cmd_args += f' {value}'
+ cmd_args += f" {value}"
elif name in args_options and value:
- cmd_args += f' -{name} {shlex.quote(str(value))}'
+ cmd_args += f" -{name} {shlex.quote(str(value))}"
elif name in datetime_options and value:
cmd_args += f" -{name} '{shlex.quote(str(value))}'"
elif name in no_args_options and value:
- cmd_args += f' -{name}'
+ cmd_args += f" -{name}"
else:
cmd_args = [self.path]
for name, value in args.items():
if name in conn_options:
- cmd_args.append(f'{value}')
+ cmd_args.append(f"{value}")
elif name in args_options:
- cmd_args.append(f'-{name}')
- cmd_args.append(f'{value}')
+ cmd_args.append(f"-{name}")
+ cmd_args.append(f"{value}")
elif name in datetime_options:
- cmd_args.append(f'-{name}')
+ cmd_args.append(f"-{name}")
cmd_args.append(f"'{value}'")
elif name in no_args_options:
- cmd_args.append(f'-{name}')
+ cmd_args.append(f"-{name}")
return cmd_args
diff --git a/sql/plugins/plugin.py b/sql/plugins/plugin.py
index 3cea862226..102b4c46d8 100644
--- a/sql/plugins/plugin.py
+++ b/sql/plugins/plugin.py
@@ -5,13 +5,13 @@
@file: plugin.py
@time: 2019/03/04
"""
-__author__ = 'hhyo'
+__author__ = "hhyo"
import logging
import subprocess
import traceback
-logger = logging.getLogger('default')
+logger = logging.getLogger("default")
class Plugin:
@@ -25,20 +25,28 @@ def check_args(self, args):
检查请求参数列表
:return: {'status': 0, 'msg': 'ok', 'data': {}}
"""
- args_check_result = {'status': 0, 'msg': 'ok', 'data': {}}
+ args_check_result = {"status": 0, "msg": "ok", "data": {}}
# 检查路径
if self.path is None:
- return {'status': 1, 'msg': '可执行文件路径不能为空!', 'data': {}}
+ return {"status": 1, "msg": "可执行文件路径不能为空!", "data": {}}
# 检查禁用参数
for arg in args.keys():
if arg in self.disable_args:
- return {'status': 1, 'msg': '{arg}参数已被禁用'.format(arg=arg), 'data': {}}
+ return {"status": 1, "msg": "{arg}参数已被禁用".format(arg=arg), "data": {}}
# 检查必须参数
for req_arg in self.required_args:
if req_arg not in args.keys():
- return {'status': 1, 'msg': '必须指定{arg}参数'.format(arg=req_arg), 'data': {}}
- elif args[req_arg] is None or args[req_arg] == '':
- return {'status': 1, 'msg': '{arg}参数值不能为空'.format(arg=req_arg), 'data': {}}
+ return {
+ "status": 1,
+ "msg": "必须指定{arg}参数".format(arg=req_arg),
+ "data": {},
+ }
+ elif args[req_arg] is None or args[req_arg] == "":
+ return {
+ "status": 1,
+ "msg": "{arg}参数值不能为空".format(arg=req_arg),
+ "data": {},
+ }
return args_check_result
def generate_args2cmd(self, args, shell):
@@ -54,12 +62,14 @@ def execute_cmd(cmd_args, shell):
:return:
"""
try:
- p = subprocess.Popen(cmd_args,
- shell=shell,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- universal_newlines=True)
+ p = subprocess.Popen(
+ cmd_args,
+ shell=shell,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ universal_newlines=True,
+ )
return p
except Exception as e:
logger.error("命令执行失败\n{}".format(traceback.format_exc()))
- raise RuntimeError('命令执行失败,失败原因:%s' % str(e))
+ raise RuntimeError("命令执行失败,失败原因:%s" % str(e))
diff --git a/sql/plugins/pt_archiver.py b/sql/plugins/pt_archiver.py
index c5785c69b5..d84ed2254f 100644
--- a/sql/plugins/pt_archiver.py
+++ b/sql/plugins/pt_archiver.py
@@ -8,7 +8,7 @@
from common.config import SysConfig
from sql.plugins.plugin import Plugin
-__author__ = 'hhyo'
+__author__ = "hhyo"
class PtArchiver(Plugin):
@@ -17,9 +17,9 @@ class PtArchiver(Plugin):
"""
def __init__(self):
- self.path = 'pt-archiver'
+ self.path = "pt-archiver"
self.required_args = []
- self.disable_args = ['analyze']
+ self.disable_args = ["analyze"]
super(Plugin, self).__init__()
def generate_args2cmd(self, args, shell):
@@ -29,24 +29,41 @@ def generate_args2cmd(self, args, shell):
:param shell:
:return:
"""
- k_options = ['no-version-check', 'statistics', 'bulk-insert', 'bulk-delete', 'purge', 'no-delete']
- kv_options = ['source', 'dest', 'file', 'where', 'progress', 'charset', 'limit', 'txn-size', 'sleep']
+ k_options = [
+ "no-version-check",
+ "statistics",
+ "bulk-insert",
+ "bulk-delete",
+ "purge",
+ "no-delete",
+ ]
+ kv_options = [
+ "source",
+ "dest",
+ "file",
+ "where",
+ "progress",
+ "charset",
+ "limit",
+ "txn-size",
+ "sleep",
+ ]
if shell:
- cmd_args = self.path if self.path else ''
+ cmd_args = self.path if self.path else ""
for name, value in args.items():
if name in k_options and value:
- cmd_args += f' --{name}'
+ cmd_args += f" --{name}"
elif name in kv_options:
- if name == 'where':
+ if name == "where":
cmd_args += f' --{name} "{value}"'
else:
- cmd_args += f' --{name} {value}'
+ cmd_args += f" --{name} {value}"
else:
cmd_args = [self.path]
for name, value in args.items():
if name in k_options and value:
- cmd_args.append(f'--{name}')
+ cmd_args.append(f"--{name}")
elif name in kv_options:
- cmd_args.append(f'--{name}')
- cmd_args.append(f'{value}')
+ cmd_args.append(f"--{name}")
+ cmd_args.append(f"{value}")
return cmd_args
diff --git a/sql/plugins/schemasync.py b/sql/plugins/schemasync.py
index c4dbf50a55..826299b9de 100644
--- a/sql/plugins/schemasync.py
+++ b/sql/plugins/schemasync.py
@@ -5,16 +5,15 @@
@file: schemasync.py
@time: 2019/03/05
"""
-__author__ = 'hhyo'
+__author__ = "hhyo"
import shlex
from sql.plugins.plugin import Plugin
class SchemaSync(Plugin):
-
def __init__(self):
- self.path = 'schemasync'
+ self.path = "schemasync"
self.required_args = []
self.disable_args = []
super(Plugin, self).__init__()
@@ -26,26 +25,26 @@ def generate_args2cmd(self, args, shell):
:param shell:
:return:
"""
- k_options = ['sync-auto-inc', 'sync-comments']
- kv_options = ['tag', 'output-directory', 'log-directory']
- v_options = ['source', 'target']
+ k_options = ["sync-auto-inc", "sync-comments"]
+ kv_options = ["tag", "output-directory", "log-directory"]
+ v_options = ["source", "target"]
if shell:
- cmd_args = self.path if self.path else ''
+ cmd_args = self.path if self.path else ""
for name, value in args.items():
if name in k_options and value:
- cmd_args += f' --{name}'
+ cmd_args += f" --{name}"
elif name in kv_options:
- cmd_args += f' --{name}={shlex.quote(str(value))}'
+ cmd_args += f" --{name}={shlex.quote(str(value))}"
elif name in v_options:
- cmd_args += f' {value}'
+ cmd_args += f" {value}"
else:
cmd_args = [self.path]
for name, value in args.items():
if name in k_options and value:
- cmd_args.append(f'--{name}')
+ cmd_args.append(f"--{name}")
elif name in kv_options:
- cmd_args.append(f'--{name}')
- cmd_args.append(f'{value}')
- elif name in ['source', 'target']:
- cmd_args.append(f'{value}')
+ cmd_args.append(f"--{name}")
+ cmd_args.append(f"{value}")
+ elif name in ["source", "target"]:
+ cmd_args.append(f"{value}")
return cmd_args
diff --git a/sql/plugins/soar.py b/sql/plugins/soar.py
index 6220c59b33..18844a8188 100644
--- a/sql/plugins/soar.py
+++ b/sql/plugins/soar.py
@@ -5,7 +5,7 @@
@file: soar.py
@time: 2019/03/04
"""
-__author__ = 'hhyo'
+__author__ = "hhyo"
import shlex
from common.config import SysConfig
@@ -13,10 +13,9 @@
class Soar(Plugin):
-
def __init__(self):
- self.path = SysConfig().get('soar')
- self.required_args = ['query']
+ self.path = SysConfig().get("soar")
+ self.required_args = ["query"]
self.disable_args = []
super(Plugin, self).__init__()
@@ -28,14 +27,14 @@ def generate_args2cmd(self, args, shell):
:return:
"""
if shell:
- cmd_args = shlex.quote(str(self.path)) if self.path else ''
+ cmd_args = shlex.quote(str(self.path)) if self.path else ""
for name, value in args.items():
cmd_args += f" -{name}={shlex.quote(str(value))}"
else:
cmd_args = [self.path]
for name, value in args.items():
- cmd_args.append(f'-{name}')
- cmd_args.append(f'{value}')
+ cmd_args.append(f"-{name}")
+ cmd_args.append(f"{value}")
return cmd_args
def fingerprint(self, sql):
@@ -44,10 +43,7 @@ def fingerprint(self, sql):
:param sql:
:return:
"""
- args = {
- "query": sql,
- "report-type": "fingerprint"
- }
+ args = {"query": sql, "report-type": "fingerprint"}
cmd_args = self.generate_args2cmd(args, shell=True)
return self.execute_cmd(cmd_args=cmd_args, shell=True)
@@ -57,10 +53,7 @@ def compress(self, sql):
:param sql:
:return:
"""
- args = {
- "query": sql,
- "report-type": "compress"
- }
+ args = {"query": sql, "report-type": "compress"}
cmd_args = self.generate_args2cmd(args, shell=True)
return self.execute_cmd(cmd_args=cmd_args, shell=True)
@@ -73,7 +66,7 @@ def pretty(self, sql):
args = {
"query": sql,
"max-pretty-sql-length": 100000, # 超出该长度的SQL会转换成指纹输出 (default 1024)
- "report-type": "pretty"
+ "report-type": "pretty",
}
cmd_args = self.generate_args2cmd(args, shell=True)
return self.execute_cmd(cmd_args=cmd_args, shell=True)
@@ -84,10 +77,7 @@ def remove_comment(self, sql):
:param sql:
:return:
"""
- args = {
- "query": sql,
- "report-type": "remove-comment"
- }
+ args = {"query": sql, "report-type": "remove-comment"}
cmd_args = self.generate_args2cmd(args, shell=True)
return self.execute_cmd(cmd_args=cmd_args, shell=True)
@@ -98,17 +88,34 @@ def rewrite(self, sql, rewrite_rules=None):
:param rewrite_rules:
:return:
"""
- rewrite_type_list = ['dml2select', 'star2columns', 'insertcolumns', 'having', 'orderbynull', 'unionall',
- 'or2in', 'dmlorderby', 'distinctstar', 'standard', 'mergealter', 'alwaystrue',
- 'countstar', 'innodb', 'autoincrement', 'intwidth', 'truncate', 'rmparenthesis',
- 'delimiter']
- rewrite_rules = rewrite_rules if rewrite_rules else ['dml2select']
+ rewrite_type_list = [
+ "dml2select",
+ "star2columns",
+ "insertcolumns",
+ "having",
+ "orderbynull",
+ "unionall",
+ "or2in",
+ "dmlorderby",
+ "distinctstar",
+ "standard",
+ "mergealter",
+ "alwaystrue",
+ "countstar",
+ "innodb",
+ "autoincrement",
+ "intwidth",
+ "truncate",
+ "rmparenthesis",
+ "delimiter",
+ ]
+ rewrite_rules = rewrite_rules if rewrite_rules else ["dml2select"]
if set(rewrite_rules).issubset(set(rewrite_type_list)) is False:
- raise RuntimeError(f'不支持的改写规则,仅支持{rewrite_type_list}')
+ raise RuntimeError(f"不支持的改写规则,仅支持{rewrite_type_list}")
args = {
"query": sql,
"report-type": "rewrite",
- "rewrite-rules": ','.join(rewrite_type_list)
+ "rewrite-rules": ",".join(rewrite_type_list),
}
cmd_args = self.generate_args2cmd(args, shell=True)
return self.execute_cmd(cmd_args=cmd_args, shell=True)
@@ -119,9 +126,6 @@ def query_tree(self, sql):
:param sql:
:return:
"""
- args = {
- "query": sql,
- "report-type": "ast-json"
- }
+ args = {"query": sql, "report-type": "ast-json"}
cmd_args = self.generate_args2cmd(args, shell=True)
return self.execute_cmd(cmd_args=cmd_args, shell=True)
diff --git a/sql/plugins/sqladvisor.py b/sql/plugins/sqladvisor.py
index 744a11f2b4..6130661eb7 100644
--- a/sql/plugins/sqladvisor.py
+++ b/sql/plugins/sqladvisor.py
@@ -5,7 +5,7 @@
@file: sqladvisor.py
@time: 2019/03/04
"""
-__author__ = 'hhyo'
+__author__ = "hhyo"
import shlex
from common.config import SysConfig
@@ -14,8 +14,8 @@
class SQLAdvisor(Plugin):
def __init__(self):
- self.path = SysConfig().get('sqladvisor')
- self.required_args = ['q']
+ self.path = SysConfig().get("sqladvisor")
+ self.required_args = ["q"]
self.disable_args = []
super(Plugin, self).__init__()
@@ -27,12 +27,12 @@ def generate_args2cmd(self, args, shell):
:return:
"""
if shell:
- cmd_args = shlex.quote(str(self.path)) if self.path else ''
+ cmd_args = shlex.quote(str(self.path)) if self.path else ""
for name, value in args.items():
cmd_args += f" -{name} {shlex.quote(str(value))}"
else:
cmd_args = [self.path]
for name, value in args.items():
- cmd_args.append(f'-{name}')
- cmd_args.append(f'{value}')
+ cmd_args.append(f"-{name}")
+ cmd_args.append(f"{value}")
return cmd_args
diff --git a/sql/plugins/tests.py b/sql/plugins/tests.py
index 59d18db220..241a92403d 100644
--- a/sql/plugins/tests.py
+++ b/sql/plugins/tests.py
@@ -20,7 +20,7 @@
User = get_user_model()
-__author__ = 'hhyo'
+__author__ = "hhyo"
class TestPlugin(TestCase):
@@ -30,7 +30,7 @@ class TestPlugin(TestCase):
@classmethod
def setUpClass(cls):
- cls.superuser = User(username='super', is_superuser=True)
+ cls.superuser = User(username="super", is_superuser=True)
cls.superuser.save()
cls.sys_config = SysConfig()
cls.client = Client()
@@ -46,74 +46,87 @@ def test_check_args_path(self):
测试路径
:return:
"""
- args = {"online-dsn": '',
- "test-dsn": '',
- "allow-online-as-test": "false",
- "report-type": "markdown",
- "query": "select 1;"
- }
- self.sys_config.set('soar', '')
+ args = {
+ "online-dsn": "",
+ "test-dsn": "",
+ "allow-online-as-test": "false",
+ "report-type": "markdown",
+ "query": "select 1;",
+ }
+ self.sys_config.set("soar", "")
self.sys_config.get_all_config()
soar = Soar()
args_check_result = soar.check_args(args)
- self.assertDictEqual(args_check_result, {'status': 1, 'msg': '可执行文件路径不能为空!', 'data': {}})
+ self.assertDictEqual(
+ args_check_result, {"status": 1, "msg": "可执行文件路径不能为空!", "data": {}}
+ )
# 路径不为空
- self.sys_config.set('soar', '/opt/archery/src/plugins/soar')
+ self.sys_config.set("soar", "/opt/archery/src/plugins/soar")
self.sys_config.get_all_config()
soar = Soar()
args_check_result = soar.check_args(args)
- self.assertDictEqual(args_check_result, {'status': 0, 'msg': 'ok', 'data': {}})
+ self.assertDictEqual(args_check_result, {"status": 0, "msg": "ok", "data": {}})
def test_check_args_disable(self):
"""
测试禁用参数
:return:
"""
- args = {"online-dsn": '',
- "test-dsn": '',
- "allow-online-as-test": "false",
- "report-type": "markdown",
- "query": "select 1;"
- }
- self.sys_config.set('soar', '/opt/archery/src/plugins/soar')
+ args = {
+ "online-dsn": "",
+ "test-dsn": "",
+ "allow-online-as-test": "false",
+ "report-type": "markdown",
+ "query": "select 1;",
+ }
+ self.sys_config.set("soar", "/opt/archery/src/plugins/soar")
self.sys_config.get_all_config()
soar = Soar()
- soar.disable_args = ['allow-online-as-test']
+ soar.disable_args = ["allow-online-as-test"]
args_check_result = soar.check_args(args)
- self.assertDictEqual(args_check_result, {'status': 1, 'msg': 'allow-online-as-test参数已被禁用', 'data': {}})
+ self.assertDictEqual(
+ args_check_result,
+ {"status": 1, "msg": "allow-online-as-test参数已被禁用", "data": {}},
+ )
def test_check_args_required(self):
"""
测试必选参数
:return:
"""
- args = {"online-dsn": '',
- "test-dsn": '',
- "allow-online-as-test": "false",
- "report-type": "markdown",
- }
- self.sys_config.set('soar', '/opt/archery/src/plugins/soar')
+ args = {
+ "online-dsn": "",
+ "test-dsn": "",
+ "allow-online-as-test": "false",
+ "report-type": "markdown",
+ }
+ self.sys_config.set("soar", "/opt/archery/src/plugins/soar")
self.sys_config.get_all_config()
soar = Soar()
- soar.required_args = ['query']
+ soar.required_args = ["query"]
args_check_result = soar.check_args(args)
- self.assertDictEqual(args_check_result, {'status': 1, 'msg': '必须指定query参数', 'data': {}})
- args['query'] = ""
+ self.assertDictEqual(
+ args_check_result, {"status": 1, "msg": "必须指定query参数", "data": {}}
+ )
+ args["query"] = ""
args_check_result = soar.check_args(args)
- self.assertDictEqual(args_check_result, {'status': 1, 'msg': 'query参数值不能为空', 'data': {}})
+ self.assertDictEqual(
+ args_check_result, {"status": 1, "msg": "query参数值不能为空", "data": {}}
+ )
def test_soar_generate_args2cmd(self):
"""
测试SOAR参数转换
:return:
"""
- args = {"online-dsn": '',
- "test-dsn": '',
- "allow-online-as-test": "false",
- "report-type": "markdown",
- "query": "select 1;"
- }
- self.sys_config.set('soar', '/opt/archery/src/plugins/soar')
+ args = {
+ "online-dsn": "",
+ "test-dsn": "",
+ "allow-online-as-test": "false",
+ "report-type": "markdown",
+ "query": "select 1;",
+ }
+ self.sys_config.set("soar", "/opt/archery/src/plugins/soar")
self.sys_config.get_all_config()
soar = Soar()
cmd_args = soar.generate_args2cmd(args, False)
@@ -126,15 +139,16 @@ def test_sql_advisor_generate_args2cmd(self):
测试sql_advisor参数转换
:return:
"""
- args = {"h": 'mysql',
- "P": 3306,
- "u": 'root',
- "p": '',
- "d": 'archery',
- "v": 1,
- "q": 'select 1;'
- }
- self.sys_config.set('sqladvisor', '/opt/archery/src/plugins/SQLAdvisor')
+ args = {
+ "h": "mysql",
+ "P": 3306,
+ "u": "root",
+ "p": "",
+ "d": "archery",
+ "v": 1,
+ "q": "select 1;",
+ }
+ self.sys_config.set("sqladvisor", "/opt/archery/src/plugins/SQLAdvisor")
self.sys_config.get_all_config()
sql_advisor = SQLAdvisor()
cmd_args = sql_advisor.generate_args2cmd(args, False)
@@ -150,20 +164,16 @@ def test_schema_sync_generate_args2cmd(self):
args = {
"sync-auto-inc": True,
"sync-comments": True,
- "tag": 'tag_v',
- "output-directory": '',
- "source": r"mysql://{user}:{pwd}@{host}:{port}/{database}".format(user='root',
- pwd='123456',
- host='127.0.0.1',
- port=3306,
- database='*'),
- "target": r"mysql://{user}:{pwd}@{host}:{port}/{database}".format(user='root',
- pwd='123456',
- host='127.0.0.1',
- port=3306,
- database='*')
+ "tag": "tag_v",
+ "output-directory": "",
+ "source": r"mysql://{user}:{pwd}@{host}:{port}/{database}".format(
+ user="root", pwd="123456", host="127.0.0.1", port=3306, database="*"
+ ),
+ "target": r"mysql://{user}:{pwd}@{host}:{port}/{database}".format(
+ user="root", pwd="123456", host="127.0.0.1", port=3306, database="*"
+ ),
}
- self.sys_config.set('schemasync', '/opt/venv4schemasync/bin/schemasync')
+ self.sys_config.set("schemasync", "/opt/venv4schemasync/bin/schemasync")
self.sys_config.get_all_config()
schema_sync = SchemaSync()
cmd_args = schema_sync.generate_args2cmd(args, False)
@@ -176,24 +186,26 @@ def test_my2sql_generate_args2cmd(self):
测试my2sql参数转换
:return:
"""
- args = {'conn_options': "-host mysql -user root -password '123456' -port 3306 ",
- 'work-type': '2sql',
- 'start-file': 'mysql-bin.000043',
- 'start-pos': 111,
- 'stop-file': '',
- 'stop-pos': '',
- 'start-datetime': '',
- 'stop-datetime': '',
- 'databases': 'account_center',
- 'tables': 'ac_apps',
- 'sql': 'update',
- "threads": 1,
- "add-extraInfo": "false",
- "ignore-primaryKey-forInsert": "false",
- "full-columns": "false",
- "do-not-add-prifixDb": "false",
- "file-per-table": "false"}
- self.sys_config.set('my2sql', '/opt/archery/src/plugins/my2sql')
+ args = {
+ "conn_options": "-host mysql -user root -password '123456' -port 3306 ",
+ "work-type": "2sql",
+ "start-file": "mysql-bin.000043",
+ "start-pos": 111,
+ "stop-file": "",
+ "stop-pos": "",
+ "start-datetime": "",
+ "stop-datetime": "",
+ "databases": "account_center",
+ "tables": "ac_apps",
+ "sql": "update",
+ "threads": 1,
+ "add-extraInfo": "false",
+ "ignore-primaryKey-forInsert": "false",
+ "full-columns": "false",
+ "do-not-add-prifixDb": "false",
+ "file-per-table": "false",
+ }
+ self.sys_config.set("my2sql", "/opt/archery/src/plugins/my2sql")
self.sys_config.get_all_config()
my2sql = My2SQL()
cmd_args = my2sql.generate_args2cmd(args, False)
@@ -208,14 +220,14 @@ def test_pt_archiver_generate_args2cmd(self):
"""
args = {
"no-version-check": True,
- "source": '',
- "where": '',
+ "source": "",
+ "where": "",
"progress": 5000,
"statistics": True,
- "charset": 'UTF8',
+ "charset": "UTF8",
"limit": 10000,
"txn-size": 1000,
- "sleep": 1
+ "sleep": 1,
}
pt_archiver = PtArchiver()
cmd_args = pt_archiver.generate_args2cmd(args, False)
@@ -223,32 +235,32 @@ def test_pt_archiver_generate_args2cmd(self):
cmd_args = pt_archiver.generate_args2cmd(args, True)
self.assertIsInstance(cmd_args, str)
- @patch('sql.plugins.plugin.subprocess')
+ @patch("sql.plugins.plugin.subprocess")
def test_execute_cmd(self, mock_subprocess):
- args = {"online-dsn": '',
- "test-dsn": '',
- "allow-online-as-test": "false",
- "report-type": "markdown",
- "query": "select 1;"
- }
- self.sys_config.set('soar', '/opt/archery/src/plugins/soar')
+ args = {
+ "online-dsn": "",
+ "test-dsn": "",
+ "allow-online-as-test": "false",
+ "report-type": "markdown",
+ "query": "select 1;",
+ }
+ self.sys_config.set("soar", "/opt/archery/src/plugins/soar")
self.sys_config.get_all_config()
soar = Soar()
cmd_args = soar.generate_args2cmd(args, True)
- mock_subprocess.Popen.return_value.communicate.return_value = ('some_stdout', 'some_stderr')
+ mock_subprocess.Popen.return_value.communicate.return_value = (
+ "some_stdout",
+ "some_stderr",
+ )
stdout, stderr = soar.execute_cmd(cmd_args, True).communicate()
mock_subprocess.Popen.assert_called_once_with(
- cmd_args,
- shell=True,
- stdout=ANY,
- stderr=ANY,
- universal_newlines=ANY
+ cmd_args, shell=True, stdout=ANY, stderr=ANY, universal_newlines=ANY
)
- self.assertIn('some_stdout', stdout)
+ self.assertIn("some_stdout", stdout)
# 异常
- mock_subprocess.Popen.side_effect = Exception('Boom! some exception!')
+ mock_subprocess.Popen.side_effect = Exception("Boom! some exception!")
with self.assertRaises(RuntimeError):
soar.execute_cmd(cmd_args, False)
@@ -260,13 +272,13 @@ class TestSoar(TestCase):
@classmethod
def setUpClass(cls):
- soar_path = '/opt/archery/src/plugins/soar' # 修改为本机的soar路径
- cls.superuser = User(username='super', is_superuser=True)
+ soar_path = "/opt/archery/src/plugins/soar" # 修改为本机的soar路径
+ cls.superuser = User(username="super", is_superuser=True)
cls.superuser.save()
cls.client = Client()
cls.client.force_login(cls.superuser)
cls.sys_config = SysConfig()
- cls.sys_config.set('soar', soar_path)
+ cls.sys_config.set("soar", soar_path)
cls.sys_config.get_all_config()
cls.soar = Soar()
@@ -327,4 +339,4 @@ def test_rewrite(self):
self.soar.rewrite(sql)
# 异常测试
with self.assertRaises(RuntimeError):
- self.soar.rewrite(sql, 'unknown')
+ self.soar.rewrite(sql, "unknown")
diff --git a/sql/query.py b/sql/query.py
index 8cd67af541..83583d8c0e 100644
--- a/sql/query.py
+++ b/sql/query.py
@@ -19,64 +19,66 @@
from .models import QueryLog, Instance
from sql.engines import get_engine
-logger = logging.getLogger('default')
+logger = logging.getLogger("default")
-@permission_required('sql.query_submit', raise_exception=True)
+@permission_required("sql.query_submit", raise_exception=True)
def query(request):
"""
获取SQL查询结果
:param request:
:return:
"""
- instance_name = request.POST.get('instance_name')
- sql_content = request.POST.get('sql_content')
- db_name = request.POST.get('db_name')
- tb_name = request.POST.get('tb_name')
- limit_num = int(request.POST.get('limit_num', 0))
- schema_name = request.POST.get('schema_name', None)
+ instance_name = request.POST.get("instance_name")
+ sql_content = request.POST.get("sql_content")
+ db_name = request.POST.get("db_name")
+ tb_name = request.POST.get("tb_name")
+ limit_num = int(request.POST.get("limit_num", 0))
+ schema_name = request.POST.get("schema_name", None)
user = request.user
- result = {'status': 0, 'msg': 'ok', 'data': {}}
+ result = {"status": 0, "msg": "ok", "data": {}}
try:
instance = user_instances(request.user).get(instance_name=instance_name)
except Instance.DoesNotExist:
- result['status'] = 1
- result['msg'] = '你所在组未关联该实例'
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result["status"] = 1
+ result["msg"] = "你所在组未关联该实例"
+ return HttpResponse(json.dumps(result), content_type="application/json")
# 服务器端参数验证
if None in [sql_content, db_name, instance_name, limit_num]:
- result['status'] = 1
- result['msg'] = '页面提交参数可能为空'
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result["status"] = 1
+ result["msg"] = "页面提交参数可能为空"
+ return HttpResponse(json.dumps(result), content_type="application/json")
try:
config = SysConfig()
# 查询前的检查,禁用语句检查,语句切分
query_engine = get_engine(instance=instance)
query_check_info = query_engine.query_check(db_name=db_name, sql=sql_content)
- if query_check_info.get('bad_query'):
+ if query_check_info.get("bad_query"):
# 引擎内部判断为 bad_query
- result['status'] = 1
- result['msg'] = query_check_info.get('msg')
- return HttpResponse(json.dumps(result), content_type='application/json')
- if query_check_info.get('has_star') and config.get('disable_star') is True:
+ result["status"] = 1
+ result["msg"] = query_check_info.get("msg")
+ return HttpResponse(json.dumps(result), content_type="application/json")
+ if query_check_info.get("has_star") and config.get("disable_star") is True:
# 引擎内部判断为有 * 且禁止 * 选项打开
- result['status'] = 1
- result['msg'] = query_check_info.get('msg')
- return HttpResponse(json.dumps(result), content_type='application/json')
- sql_content = query_check_info['filtered_sql']
+ result["status"] = 1
+ result["msg"] = query_check_info.get("msg")
+ return HttpResponse(json.dumps(result), content_type="application/json")
+ sql_content = query_check_info["filtered_sql"]
# 查询权限校验,并且获取limit_num
- priv_check_info = query_priv_check(user, instance, db_name, sql_content, limit_num)
- if priv_check_info['status'] == 0:
- limit_num = priv_check_info['data']['limit_num']
- priv_check = priv_check_info['data']['priv_check']
+ priv_check_info = query_priv_check(
+ user, instance, db_name, sql_content, limit_num
+ )
+ if priv_check_info["status"] == 0:
+ limit_num = priv_check_info["data"]["limit_num"]
+ priv_check = priv_check_info["data"]["priv_check"]
else:
- result['status'] = priv_check_info['status']
- result['msg'] = priv_check_info['msg']
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result["status"] = priv_check_info["status"]
+ result["msg"] = priv_check_info["msg"]
+ return HttpResponse(json.dumps(result), content_type="application/json")
# explain的limit_num设置为0
limit_num = 0 if re.match(r"^explain", sql_content.lower()) else limit_num
@@ -86,19 +88,25 @@ def query(request):
# 先获取查询连接,用于后面查询复用连接以及终止会话
query_engine.get_connection(db_name=db_name)
thread_id = query_engine.thread_id
- max_execution_time = int(config.get('max_execution_time', 60))
+ max_execution_time = int(config.get("max_execution_time", 60))
# 执行查询语句,并增加一个定时终止语句的schedule,timeout=max_execution_time
if thread_id:
- schedule_name = f'query-{time.time()}'
- run_date = (datetime.datetime.now() + datetime.timedelta(seconds=max_execution_time))
+ schedule_name = f"query-{time.time()}"
+ run_date = datetime.datetime.now() + datetime.timedelta(
+ seconds=max_execution_time
+ )
add_kill_conn_schedule(schedule_name, run_date, instance.id, thread_id)
with FuncTimer() as t:
# 获取主从延迟信息
seconds_behind_master = query_engine.seconds_behind_master
- query_result = query_engine.query(db_name, sql_content, limit_num,
- schema_name=schema_name,
- tb_name=tb_name,
- max_execution_time=max_execution_time * 1000)
+ query_result = query_engine.query(
+ db_name,
+ sql_content,
+ limit_num,
+ schema_name=schema_name,
+ tb_name=tb_name,
+ max_execution_time=max_execution_time * 1000,
+ )
query_result.query_time = t.cost
# 返回查询结果后删除schedule
if thread_id:
@@ -106,46 +114,50 @@ def query(request):
# 查询异常
if query_result.error:
- result['status'] = 1
- result['msg'] = query_result.error
+ result["status"] = 1
+ result["msg"] = query_result.error
# 数据脱敏,仅对查询无错误的结果集进行脱敏,并且按照query_check配置是否返回
- elif config.get('data_masking'):
+ elif config.get("data_masking"):
try:
with FuncTimer() as t:
- masking_result = query_engine.query_masking(db_name, sql_content, query_result)
+ masking_result = query_engine.query_masking(
+ db_name, sql_content, query_result
+ )
masking_result.mask_time = t.cost
# 脱敏出错
if masking_result.error:
# 开启query_check,直接返回异常,禁止执行
- if config.get('query_check'):
- result['status'] = 1
- result['msg'] = f'数据脱敏异常:{masking_result.error}'
+ if config.get("query_check"):
+ result["status"] = 1
+ result["msg"] = f"数据脱敏异常:{masking_result.error}"
# 关闭query_check,忽略错误信息,返回未脱敏数据,权限校验标记为跳过
else:
- logger.warning(f'数据脱敏异常,按照配置放行,查询语句:{sql_content},错误信息:{masking_result.error}')
+ logger.warning(
+ f"数据脱敏异常,按照配置放行,查询语句:{sql_content},错误信息:{masking_result.error}"
+ )
query_result.error = None
- result['data'] = query_result.__dict__
+ result["data"] = query_result.__dict__
# 正常脱敏
else:
- result['data'] = masking_result.__dict__
+ result["data"] = masking_result.__dict__
except Exception as msg:
logger.error(traceback.format_exc())
# 抛出未定义异常,并且开启query_check,直接返回异常,禁止执行
- if config.get('query_check'):
- result['status'] = 1
- result['msg'] = f'数据脱敏异常,请联系管理员,错误信息:{msg}'
+ if config.get("query_check"):
+ result["status"] = 1
+ result["msg"] = f"数据脱敏异常,请联系管理员,错误信息:{msg}"
# 关闭query_check,忽略错误信息,返回未脱敏数据,权限校验标记为跳过
else:
- logger.warning(f'数据脱敏异常,按照配置放行,查询语句:{sql_content},错误信息:{msg}')
+ logger.warning(f"数据脱敏异常,按照配置放行,查询语句:{sql_content},错误信息:{msg}")
query_result.error = None
- result['data'] = query_result.__dict__
+ result["data"] = query_result.__dict__
# 无需脱敏的语句
else:
- result['data'] = query_result.__dict__
+ result["data"] = query_result.__dict__
# 仅将成功的查询语句记录存入数据库
if not query_result.error:
- result['data']['seconds_behind_master'] = seconds_behind_master
+ result["data"]["seconds_behind_master"] = seconds_behind_master
if int(limit_num) == 0:
limit_num = int(query_result.affected_rows)
else:
@@ -160,35 +172,46 @@ def query(request):
cost_time=query_result.query_time,
priv_check=priv_check,
hit_rule=query_result.mask_rule_hit,
- masking=query_result.is_masked
+ masking=query_result.is_masked,
)
# 防止查询超时
if connection.connection and not connection.is_usable():
close_old_connections()
query_log.save()
except Exception as e:
- logger.error(f'查询异常报错,查询语句:{sql_content}\n,错误信息:{traceback.format_exc()}')
- result['status'] = 1
- result['msg'] = f'查询异常报错,错误信息:{e}'
- return HttpResponse(json.dumps(result), content_type='application/json')
+ logger.error(f"查询异常报错,查询语句:{sql_content}\n,错误信息:{traceback.format_exc()}")
+ result["status"] = 1
+ result["msg"] = f"查询异常报错,错误信息:{e}"
+ return HttpResponse(json.dumps(result), content_type="application/json")
# 返回查询结果
try:
- return HttpResponse(json.dumps(result, use_decimal=False, cls=ExtendJSONEncoderFTime, bigint_as_string=True),
- content_type='application/json')
+ return HttpResponse(
+ json.dumps(
+ result,
+ use_decimal=False,
+ cls=ExtendJSONEncoderFTime,
+ bigint_as_string=True,
+ ),
+ content_type="application/json",
+ )
# 虽然能正常返回,但是依然会乱码
except UnicodeDecodeError:
- return HttpResponse(json.dumps(result, default=str, bigint_as_string=True, encoding='latin1'),
- content_type='application/json')
+ return HttpResponse(
+ json.dumps(result, default=str, bigint_as_string=True, encoding="latin1"),
+ content_type="application/json",
+ )
-@permission_required('sql.menu_sqlquery', raise_exception=True)
+@permission_required("sql.menu_sqlquery", raise_exception=True)
def querylog(request):
return _querylog(request)
-@permission_required('sql.audit_user', raise_exception=True)
+
+@permission_required("sql.audit_user", raise_exception=True)
def querylog_audit(request):
return _querylog(request)
+
def _querylog(request):
"""
获取sql查询记录
@@ -198,67 +221,85 @@ def _querylog(request):
# 获取用户信息
user = request.user
- limit = int(request.GET.get('limit',0))
- offset = int(request.GET.get('offset',0))
+ limit = int(request.GET.get("limit", 0))
+ offset = int(request.GET.get("offset", 0))
limit = offset + limit
limit = limit if limit else None
- star = True if request.GET.get('star') == 'true' else False
- query_log_id = request.GET.get('query_log_id')
- search = request.GET.get('search', '')
- start_date = request.GET.get('start_date','')
- end_date = request.GET.get('end_date','')
+ star = True if request.GET.get("star") == "true" else False
+ query_log_id = request.GET.get("query_log_id")
+ search = request.GET.get("search", "")
+ start_date = request.GET.get("start_date", "")
+ end_date = request.GET.get("end_date", "")
# 组合筛选项
filter_dict = dict()
# 是否收藏
if star:
- filter_dict['favorite'] = star
+ filter_dict["favorite"] = star
# 语句别名
if query_log_id:
- filter_dict['id'] = query_log_id
+ filter_dict["id"] = query_log_id
# 管理员、审计员查看全部数据,普通用户查看自己的数据
- if not (user.is_superuser or user.has_perm('sql.audit_user')):
- filter_dict['username'] = user.username
-
+ if not (user.is_superuser or user.has_perm("sql.audit_user")):
+ filter_dict["username"] = user.username
+
if start_date and end_date:
- end_date = datetime.datetime.strptime(end_date, '%Y-%m-%d') + datetime.timedelta(days=1)
- filter_dict['create_time__range'] = (start_date, end_date)
+ end_date = datetime.datetime.strptime(
+ end_date, "%Y-%m-%d"
+ ) + datetime.timedelta(days=1)
+ filter_dict["create_time__range"] = (start_date, end_date)
# 过滤组合筛选项
sql_log = QueryLog.objects.filter(**filter_dict)
# 过滤搜索信息
- sql_log = sql_log.filter(Q(sqllog__icontains=search) |
- Q(user_display__icontains=search) |
- Q(alias__icontains=search))
+ sql_log = sql_log.filter(
+ Q(sqllog__icontains=search)
+ | Q(user_display__icontains=search)
+ | Q(alias__icontains=search)
+ )
sql_log_count = sql_log.count()
- sql_log_list = sql_log.order_by('-id')[offset:limit].values(
- "id", "instance_name", "db_name", "sqllog",
- "effect_row", "cost_time", "user_display", "favorite", "alias",
- "create_time")
+ sql_log_list = sql_log.order_by("-id")[offset:limit].values(
+ "id",
+ "instance_name",
+ "db_name",
+ "sqllog",
+ "effect_row",
+ "cost_time",
+ "user_display",
+ "favorite",
+ "alias",
+ "create_time",
+ )
# QuerySet 序列化
rows = [row for row in sql_log_list]
result = {"total": sql_log_count, "rows": rows}
# 返回查询结果
- return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
- content_type='application/json')
+ return HttpResponse(
+ json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
+ content_type="application/json",
+ )
-@permission_required('sql.menu_sqlquery', raise_exception=True)
+@permission_required("sql.menu_sqlquery", raise_exception=True)
def favorite(request):
"""
收藏查询记录,并且设置别名
:param request:
:return:
"""
- query_log_id = request.POST.get('query_log_id')
- star = True if request.POST.get('star') == 'true' else False
- alias = request.POST.get('alias')
- QueryLog(id=query_log_id, favorite=star, alias=alias).save(update_fields=['favorite', 'alias'])
+ query_log_id = request.POST.get("query_log_id")
+ star = True if request.POST.get("star") == "true" else False
+ alias = request.POST.get("alias")
+ QueryLog(id=query_log_id, favorite=star, alias=alias).save(
+ update_fields=["favorite", "alias"]
+ )
# 返回查询结果
- return HttpResponse(json.dumps({'status': 0, 'msg': 'ok'}), content_type='application/json')
+ return HttpResponse(
+ json.dumps({"status": 0, "msg": "ok"}), content_type="application/json"
+ )
def kill_query_conn(instance_id, thread_id):
@@ -266,4 +307,3 @@ def kill_query_conn(instance_id, thread_id):
instance = Instance.objects.get(pk=instance_id)
query_engine = get_engine(instance)
query_engine.kill_connection(thread_id)
-
diff --git a/sql/query_privileges.py b/sql/query_privileges.py
index 9257e23071..4c88980574 100644
--- a/sql/query_privileges.py
+++ b/sql/query_privileges.py
@@ -29,9 +29,9 @@
from sql.utils.workflow_audit import Audit
from sql.utils.sql_utils import extract_tables
-logger = logging.getLogger('default')
+logger = logging.getLogger("default")
-__author__ = 'hhyo'
+__author__ = "hhyo"
# TODO 权限校验内的语法解析和判断独立到每个engine内
@@ -45,22 +45,26 @@ def query_priv_check(user, instance, db_name, sql_content, limit_num):
:param limit_num:
:return:
"""
- result = {'status': 0, 'msg': 'ok', 'data': {'priv_check': True, 'limit_num': 0}}
+ result = {"status": 0, "msg": "ok", "data": {"priv_check": True, "limit_num": 0}}
# 如果有can_query_all_instance, 视为管理员, 仅获取limit值信息
# superuser 拥有全部权限, 不需做特别修改
- if user.has_perm('sql.query_all_instances'):
- priv_limit = int(SysConfig().get('admin_query_limit', 5000))
- result['data']['limit_num'] = min(priv_limit, limit_num) if limit_num else priv_limit
+ if user.has_perm("sql.query_all_instances"):
+ priv_limit = int(SysConfig().get("admin_query_limit", 5000))
+ result["data"]["limit_num"] = (
+ min(priv_limit, limit_num) if limit_num else priv_limit
+ )
return result
# 如果有can_query_resource_group_instance, 视为资源组管理员, 可查询资源组内所有实例数据
- if user.has_perm('sql.query_resource_group_instance'):
- if user_instances(user, tag_codes=['can_read']).filter(pk=instance.pk).exists():
- priv_limit = int(SysConfig().get('admin_query_limit', 5000))
- result['data']['limit_num'] = min(priv_limit, limit_num) if limit_num else priv_limit
+ if user.has_perm("sql.query_resource_group_instance"):
+ if user_instances(user, tag_codes=["can_read"]).filter(pk=instance.pk).exists():
+ priv_limit = int(SysConfig().get("admin_query_limit", 5000))
+ result["data"]["limit_num"] = (
+ min(priv_limit, limit_num) if limit_num else priv_limit
+ )
return result
# 仅MySQL做表权限校验
- if instance.db_type == 'mysql':
+ if instance.db_type == "mysql":
try:
# explain和show create跳过权限校验
if re.match(r"^explain|^show\s+create", sql_content, re.I):
@@ -70,28 +74,39 @@ def query_priv_check(user, instance, db_name, sql_content, limit_num):
# 循环验证权限,可能存在性能问题,但一次查询涉及的库表数量有限
for table in table_ref:
# 既无库权限也无表权限则鉴权失败
- if not _db_priv(user, instance, table['schema']) and \
- not _tb_priv(user, instance, table['schema'], table['name']):
+ if not _db_priv(user, instance, table["schema"]) and not _tb_priv(
+ user, instance, table["schema"], table["name"]
+ ):
# 没有库表查询权限时的staus为2
- result['status'] = 2
- result['msg'] = f"你无{table['schema']}.{table['name']}表的查询权限!请先到查询权限管理进行申请"
+ result["status"] = 2
+ result[
+ "msg"
+ ] = f"你无{table['schema']}.{table['name']}表的查询权限!请先到查询权限管理进行申请"
return result
# 获取查询涉及库/表权限的最小limit限制,和前端传参作对比,取最小值
for table in table_ref:
- priv_limit = _priv_limit(user, instance, db_name=table['schema'], tb_name=table['name'])
+ priv_limit = _priv_limit(
+ user, instance, db_name=table["schema"], tb_name=table["name"]
+ )
limit_num = min(priv_limit, limit_num) if limit_num else priv_limit
- result['data']['limit_num'] = limit_num
+ result["data"]["limit_num"] = limit_num
except Exception as msg:
- logger.error(f"无法校验查询语句权限,{instance.instance_name},{sql_content},{traceback.format_exc()}")
- result['status'] = 1
- result['msg'] = f"无法校验查询语句权限,请联系管理员,错误信息:{msg}"
+ logger.error(
+ f"无法校验查询语句权限,{instance.instance_name},{sql_content},{traceback.format_exc()}"
+ )
+ result["status"] = 1
+ result["msg"] = f"无法校验查询语句权限,请联系管理员,错误信息:{msg}"
# 其他类型实例仅校验库权限
else:
# 先获取查询语句涉及的库,redis、mssql特殊处理,仅校验当前选择的库
- if instance.db_type in ['redis', 'mssql']:
+ if instance.db_type in ["redis", "mssql"]:
dbs = [db_name]
else:
- dbs = [i['schema'].strip('`') for i in extract_tables(sql_content) if i['schema'] is not None]
+ dbs = [
+ i["schema"].strip("`")
+ for i in extract_tables(sql_content)
+ if i["schema"] is not None
+ ]
dbs.append(db_name)
# 库去重
dbs = list(set(dbs))
@@ -101,18 +116,18 @@ def query_priv_check(user, instance, db_name, sql_content, limit_num):
for db_name in dbs:
if not _db_priv(user, instance, db_name):
# 没有库表查询权限时的staus为2
- result['status'] = 2
- result['msg'] = f"你无{db_name}数据库的查询权限!请先到查询权限管理进行申请"
+ result["status"] = 2
+ result["msg"] = f"你无{db_name}数据库的查询权限!请先到查询权限管理进行申请"
return result
# 有所有库权限则获取最小limit值
for db_name in dbs:
priv_limit = _priv_limit(user, instance, db_name=db_name)
limit_num = min(priv_limit, limit_num) if limit_num else priv_limit
- result['data']['limit_num'] = limit_num
+ result["data"]["limit_num"] = limit_num
return result
-@permission_required('sql.menu_queryapplylist', raise_exception=True)
+@permission_required("sql.menu_queryapplylist", raise_exception=True)
def query_priv_apply_list(request):
"""
获取查询权限申请列表
@@ -120,20 +135,22 @@ def query_priv_apply_list(request):
:return:
"""
user = request.user
- limit = int(request.POST.get('limit', 0))
- offset = int(request.POST.get('offset', 0))
+ limit = int(request.POST.get("limit", 0))
+ offset = int(request.POST.get("offset", 0))
limit = offset + limit
- search = request.POST.get('search', '')
+ search = request.POST.get("search", "")
query_privs = QueryPrivilegesApply.objects.all()
# 过滤搜索项,支持模糊搜索标题、用户
if search:
- query_privs = query_privs.filter(Q(title__icontains=search) | Q(user_display__icontains=search))
+ query_privs = query_privs.filter(
+ Q(title__icontains=search) | Q(user_display__icontains=search)
+ )
# 管理员可以看到全部数据
if user.is_superuser:
query_privs = query_privs
# 拥有审核权限、可以查看组内所有工单
- elif user.has_perm('sql.query_review'):
+ elif user.has_perm("sql.query_review"):
# 先获取用户所在资源组列表
group_list = user_groups(user)
group_ids = [group.group_id for group in group_list]
@@ -143,9 +160,19 @@ def query_priv_apply_list(request):
query_privs = query_privs.filter(user_name=user.username)
count = query_privs.count()
- lists = query_privs.order_by('-apply_id')[offset:limit].values(
- 'apply_id', 'title', 'instance__instance_name', 'db_list', 'priv_type', 'table_list', 'limit_num', 'valid_date',
- 'user_display', 'status', 'create_time', 'group_name'
+ lists = query_privs.order_by("-apply_id")[offset:limit].values(
+ "apply_id",
+ "title",
+ "instance__instance_name",
+ "db_list",
+ "priv_type",
+ "table_list",
+ "limit_num",
+ "valid_date",
+ "user_display",
+ "status",
+ "create_time",
+ "group_name",
)
# QuerySet 序列化
@@ -153,49 +180,60 @@ def query_priv_apply_list(request):
result = {"total": count, "rows": rows}
# 返回查询结果
- return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
- content_type='application/json')
+ return HttpResponse(
+ json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
+ content_type="application/json",
+ )
-@permission_required('sql.query_applypriv', raise_exception=True)
+@permission_required("sql.query_applypriv", raise_exception=True)
def query_priv_apply(request):
"""
申请查询权限
:param request:
:return:
"""
- title = request.POST['title']
- instance_name = request.POST.get('instance_name')
- group_name = request.POST.get('group_name')
+ title = request.POST["title"]
+ instance_name = request.POST.get("instance_name")
+ group_name = request.POST.get("group_name")
group_id = ResourceGroup.objects.get(group_name=group_name).group_id
- priv_type = request.POST.get('priv_type')
- db_name = request.POST.get('db_name')
- db_list = request.POST.getlist('db_list[]')
- table_list = request.POST.getlist('table_list[]')
- valid_date = request.POST.get('valid_date')
- limit_num = request.POST.get('limit_num')
+ priv_type = request.POST.get("priv_type")
+ db_name = request.POST.get("db_name")
+ db_list = request.POST.getlist("db_list[]")
+ table_list = request.POST.getlist("table_list[]")
+ valid_date = request.POST.get("valid_date")
+ limit_num = request.POST.get("limit_num")
# 获取用户信息
user = request.user
# 服务端参数校验
- result = {'status': 0, 'msg': 'ok', 'data': []}
+ result = {"status": 0, "msg": "ok", "data": []}
if int(priv_type) == 1:
if not (title and instance_name and db_list and valid_date and limit_num):
- result['status'] = 1
- result['msg'] = '请填写完整'
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result["status"] = 1
+ result["msg"] = "请填写完整"
+ return HttpResponse(json.dumps(result), content_type="application/json")
elif int(priv_type) == 2:
- if not (title and instance_name and db_name and valid_date and table_list and limit_num):
- result['status'] = 1
- result['msg'] = '请填写完整'
- return HttpResponse(json.dumps(result), content_type='application/json')
+ if not (
+ title
+ and instance_name
+ and db_name
+ and valid_date
+ and table_list
+ and limit_num
+ ):
+ result["status"] = 1
+ result["msg"] = "请填写完整"
+ return HttpResponse(json.dumps(result), content_type="application/json")
try:
- user_instances(request.user, tag_codes=['can_read']).get(instance_name=instance_name)
+ user_instances(request.user, tag_codes=["can_read"]).get(
+ instance_name=instance_name
+ )
except Instance.DoesNotExist:
- result['status'] = 1
- result['msg'] = '你所在组未关联该实例!'
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result["status"] = 1
+ result["msg"] = "你所在组未关联该实例!"
+ return HttpResponse(json.dumps(result), content_type="application/json")
# 库权限
ins = Instance.objects.get(instance_name=instance_name)
@@ -203,23 +241,23 @@ def query_priv_apply(request):
# 检查申请账号是否已拥库查询权限
for db_name in db_list:
if _db_priv(user, ins, db_name):
- result['status'] = 1
- result['msg'] = f'你已拥有{instance_name}实例{db_name}库权限,不能重复申请'
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result["status"] = 1
+ result["msg"] = f"你已拥有{instance_name}实例{db_name}库权限,不能重复申请"
+ return HttpResponse(json.dumps(result), content_type="application/json")
# 表权限
elif int(priv_type) == 2:
# 先检查是否拥有库权限
if _db_priv(user, ins, db_name):
- result['status'] = 1
- result['msg'] = f'你已拥有{instance_name}实例{db_name}库的全部权限,不能重复申请'
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result["status"] = 1
+ result["msg"] = f"你已拥有{instance_name}实例{db_name}库的全部权限,不能重复申请"
+ return HttpResponse(json.dumps(result), content_type="application/json")
# 检查申请账号是否已拥有该表的查询权限
for tb_name in table_list:
if _tb_priv(user, ins, db_name, tb_name):
- result['status'] = 1
- result['msg'] = f'你已拥有{instance_name}实例{db_name}.{tb_name}表的查询权限,不能重复申请'
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result["status"] = 1
+ result["msg"] = f"你已拥有{instance_name}实例{db_name}.{tb_name}表的查询权限,不能重复申请"
+ return HttpResponse(json.dumps(result), content_type="application/json")
# 使用事务保持数据一致性
try:
@@ -229,43 +267,53 @@ def query_priv_apply(request):
title=title,
group_id=group_id,
group_name=group_name,
- audit_auth_groups=Audit.settings(group_id, WorkflowDict.workflow_type['query']),
+ audit_auth_groups=Audit.settings(
+ group_id, WorkflowDict.workflow_type["query"]
+ ),
user_name=user.username,
user_display=user.display,
instance=ins,
priv_type=int(priv_type),
valid_date=valid_date,
- status=WorkflowDict.workflow_status['audit_wait'],
- limit_num=limit_num
+ status=WorkflowDict.workflow_status["audit_wait"],
+ limit_num=limit_num,
)
if int(priv_type) == 1:
- applyinfo.db_list = ','.join(db_list)
- applyinfo.table_list = ''
+ applyinfo.db_list = ",".join(db_list)
+ applyinfo.table_list = ""
elif int(priv_type) == 2:
applyinfo.db_list = db_name
- applyinfo.table_list = ','.join(table_list)
+ applyinfo.table_list = ",".join(table_list)
applyinfo.save()
apply_id = applyinfo.apply_id
# 调用工作流插入审核信息,查询权限申请workflow_type=1
- audit_result = Audit.add(WorkflowDict.workflow_type['query'], apply_id)
- if audit_result['status'] == 0:
+ audit_result = Audit.add(WorkflowDict.workflow_type["query"], apply_id)
+ if audit_result["status"] == 0:
# 更新业务表审核状态,判断是否插入权限信息
- _query_apply_audit_call_back(apply_id, audit_result['data']['workflow_status'])
+ _query_apply_audit_call_back(
+ apply_id, audit_result["data"]["workflow_status"]
+ )
except Exception as msg:
logger.error(traceback.format_exc())
- result['status'] = 1
- result['msg'] = str(msg)
+ result["status"] = 1
+ result["msg"] = str(msg)
else:
result = audit_result
# 消息通知
- audit_id = Audit.detail_by_workflow_id(workflow_id=apply_id,
- workflow_type=WorkflowDict.workflow_type['query']).audit_id
- async_task(notify_for_audit, audit_id=audit_id, timeout=60, task_name=f'query-priv-apply-{apply_id}')
- return HttpResponse(json.dumps(result), content_type='application/json')
-
-
-@permission_required('sql.menu_queryapplylist', raise_exception=True)
+ audit_id = Audit.detail_by_workflow_id(
+ workflow_id=apply_id, workflow_type=WorkflowDict.workflow_type["query"]
+ ).audit_id
+ async_task(
+ notify_for_audit,
+ audit_id=audit_id,
+ timeout=60,
+ task_name=f"query-priv-apply-{apply_id}",
+ )
+ return HttpResponse(json.dumps(result), content_type="application/json")
+
+
+@permission_required("sql.menu_queryapplylist", raise_exception=True)
def user_query_priv(request):
"""
用户的查询权限管理
@@ -273,38 +321,54 @@ def user_query_priv(request):
:return:
"""
user = request.user
- user_display = request.POST.get('user_display', 'all')
- limit = int(request.POST.get('limit'))
- offset = int(request.POST.get('offset'))
+ user_display = request.POST.get("user_display", "all")
+ limit = int(request.POST.get("limit"))
+ offset = int(request.POST.get("offset"))
limit = offset + limit
- search = request.POST.get('search', '')
+ search = request.POST.get("search", "")
- user_query_privs = QueryPrivileges.objects.filter(is_deleted=0, valid_date__gte=datetime.datetime.now())
+ user_query_privs = QueryPrivileges.objects.filter(
+ is_deleted=0, valid_date__gte=datetime.datetime.now()
+ )
# 过滤搜索项,支持模糊搜索用户、数据库、表
if search:
- user_query_privs = user_query_privs.filter(Q(user_display__icontains=search) |
- Q(db_name__icontains=search) |
- Q(table_name__icontains=search))
+ user_query_privs = user_query_privs.filter(
+ Q(user_display__icontains=search)
+ | Q(db_name__icontains=search)
+ | Q(table_name__icontains=search)
+ )
# 过滤用户
- if user_display != 'all':
+ if user_display != "all":
user_query_privs = user_query_privs.filter(user_display=user_display)
# 管理员可以看到全部数据
if user.is_superuser:
user_query_privs = user_query_privs
# 拥有管理权限、可以查看组内所有工单
- elif user.has_perm('sql.query_mgtpriv'):
+ elif user.has_perm("sql.query_mgtpriv"):
# 先获取用户所在资源组列表
group_list = user_groups(user)
group_ids = [group.group_id for group in group_list]
- user_query_privs = user_query_privs.filter(instance__queryprivilegesapply__group_id__in=group_ids)
+ user_query_privs = user_query_privs.filter(
+ instance__queryprivilegesapply__group_id__in=group_ids
+ )
# 其他人只能看到自己提交的工单
else:
user_query_privs = user_query_privs.filter(user_name=user.username)
privileges_count = user_query_privs.distinct().count()
- privileges_list = user_query_privs.distinct().order_by('-privilege_id')[offset:limit].values(
- 'privilege_id', 'user_display', 'instance__instance_name', 'db_name', 'priv_type',
- 'table_name', 'limit_num', 'valid_date'
+ privileges_list = (
+ user_query_privs.distinct()
+ .order_by("-privilege_id")[offset:limit]
+ .values(
+ "privilege_id",
+ "user_display",
+ "instance__instance_name",
+ "db_name",
+ "priv_type",
+ "table_name",
+ "limit_num",
+ "valid_date",
+ )
)
# QuerySet 序列化
@@ -312,45 +376,47 @@ def user_query_priv(request):
result = {"total": privileges_count, "rows": rows}
# 返回查询结果
- return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
- content_type='application/json')
+ return HttpResponse(
+ json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
+ content_type="application/json",
+ )
-@permission_required('sql.query_mgtpriv', raise_exception=True)
+@permission_required("sql.query_mgtpriv", raise_exception=True)
def query_priv_modify(request):
"""
变更权限信息
:param request:
:return:
"""
- privilege_id = request.POST.get('privilege_id')
- type = request.POST.get('type')
- result = {'status': 0, 'msg': 'ok', 'data': []}
+ privilege_id = request.POST.get("privilege_id")
+ type = request.POST.get("type")
+ result = {"status": 0, "msg": "ok", "data": []}
# type=1删除权限,type=2变更权限
try:
privilege = QueryPrivileges.objects.get(privilege_id=int(privilege_id))
except QueryPrivileges.DoesNotExist:
- result['msg'] = '待操作权限不存在'
- result['status'] = 1
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result["msg"] = "待操作权限不存在"
+ result["status"] = 1
+ return HttpResponse(json.dumps(result), content_type="application/json")
if int(type) == 1:
# 删除权限
privilege.is_deleted = 1
- privilege.save(update_fields=['is_deleted'])
- return HttpResponse(json.dumps(result), content_type='application/json')
+ privilege.save(update_fields=["is_deleted"])
+ return HttpResponse(json.dumps(result), content_type="application/json")
elif int(type) == 2:
# 变更权限
- valid_date = request.POST.get('valid_date')
- limit_num = request.POST.get('limit_num')
+ valid_date = request.POST.get("valid_date")
+ limit_num = request.POST.get("limit_num")
privilege.valid_date = valid_date
privilege.limit_num = limit_num
- privilege.save(update_fields=['valid_date', 'limit_num'])
- return HttpResponse(json.dumps(result), content_type='application/json')
+ privilege.save(update_fields=["valid_date", "limit_num"])
+ return HttpResponse(json.dumps(result), content_type="application/json")
-@permission_required('sql.query_review', raise_exception=True)
+@permission_required("sql.query_review", raise_exception=True)
def query_priv_audit(request):
"""
查询权限审核
@@ -359,42 +425,52 @@ def query_priv_audit(request):
"""
# 获取用户信息
user = request.user
- apply_id = int(request.POST['apply_id'])
- audit_status = int(request.POST['audit_status'])
- audit_remark = request.POST.get('audit_remark')
+ apply_id = int(request.POST["apply_id"])
+ audit_status = int(request.POST["audit_status"])
+ audit_remark = request.POST.get("audit_remark")
if audit_remark is None:
- audit_remark = ''
+ audit_remark = ""
if Audit.can_review(request.user, apply_id, 1) is False:
- context = {'errMsg': '你无权操作当前工单!'}
- return render(request, 'error.html', context)
+ context = {"errMsg": "你无权操作当前工单!"}
+ return render(request, "error.html", context)
# 使用事务保持数据一致性
try:
with transaction.atomic():
- audit_id = Audit.detail_by_workflow_id(workflow_id=apply_id,
- workflow_type=WorkflowDict.workflow_type['query']).audit_id
+ audit_id = Audit.detail_by_workflow_id(
+ workflow_id=apply_id, workflow_type=WorkflowDict.workflow_type["query"]
+ ).audit_id
# 调用工作流接口审核
- audit_result = Audit.audit(audit_id, audit_status, user.username, audit_remark)
+ audit_result = Audit.audit(
+ audit_id, audit_status, user.username, audit_remark
+ )
# 按照审核结果更新业务表审核状态
audit_detail = Audit.detail(audit_id)
- if audit_detail.workflow_type == WorkflowDict.workflow_type['query']:
+ if audit_detail.workflow_type == WorkflowDict.workflow_type["query"]:
# 更新业务表审核状态,插入权限信息
- _query_apply_audit_call_back(audit_detail.workflow_id, audit_result['data']['workflow_status'])
+ _query_apply_audit_call_back(
+ audit_detail.workflow_id, audit_result["data"]["workflow_status"]
+ )
except Exception as msg:
logger.error(traceback.format_exc())
- context = {'errMsg': msg}
- return render(request, 'error.html', context)
+ context = {"errMsg": msg}
+ return render(request, "error.html", context)
else:
# 消息通知
- async_task(notify_for_audit, audit_id=audit_id, audit_remark=audit_remark, timeout=60,
- task_name=f'query-priv-audit-{apply_id}')
+ async_task(
+ notify_for_audit,
+ audit_id=audit_id,
+ audit_remark=audit_remark,
+ timeout=60,
+ task_name=f"query-priv-audit-{apply_id}",
+ )
- return HttpResponseRedirect(reverse('sql:queryapplydetail', args=(apply_id,)))
+ return HttpResponseRedirect(reverse("sql:queryapplydetail", args=(apply_id,)))
def _table_ref(sql_content, instance, db_name):
@@ -406,7 +482,9 @@ def _table_ref(sql_content, instance, db_name):
:return:
"""
engine = GoInceptionEngine()
- query_tree = engine.query_print(instance=instance, db_name=db_name, sql=sql_content).get('query_tree')
+ query_tree = engine.query_print(
+ instance=instance, db_name=db_name, sql=sql_content
+ ).get("query_tree")
return engine.get_table_ref(json.loads(query_tree), db_name=db_name)
@@ -420,11 +498,16 @@ def _db_priv(user, instance, db_name):
TODO 返回统一为 int 类型, 不存在返回0 (虽然其实在python中 0==False)
"""
# 获取用户库权限
- user_privileges = QueryPrivileges.objects.filter(user_name=user.username, instance=instance, db_name=str(db_name),
- valid_date__gte=datetime.datetime.now(), is_deleted=0,
- priv_type=1)
+ user_privileges = QueryPrivileges.objects.filter(
+ user_name=user.username,
+ instance=instance,
+ db_name=str(db_name),
+ valid_date__gte=datetime.datetime.now(),
+ is_deleted=0,
+ priv_type=1,
+ )
if user.is_superuser:
- return int(SysConfig().get('admin_query_limit', 5000))
+ return int(SysConfig().get("admin_query_limit", 5000))
else:
if user_privileges.exists():
return user_privileges.first().limit_num
@@ -441,11 +524,17 @@ def _tb_priv(user, instance, db_name, tb_name):
:return: 权限存在则返回对应权限的limit_num,否则返回False
"""
# 获取用户表权限
- user_privileges = QueryPrivileges.objects.filter(user_name=user.username, instance=instance, db_name=str(db_name),
- table_name=str(tb_name), valid_date__gte=datetime.datetime.now(),
- is_deleted=0, priv_type=2)
+ user_privileges = QueryPrivileges.objects.filter(
+ user_name=user.username,
+ instance=instance,
+ db_name=str(db_name),
+ table_name=str(tb_name),
+ valid_date__gte=datetime.datetime.now(),
+ is_deleted=0,
+ priv_type=2,
+ )
if user.is_superuser:
- return int(SysConfig().get('admin_query_limit', 5000))
+ return int(SysConfig().get("admin_query_limit", 5000))
else:
if user_privileges.exists():
return user_privileges.first().limit_num
@@ -473,7 +562,7 @@ def _priv_limit(user, instance, db_name, tb_name=None):
elif tb_limit_num:
return tb_limit_num
else:
- raise RuntimeError('用户无任何有效权限!')
+ raise RuntimeError("用户无任何有效权限!")
def _query_apply_audit_call_back(apply_id, workflow_status):
@@ -488,27 +577,37 @@ def _query_apply_audit_call_back(apply_id, workflow_status):
apply_info.status = workflow_status
apply_info.save()
# 审核通过插入权限信息,批量插入,减少性能消耗
- if workflow_status == WorkflowDict.workflow_status['audit_success']:
+ if workflow_status == WorkflowDict.workflow_status["audit_success"]:
apply_queryset = QueryPrivilegesApply.objects.get(apply_id=apply_id)
# 库权限
if apply_queryset.priv_type == 1:
- insert_list = [QueryPrivileges(
- user_name=apply_queryset.user_name,
- user_display=apply_queryset.user_display,
- instance=apply_queryset.instance,
- db_name=db_name,
- table_name=apply_queryset.table_list, valid_date=apply_queryset.valid_date,
- limit_num=apply_queryset.limit_num, priv_type=apply_queryset.priv_type) for db_name in
- apply_queryset.db_list.split(',')]
+ insert_list = [
+ QueryPrivileges(
+ user_name=apply_queryset.user_name,
+ user_display=apply_queryset.user_display,
+ instance=apply_queryset.instance,
+ db_name=db_name,
+ table_name=apply_queryset.table_list,
+ valid_date=apply_queryset.valid_date,
+ limit_num=apply_queryset.limit_num,
+ priv_type=apply_queryset.priv_type,
+ )
+ for db_name in apply_queryset.db_list.split(",")
+ ]
# 表权限
elif apply_queryset.priv_type == 2:
- insert_list = [QueryPrivileges(
- user_name=apply_queryset.user_name,
- user_display=apply_queryset.user_display,
- instance=apply_queryset.instance,
- db_name=apply_queryset.db_list,
- table_name=table_name, valid_date=apply_queryset.valid_date,
- limit_num=apply_queryset.limit_num, priv_type=apply_queryset.priv_type) for table_name in
- apply_queryset.table_list.split(',')]
+ insert_list = [
+ QueryPrivileges(
+ user_name=apply_queryset.user_name,
+ user_display=apply_queryset.user_display,
+ instance=apply_queryset.instance,
+ db_name=apply_queryset.db_list,
+ table_name=table_name,
+ valid_date=apply_queryset.valid_date,
+ limit_num=apply_queryset.limit_num,
+ priv_type=apply_queryset.priv_type,
+ )
+ for table_name in apply_queryset.table_list.split(",")
+ ]
QueryPrivileges.objects.bulk_create(insert_list)
diff --git a/sql/resource_group.py b/sql/resource_group.py
index e9875a14f5..6d9fa0226e 100644
--- a/sql/resource_group.py
+++ b/sql/resource_group.py
@@ -14,29 +14,33 @@
from sql.utils.resource_group import user_instances
from sql.utils.workflow_audit import Audit
-logger = logging.getLogger('default')
+logger = logging.getLogger("default")
@superuser_required
def group(request):
"""获取资源组列表"""
- limit = int(request.POST.get('limit'))
- offset = int(request.POST.get('offset'))
+ limit = int(request.POST.get("limit"))
+ offset = int(request.POST.get("offset"))
limit = offset + limit
- search = request.POST.get('search', '')
+ search = request.POST.get("search", "")
# 过滤搜索条件
group_obj = ResourceGroup.objects.filter(group_name__icontains=search, is_deleted=0)
group_count = group_obj.count()
- group_list = group_obj[offset:limit].values("group_id", "group_name", "ding_webhook")
+ group_list = group_obj[offset:limit].values(
+ "group_id", "group_name", "ding_webhook"
+ )
# QuerySet 序列化
rows = [row for row in group_list]
result = {"total": group_count, "rows": rows}
# 返回查询结果
- return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
- content_type='application/json')
+ return HttpResponse(
+ json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
+ content_type="application/json",
+ )
def associated_objects(request):
@@ -44,12 +48,12 @@ def associated_objects(request):
获取资源组已关联对象信息
type:(0, '用户'), (1, '实例')
"""
- group_id = int(request.POST.get('group_id'))
- object_type = request.POST.get('type')
- limit = int(request.POST.get('limit'))
- offset = int(request.POST.get('offset'))
+ group_id = int(request.POST.get("group_id"))
+ object_type = request.POST.get("type")
+ limit = int(request.POST.get("limit"))
+ offset = int(request.POST.get("offset"))
limit = offset + limit
- search = request.POST.get('search')
+ search = request.POST.get("search")
# 获取关联数据
resource_group = ResourceGroup.objects.get(group_id=group_id)
@@ -60,29 +64,25 @@ def associated_objects(request):
rows_users = rows_users.filter(display__contains=search)
rows_instances = rows_instances.filter(instance_name__contains=search)
rows_users = rows_users.annotate(
- object_id=F('id'),
+ object_id=F("id"),
object_type=Value(0, output_field=IntegerField()),
- object_name=F('display'),
- group_id=F('resource_group__group_id'),
- group_name=F('resource_group__group_name')
- ).values(
- 'object_type', 'object_id', 'object_name',
- 'group_id', 'group_name')
+ object_name=F("display"),
+ group_id=F("resource_group__group_id"),
+ group_name=F("resource_group__group_name"),
+ ).values("object_type", "object_id", "object_name", "group_id", "group_name")
rows_instances = rows_instances.annotate(
- object_id=F('id'),
+ object_id=F("id"),
object_type=Value(1, output_field=IntegerField()),
- object_name=F('instance_name'),
- group_id=F('resource_group__group_id'),
- group_name=F('resource_group__group_name')
- ).values(
- 'object_type', 'object_id', 'object_name',
- 'group_id', 'group_name')
+ object_name=F("instance_name"),
+ group_id=F("resource_group__group_id"),
+ group_name=F("resource_group__group_name"),
+ ).values("object_type", "object_id", "object_name", "group_id", "group_name")
# 过滤对象类型
- if object_type == '0':
+ if object_type == "0":
rows_obj = rows_users
count = rows_obj.count()
rows = [row for row in rows_obj][offset:limit]
- elif object_type == '1':
+ elif object_type == "1":
rows_obj = rows_instances
count = rows_obj.count()
rows = [row for row in rows_obj][offset:limit]
@@ -90,8 +90,10 @@ def associated_objects(request):
rows = list(chain(rows_users, rows_instances))
count = len(rows)
rows = rows[offset:limit]
- result = {'status': 0, 'msg': 'ok', "total": count, "rows": rows}
- return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder), content_type='application/json')
+ result = {"status": 0, "msg": "ok", "total": count, "rows": rows}
+ return HttpResponse(
+ json.dumps(result, cls=ExtendJSONEncoder), content_type="application/json"
+ )
def unassociated_objects(request):
@@ -99,33 +101,38 @@ def unassociated_objects(request):
获取资源组未关联对象信息
type:(0, '用户'), (1, '实例')
"""
- group_id = int(request.POST.get('group_id'))
- object_type = int(request.POST.get('object_type'))
+ group_id = int(request.POST.get("group_id"))
+ object_type = int(request.POST.get("object_type"))
# 获取关联数据
resource_group = ResourceGroup.objects.get(group_id=group_id)
if object_type == 0:
associated_user_ids = [user.id for user in resource_group.users_set.all()]
- rows = Users.objects.exclude(pk__in=associated_user_ids).annotate(
- object_id=F('pk'), object_name=F('display')).values('object_id', 'object_name')
+ rows = (
+ Users.objects.exclude(pk__in=associated_user_ids)
+ .annotate(object_id=F("pk"), object_name=F("display"))
+ .values("object_id", "object_name")
+ )
elif object_type == 1:
associated_instance_ids = [ins.id for ins in resource_group.instance_set.all()]
- rows = Instance.objects.exclude(pk__in=associated_instance_ids).annotate(
- object_id=F('pk'), object_name=F('instance_name')
- ).values('object_id', 'object_name')
+ rows = (
+ Instance.objects.exclude(pk__in=associated_instance_ids)
+ .annotate(object_id=F("pk"), object_name=F("instance_name"))
+ .values("object_id", "object_name")
+ )
else:
- raise ValueError('关联对象类型不正确')
+ raise ValueError("关联对象类型不正确")
rows = [row for row in rows]
- result = {'status': 0, 'msg': 'ok', "rows": rows, "total": len(rows)}
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result = {"status": 0, "msg": "ok", "rows": rows, "total": len(rows)}
+ return HttpResponse(json.dumps(result), content_type="application/json")
def instances(request):
"""获取资源组关联实例列表"""
- group_name = request.POST.get('group_name')
+ group_name = request.POST.get("group_name")
group_id = ResourceGroup.objects.get(group_name=group_name).group_id
- tag_code = request.POST.get('tag_code')
- db_type = request.POST.get('db_type')
+ tag_code = request.POST.get("tag_code")
+ db_type = request.POST.get("db_type")
# 先获取资源组关联所有实例列表
ins = ResourceGroup.objects.get(group_id=group_id).instance_set.all()
@@ -134,28 +141,34 @@ def instances(request):
filter_dict = dict()
# db_type
if db_type:
- filter_dict['db_type'] = db_type
+ filter_dict["db_type"] = db_type
if tag_code:
- filter_dict['instance_tag__tag_code'] = tag_code
- filter_dict['instance_tag__active'] = True
- ins = ins.filter(**filter_dict).order_by(Convert('instance_name', 'gbk').asc()).values(
- 'id', 'type', 'db_type', 'instance_name')
+ filter_dict["instance_tag__tag_code"] = tag_code
+ filter_dict["instance_tag__active"] = True
+ ins = (
+ ins.filter(**filter_dict)
+ .order_by(Convert("instance_name", "gbk").asc())
+ .values("id", "type", "db_type", "instance_name")
+ )
rows = [row for row in ins]
- result = {'status': 0, 'msg': 'ok', "data": rows}
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result = {"status": 0, "msg": "ok", "data": rows}
+ return HttpResponse(json.dumps(result), content_type="application/json")
def user_all_instances(request):
"""获取用户所有实例列表(通过资源组间接关联)"""
user = request.user
- type = request.GET.get('type')
- db_type = request.GET.getlist('db_type[]')
- tag_codes = request.GET.getlist('tag_codes[]')
- instances = user_instances(user, type, db_type, tag_codes).order_by(
- Convert('instance_name', 'gbk').asc()).values('id', 'type', 'db_type', 'instance_name')
+ type = request.GET.get("type")
+ db_type = request.GET.getlist("db_type[]")
+ tag_codes = request.GET.getlist("tag_codes[]")
+ instances = (
+ user_instances(user, type, db_type, tag_codes)
+ .order_by(Convert("instance_name", "gbk").asc())
+ .values("id", "type", "db_type", "instance_name")
+ )
rows = [row for row in instances]
- result = {'status': 0, 'msg': 'ok', "data": rows}
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result = {"status": 0, "msg": "ok", "data": rows}
+ return HttpResponse(json.dumps(result), content_type="application/json")
@superuser_required
@@ -164,71 +177,84 @@ def addrelation(request):
添加资源组关联对象
type:(0, '用户'), (1, '实例')
"""
- group_id = int(request.POST.get('group_id'))
- object_type = request.POST.get('object_type')
- object_list = json.loads(request.POST.get('object_info'))
+ group_id = int(request.POST.get("group_id"))
+ object_type = request.POST.get("object_type")
+ object_list = json.loads(request.POST.get("object_info"))
try:
resource_group = ResourceGroup.objects.get(group_id=group_id)
- obj_ids = [int(obj.split(',')[0]) for obj in object_list]
- if object_type == '0': # 用户
+ obj_ids = [int(obj.split(",")[0]) for obj in object_list]
+ if object_type == "0": # 用户
resource_group.users_set.add(*Users.objects.filter(pk__in=obj_ids))
- elif object_type == '1': # 实例
+ elif object_type == "1": # 实例
resource_group.instance_set.add(*Instance.objects.filter(pk__in=obj_ids))
- result = {'status': 0, 'msg': 'ok'}
+ result = {"status": 0, "msg": "ok"}
except Exception as e:
logger.error(traceback.format_exc())
- result = {'status': 1, 'msg': str(e)}
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result = {"status": 1, "msg": str(e)}
+ return HttpResponse(json.dumps(result), content_type="application/json")
def auditors(request):
"""获取资源组的审批流程"""
- group_name = request.POST.get('group_name')
- workflow_type = request.POST['workflow_type']
- result = {'status': 0, 'msg': 'ok', 'data': {'auditors': '', 'auditors_display': ''}}
+ group_name = request.POST.get("group_name")
+ workflow_type = request.POST["workflow_type"]
+ result = {
+ "status": 0,
+ "msg": "ok",
+ "data": {"auditors": "", "auditors_display": ""},
+ }
if group_name:
group_id = ResourceGroup.objects.get(group_name=group_name).group_id
- audit_auth_groups = Audit.settings(group_id=group_id, workflow_type=workflow_type)
+ audit_auth_groups = Audit.settings(
+ group_id=group_id, workflow_type=workflow_type
+ )
else:
- result['status'] = 1
- result['msg'] = '参数错误'
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result["status"] = 1
+ result["msg"] = "参数错误"
+ return HttpResponse(json.dumps(result), content_type="application/json")
# 获取权限组名称
if audit_auth_groups:
# 校验配置
- for auth_group_id in audit_auth_groups.split(','):
+ for auth_group_id in audit_auth_groups.split(","):
try:
Group.objects.get(id=auth_group_id)
except Exception:
- result['status'] = 1
- result['msg'] = '审批流程权限组不存在,请重新配置!'
- return HttpResponse(json.dumps(result), content_type='application/json')
- audit_auth_groups_name = '->'.join(
- [Group.objects.get(id=auth_group_id).name for auth_group_id in audit_auth_groups.split(',')])
- result['data']['auditors'] = audit_auth_groups
- result['data']['auditors_display'] = audit_auth_groups_name
+ result["status"] = 1
+ result["msg"] = "审批流程权限组不存在,请重新配置!"
+ return HttpResponse(json.dumps(result), content_type="application/json")
+ audit_auth_groups_name = "->".join(
+ [
+ Group.objects.get(id=auth_group_id).name
+ for auth_group_id in audit_auth_groups.split(",")
+ ]
+ )
+ result["data"]["auditors"] = audit_auth_groups
+ result["data"]["auditors_display"] = audit_auth_groups_name
- return HttpResponse(json.dumps(result), content_type='application/json')
+ return HttpResponse(json.dumps(result), content_type="application/json")
@superuser_required
def changeauditors(request):
"""设置资源组的审批流程"""
- auth_groups = request.POST.get('audit_auth_groups')
- group_name = request.POST.get('group_name')
- workflow_type = request.POST.get('workflow_type')
- result = {'status': 0, 'msg': 'ok', 'data': []}
+ auth_groups = request.POST.get("audit_auth_groups")
+ group_name = request.POST.get("group_name")
+ workflow_type = request.POST.get("workflow_type")
+ result = {"status": 0, "msg": "ok", "data": []}
# 调用工作流修改审核配置
group_id = ResourceGroup.objects.get(group_name=group_name).group_id
- audit_auth_groups = [str(Group.objects.get(name=auth_group).id) for auth_group in auth_groups.split(',')]
+ audit_auth_groups = [
+ str(Group.objects.get(name=auth_group).id)
+ for auth_group in auth_groups.split(",")
+ ]
try:
- Audit.change_settings(group_id, workflow_type, ','.join(audit_auth_groups))
+ Audit.change_settings(group_id, workflow_type, ",".join(audit_auth_groups))
except Exception as msg:
logger.error(traceback.format_exc())
- result['msg'] = str(msg)
- result['status'] = 1
+ result["msg"] = str(msg)
+ result["status"] = 1
# 返回结果
- return HttpResponse(json.dumps(result), content_type='application/json')
+ return HttpResponse(json.dumps(result), content_type="application/json")
diff --git a/sql/slowlog.py b/sql/slowlog.py
index f389f8e451..630e9d9446 100644
--- a/sql/slowlog.py
+++ b/sql/slowlog.py
@@ -14,24 +14,26 @@
from common.utils.extend_json_encoder import ExtendJSONEncoder
from .models import Instance, SlowQuery, SlowQueryHistory, AliyunRdsConfig
-from .aliyun_rds import slowquery_review as aliyun_rds_slowquery_review, \
- slowquery_review_history as aliyun_rds_slowquery_review_history
+from .aliyun_rds import (
+ slowquery_review as aliyun_rds_slowquery_review,
+ slowquery_review_history as aliyun_rds_slowquery_review_history,
+)
import logging
-logger = logging.getLogger('default')
+logger = logging.getLogger("default")
# 获取SQL慢日志统计
-@permission_required('sql.menu_slowquery', raise_exception=True)
+@permission_required("sql.menu_slowquery", raise_exception=True)
def slowquery_review(request):
- instance_name = request.POST.get('instance_name')
+ instance_name = request.POST.get("instance_name")
# 服务端权限校验
try:
- user_instances(request.user, db_type=['mysql']).get(instance_name=instance_name)
+ user_instances(request.user, db_type=["mysql"]).get(instance_name=instance_name)
except Exception:
- result = {'status': 1, 'msg': '你所在组未关联该实例', 'data': []}
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result = {"status": 1, "msg": "你所在组未关联该实例", "data": []}
+ return HttpResponse(json.dumps(result), content_type="application/json")
# 判断是RDS还是其他实例
instance_info = Instance.objects.get(instance_name=instance_name)
@@ -39,65 +41,85 @@ def slowquery_review(request):
# 调用阿里云慢日志接口
result = aliyun_rds_slowquery_review(request)
else:
- start_time = request.POST.get('StartTime')
- end_time = request.POST.get('EndTime')
- db_name = request.POST.get('db_name')
- limit = int(request.POST.get('limit'))
- offset = int(request.POST.get('offset'))
+ start_time = request.POST.get("StartTime")
+ end_time = request.POST.get("EndTime")
+ db_name = request.POST.get("db_name")
+ limit = int(request.POST.get("limit"))
+ offset = int(request.POST.get("offset"))
limit = offset + limit
- search = request.POST.get('search')
- sortName = str(request.POST.get('sortName'))
- sortOrder = str(request.POST.get('sortOrder')).lower()
+ search = request.POST.get("search")
+ sortName = str(request.POST.get("sortName"))
+ sortOrder = str(request.POST.get("sortOrder")).lower()
# 时间处理
- end_time = datetime.datetime.strptime(end_time, '%Y-%m-%d') + datetime.timedelta(days=1)
+ end_time = datetime.datetime.strptime(
+ end_time, "%Y-%m-%d"
+ ) + datetime.timedelta(days=1)
filter_kwargs = {"slowqueryhistory__db_max": db_name} if db_name else {}
# 获取慢查数据
- slowsql_obj = SlowQuery.objects.filter(
- slowqueryhistory__hostname_max=(instance_info.host + ':' + str(instance_info.port)),
- slowqueryhistory__ts_min__range=(start_time, end_time),
- fingerprint__icontains=search,
- **filter_kwargs
- ).annotate(SQLText=F('fingerprint'), SQLId=F('checksum')).values('SQLText', 'SQLId').annotate(
- CreateTime=Max('slowqueryhistory__ts_max'),
- DBName=Max('slowqueryhistory__db_max'), # 数据库
- QueryTimeAvg=Sum('slowqueryhistory__query_time_sum') / Sum('slowqueryhistory__ts_cnt'), # 平均执行时长
- MySQLTotalExecutionCounts=Sum('slowqueryhistory__ts_cnt'), # 执行总次数
- MySQLTotalExecutionTimes=Sum('slowqueryhistory__query_time_sum'), # 执行总时长
- ParseTotalRowCounts=Sum('slowqueryhistory__rows_examined_sum'), # 扫描总行数
- ReturnTotalRowCounts=Sum('slowqueryhistory__rows_sent_sum'), # 返回总行数
- ParseRowAvg=Sum('slowqueryhistory__rows_examined_sum') / Sum('slowqueryhistory__ts_cnt'), # 平均扫描行数
- ReturnRowAvg=Sum('slowqueryhistory__rows_sent_sum') / Sum('slowqueryhistory__ts_cnt'), # 平均返回行数
+ slowsql_obj = (
+ SlowQuery.objects.filter(
+ slowqueryhistory__hostname_max=(
+ instance_info.host + ":" + str(instance_info.port)
+ ),
+ slowqueryhistory__ts_min__range=(start_time, end_time),
+ fingerprint__icontains=search,
+ **filter_kwargs
+ )
+ .annotate(SQLText=F("fingerprint"), SQLId=F("checksum"))
+ .values("SQLText", "SQLId")
+ .annotate(
+ CreateTime=Max("slowqueryhistory__ts_max"),
+ DBName=Max("slowqueryhistory__db_max"), # 数据库
+ QueryTimeAvg=Sum("slowqueryhistory__query_time_sum")
+ / Sum("slowqueryhistory__ts_cnt"), # 平均执行时长
+ MySQLTotalExecutionCounts=Sum("slowqueryhistory__ts_cnt"), # 执行总次数
+ MySQLTotalExecutionTimes=Sum(
+ "slowqueryhistory__query_time_sum"
+ ), # 执行总时长
+ ParseTotalRowCounts=Sum("slowqueryhistory__rows_examined_sum"), # 扫描总行数
+ ReturnTotalRowCounts=Sum("slowqueryhistory__rows_sent_sum"), # 返回总行数
+ ParseRowAvg=Sum("slowqueryhistory__rows_examined_sum")
+ / Sum("slowqueryhistory__ts_cnt"), # 平均扫描行数
+ ReturnRowAvg=Sum("slowqueryhistory__rows_sent_sum")
+ / Sum("slowqueryhistory__ts_cnt"), # 平均返回行数
+ )
)
slow_sql_count = slowsql_obj.count()
# 默认“执行总次数”倒序排列
- slow_sql_list = slowsql_obj.order_by('-' + sortName if 'desc'.__eq__(sortOrder) else sortName)[offset:limit]
+ slow_sql_list = slowsql_obj.order_by(
+ "-" + sortName if "desc".__eq__(sortOrder) else sortName
+ )[offset:limit]
# QuerySet 序列化
sql_slow_log = []
for SlowLog in slow_sql_list:
- SlowLog['QueryTimeAvg'] = round(SlowLog['QueryTimeAvg'], 6)
- SlowLog['MySQLTotalExecutionTimes'] = round(SlowLog['MySQLTotalExecutionTimes'], 6)
- SlowLog['ParseRowAvg'] = int(SlowLog['ParseRowAvg'])
- SlowLog['ReturnRowAvg'] = int(SlowLog['ReturnRowAvg'])
+ SlowLog["QueryTimeAvg"] = round(SlowLog["QueryTimeAvg"], 6)
+ SlowLog["MySQLTotalExecutionTimes"] = round(
+ SlowLog["MySQLTotalExecutionTimes"], 6
+ )
+ SlowLog["ParseRowAvg"] = int(SlowLog["ParseRowAvg"])
+ SlowLog["ReturnRowAvg"] = int(SlowLog["ReturnRowAvg"])
sql_slow_log.append(SlowLog)
result = {"total": slow_sql_count, "rows": sql_slow_log}
# 返回查询结果
- return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
- content_type='application/json')
+ return HttpResponse(
+ json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
+ content_type="application/json",
+ )
# 获取SQL慢日志明细
-@permission_required('sql.menu_slowquery', raise_exception=True)
+@permission_required("sql.menu_slowquery", raise_exception=True)
def slowquery_review_history(request):
- instance_name = request.POST.get('instance_name')
+ instance_name = request.POST.get("instance_name")
# 服务端权限校验
try:
- user_instances(request.user, db_type=['mysql']).get(instance_name=instance_name)
+ user_instances(request.user, db_type=["mysql"]).get(instance_name=instance_name)
except Exception:
- result = {'status': 1, 'msg': '你所在组未关联该实例', 'data': []}
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result = {"status": 1, "msg": "你所在组未关联该实例", "data": []}
+ return HttpResponse(json.dumps(result), content_type="application/json")
# 判断是RDS还是其他实例
instance_info = Instance.objects.get(instance_name=instance_name)
@@ -105,87 +127,115 @@ def slowquery_review_history(request):
# 调用阿里云慢日志接口
result = aliyun_rds_slowquery_review_history(request)
else:
- start_time = request.POST.get('StartTime')
- end_time = request.POST.get('EndTime')
- db_name = request.POST.get('db_name')
- sql_id = request.POST.get('SQLId')
- limit = int(request.POST.get('limit'))
- offset = int(request.POST.get('offset'))
- search = request.POST.get('search')
- sortName = str(request.POST.get('sortName'))
- sortOrder = str(request.POST.get('sortOrder')).lower()
+ start_time = request.POST.get("StartTime")
+ end_time = request.POST.get("EndTime")
+ db_name = request.POST.get("db_name")
+ sql_id = request.POST.get("SQLId")
+ limit = int(request.POST.get("limit"))
+ offset = int(request.POST.get("offset"))
+ search = request.POST.get("search")
+ sortName = str(request.POST.get("sortName"))
+ sortOrder = str(request.POST.get("sortOrder")).lower()
# 时间处理
- end_time = datetime.datetime.strptime(end_time, '%Y-%m-%d') + datetime.timedelta(days=1)
+ end_time = datetime.datetime.strptime(
+ end_time, "%Y-%m-%d"
+ ) + datetime.timedelta(days=1)
limit = offset + limit
filter_kwargs = {}
filter_kwargs.update({"checksum": sql_id}) if sql_id else None
- filter_kwargs.update({'db_max': db_name}) if db_name else None
+ filter_kwargs.update({"db_max": db_name}) if db_name else None
# SQLId、DBName非必传
# 获取慢查明细数据
slow_sql_record_obj = SlowQueryHistory.objects.filter(
- hostname_max=(instance_info.host + ':' + str(instance_info.port)),
+ hostname_max=(instance_info.host + ":" + str(instance_info.port)),
ts_min__range=(start_time, end_time),
sample__icontains=search,
**filter_kwargs
- ).annotate(ExecutionStartTime=F('ts_min'), # 本次统计(每5分钟一次)该类型sql语句出现的最小时间
- DBName=F('db_max'), # 数据库名
- HostAddress=Concat(V('\''), 'user_max', V('\''), V('@'), V('\''), 'client_max', V('\'')), # 用户名
- SQLText=F('sample'), # SQL语句
- TotalExecutionCounts=F('ts_cnt'), # 本次统计该sql语句出现的次数
- QueryTimePct95=F('query_time_pct_95'), # 本次统计该sql语句95%耗时
- QueryTimes=F('query_time_sum'), # 本次统计该sql语句花费的总时间(秒)
- LockTimes=F('lock_time_sum'), # 本次统计该sql语句锁定总时长(秒)
- ParseRowCounts=F('rows_examined_sum'), # 本次统计该sql语句解析总行数
- ReturnRowCounts=F('rows_sent_sum') # 本次统计该sql语句返回总行数
- )
+ ).annotate(
+ ExecutionStartTime=F("ts_min"), # 本次统计(每5分钟一次)该类型sql语句出现的最小时间
+ DBName=F("db_max"), # 数据库名
+ HostAddress=Concat(
+ V("'"), "user_max", V("'"), V("@"), V("'"), "client_max", V("'")
+ ), # 用户名
+ SQLText=F("sample"), # SQL语句
+ TotalExecutionCounts=F("ts_cnt"), # 本次统计该sql语句出现的次数
+ QueryTimePct95=F("query_time_pct_95"), # 本次统计该sql语句95%耗时
+ QueryTimes=F("query_time_sum"), # 本次统计该sql语句花费的总时间(秒)
+ LockTimes=F("lock_time_sum"), # 本次统计该sql语句锁定总时长(秒)
+ ParseRowCounts=F("rows_examined_sum"), # 本次统计该sql语句解析总行数
+ ReturnRowCounts=F("rows_sent_sum"), # 本次统计该sql语句返回总行数
+ )
slow_sql_record_count = slow_sql_record_obj.count()
- slow_sql_record_list = slow_sql_record_obj.order_by('-' + sortName if 'desc'.__eq__(sortOrder) else sortName)[
- offset:limit].values('ExecutionStartTime', 'DBName', 'HostAddress',
- 'SQLText',
- 'TotalExecutionCounts', 'QueryTimePct95',
- 'QueryTimes', 'LockTimes', 'ParseRowCounts',
- 'ReturnRowCounts'
- )
+ slow_sql_record_list = slow_sql_record_obj.order_by(
+ "-" + sortName if "desc".__eq__(sortOrder) else sortName
+ )[offset:limit].values(
+ "ExecutionStartTime",
+ "DBName",
+ "HostAddress",
+ "SQLText",
+ "TotalExecutionCounts",
+ "QueryTimePct95",
+ "QueryTimes",
+ "LockTimes",
+ "ParseRowCounts",
+ "ReturnRowCounts",
+ )
# QuerySet 序列化
sql_slow_record = []
for SlowRecord in slow_sql_record_list:
- SlowRecord['QueryTimePct95'] = round(SlowRecord['QueryTimePct95'], 6)
- SlowRecord['QueryTimes'] = round(SlowRecord['QueryTimes'], 6)
- SlowRecord['LockTimes'] = round(SlowRecord['LockTimes'], 6)
+ SlowRecord["QueryTimePct95"] = round(SlowRecord["QueryTimePct95"], 6)
+ SlowRecord["QueryTimes"] = round(SlowRecord["QueryTimes"], 6)
+ SlowRecord["LockTimes"] = round(SlowRecord["LockTimes"], 6)
sql_slow_record.append(SlowRecord)
result = {"total": slow_sql_record_count, "rows": sql_slow_record}
# 返回查询结果
- return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
- content_type='application/json')
+ return HttpResponse(
+ json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
+ content_type="application/json",
+ )
@cache_page(60 * 10)
def report(request):
"""返回慢SQL历史趋势"""
- checksum = request.GET.get('checksum')
+ checksum = request.GET.get("checksum")
cnt_data = ChartDao().slow_query_review_history_by_cnt(checksum)
pct_data = ChartDao().slow_query_review_history_by_pct_95_time(checksum)
- cnt_x_data = [row[1] for row in cnt_data['rows']]
- cnt_y_data = [int(row[0]) for row in cnt_data['rows']]
- pct_y_data = [str(row[0]) for row in pct_data['rows']]
- line = Line(init_opts=opts.InitOpts(width='800', height='380px'))
+ cnt_x_data = [row[1] for row in cnt_data["rows"]]
+ cnt_y_data = [int(row[0]) for row in cnt_data["rows"]]
+ pct_y_data = [str(row[0]) for row in pct_data["rows"]]
+ line = Line(init_opts=opts.InitOpts(width="800", height="380px"))
line.add_xaxis(cnt_x_data)
- line.add_yaxis("慢查次数", cnt_y_data, is_smooth=True,
- markline_opts=opts.MarkLineOpts(data=[opts.MarkLineItem(type_="max", name='最大值'),
- opts.MarkLineItem(type_="average", name='平均值')]))
+ line.add_yaxis(
+ "慢查次数",
+ cnt_y_data,
+ is_smooth=True,
+ markline_opts=opts.MarkLineOpts(
+ data=[
+ opts.MarkLineItem(type_="max", name="最大值"),
+ opts.MarkLineItem(type_="average", name="平均值"),
+ ]
+ ),
+ )
line.add_yaxis("慢查时长(95%)", pct_y_data, is_smooth=True, is_symbol_show=False)
- line.set_series_opts(areastyle_opts=opts.AreaStyleOpts(opacity=0.5, ))
- line.set_global_opts(title_opts=opts.TitleOpts(title='SQL历史趋势'),
- legend_opts=opts.LegendOpts(selected_mode='single'),
- xaxis_opts=opts.AxisOpts(
- axistick_opts=opts.AxisTickOpts(is_align_with_label=True),
- is_scale=False,
- boundary_gap=False,
- ), )
-
- result = {"status": 0, "msg": '', "data": line.render_embed()}
- return HttpResponse(json.dumps(result), content_type='application/json')
+ line.set_series_opts(
+ areastyle_opts=opts.AreaStyleOpts(
+ opacity=0.5,
+ )
+ )
+ line.set_global_opts(
+ title_opts=opts.TitleOpts(title="SQL历史趋势"),
+ legend_opts=opts.LegendOpts(selected_mode="single"),
+ xaxis_opts=opts.AxisOpts(
+ axistick_opts=opts.AxisTickOpts(is_align_with_label=True),
+ is_scale=False,
+ boundary_gap=False,
+ ),
+ )
+
+ result = {"status": 0, "msg": "", "data": line.render_embed()}
+ return HttpResponse(json.dumps(result), content_type="application/json")
diff --git a/sql/sql_analyze.py b/sql/sql_analyze.py
index c61fa9ece2..15b7bfdc1c 100644
--- a/sql/sql_analyze.py
+++ b/sql/sql_analyze.py
@@ -16,66 +16,76 @@
from common.utils.extend_json_encoder import ExtendJSONEncoder
from .models import Instance
-__author__ = 'hhyo'
+__author__ = "hhyo"
-@permission_required('sql.sql_analyze', raise_exception=True)
+@permission_required("sql.sql_analyze", raise_exception=True)
def generate(request):
"""
解析上传文件为SQL列表
:param request:
:return:
"""
- text = request.POST.get('text')
+ text = request.POST.get("text")
if text is None:
result = {"total": 0, "rows": []}
else:
rows = generate_sql(text)
result = {"total": len(rows), "rows": rows}
- return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
- content_type='application/json')
+ return HttpResponse(
+ json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
+ content_type="application/json",
+ )
-@permission_required('sql.sql_analyze', raise_exception=True)
+@permission_required("sql.sql_analyze", raise_exception=True)
def analyze(request):
"""
利用soar分析SQL
:param request:
:return:
"""
- text = request.POST.get('text')
- instance_name = request.POST.get('instance_name')
- db_name = request.POST.get('db_name')
+ text = request.POST.get("text")
+ instance_name = request.POST.get("instance_name")
+ db_name = request.POST.get("db_name")
if not text:
result = {"total": 0, "rows": []}
else:
soar = Soar()
- if instance_name != '' and db_name != '':
+ if instance_name != "" and db_name != "":
try:
- instance_info = user_instances(request.user, db_type=['mysql']).get(instance_name=instance_name)
+ instance_info = user_instances(request.user, db_type=["mysql"]).get(
+ instance_name=instance_name
+ )
except Instance.DoesNotExist:
- return JsonResponse({'status': 1, 'msg': '你所在组未关联该实例!', 'data': []})
- soar_test_dsn = SysConfig().get('soar_test_dsn')
+ return JsonResponse({"status": 1, "msg": "你所在组未关联该实例!", "data": []})
+ soar_test_dsn = SysConfig().get("soar_test_dsn")
# 获取实例连接信息
- online_dsn = "{user}:{pwd}@{host}:{port}/{db}".format(user=instance_info.user,
- pwd=instance_info.password,
- host=instance_info.host,
- port=instance_info.port,
- db=db_name)
+ online_dsn = "{user}:{pwd}@{host}:{port}/{db}".format(
+ user=instance_info.user,
+ pwd=instance_info.password,
+ host=instance_info.host,
+ port=instance_info.port,
+ db=db_name,
+ )
else:
- online_dsn = ''
- soar_test_dsn = ''
- args = {"report-type": "markdown",
- "query": '',
- "online-dsn": online_dsn,
- "test-dsn": soar_test_dsn,
- "allow-online-as-test": "false"}
+ online_dsn = ""
+ soar_test_dsn = ""
+ args = {
+ "report-type": "markdown",
+ "query": "",
+ "online-dsn": online_dsn,
+ "test-dsn": soar_test_dsn,
+ "allow-online-as-test": "false",
+ }
rows = generate_sql(text)
for row in rows:
- args['query'] = row['sql']
+ args["query"] = row["sql"]
cmd_args = soar.generate_args2cmd(args=args, shell=True)
stdout, stderr = soar.execute_cmd(cmd_args, shell=True).communicate()
- row['report'] = stdout if stdout else stderr
+ row["report"] = stdout if stdout else stderr
result = {"total": len(rows), "rows": rows}
- return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
- content_type='application/json')
+ return HttpResponse(
+ json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
+ content_type="application/json",
+ )
diff --git a/sql/sql_optimize.py b/sql/sql_optimize.py
index c3cd89404e..75488f2c43 100644
--- a/sql/sql_optimize.py
+++ b/sql/sql_optimize.py
@@ -21,166 +21,182 @@
from sql.sql_tuning import SqlTuning
from sql.utils.resource_group import user_instances
-__author__ = 'hhyo'
+__author__ = "hhyo"
-@permission_required('sql.optimize_sqladvisor', raise_exception=True)
+@permission_required("sql.optimize_sqladvisor", raise_exception=True)
def optimize_sqladvisor(request):
- sql_content = request.POST.get('sql_content')
- instance_name = request.POST.get('instance_name')
- db_name = request.POST.get('db_name')
- verbose = request.POST.get('verbose', 1)
- result = {'status': 0, 'msg': 'ok', 'data': []}
+ sql_content = request.POST.get("sql_content")
+ instance_name = request.POST.get("instance_name")
+ db_name = request.POST.get("db_name")
+ verbose = request.POST.get("verbose", 1)
+ result = {"status": 0, "msg": "ok", "data": []}
# 服务器端参数验证
if sql_content is None or instance_name is None:
- result['status'] = 1
- result['msg'] = '页面提交参数可能为空'
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result["status"] = 1
+ result["msg"] = "页面提交参数可能为空"
+ return HttpResponse(json.dumps(result), content_type="application/json")
try:
- instance_info = user_instances(request.user, db_type=['mysql']).get(instance_name=instance_name)
+ instance_info = user_instances(request.user, db_type=["mysql"]).get(
+ instance_name=instance_name
+ )
except Instance.DoesNotExist:
- result['status'] = 1
- result['msg'] = '你所在组未关联该实例!'
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result["status"] = 1
+ result["msg"] = "你所在组未关联该实例!"
+ return HttpResponse(json.dumps(result), content_type="application/json")
# 检查sqladvisor程序路径
- sqladvisor_path = SysConfig().get('sqladvisor')
+ sqladvisor_path = SysConfig().get("sqladvisor")
if sqladvisor_path is None:
- result['status'] = 1
- result['msg'] = '请配置SQLAdvisor路径!'
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result["status"] = 1
+ result["msg"] = "请配置SQLAdvisor路径!"
+ return HttpResponse(json.dumps(result), content_type="application/json")
# 提交给sqladvisor获取分析报告
sqladvisor = SQLAdvisor()
# 准备参数
- args = {"h": instance_info.host,
- "P": instance_info.port,
- "u": instance_info.user,
- "p": instance_info.password,
- "d": db_name,
- "v": verbose,
- "q": sql_content.strip()
- }
+ args = {
+ "h": instance_info.host,
+ "P": instance_info.port,
+ "u": instance_info.user,
+ "p": instance_info.password,
+ "d": db_name,
+ "v": verbose,
+ "q": sql_content.strip(),
+ }
# 参数检查
args_check_result = sqladvisor.check_args(args)
- if args_check_result['status'] == 1:
- return HttpResponse(json.dumps(args_check_result), content_type='application/json')
+ if args_check_result["status"] == 1:
+ return HttpResponse(
+ json.dumps(args_check_result), content_type="application/json"
+ )
# 参数转换
cmd_args = sqladvisor.generate_args2cmd(args, shell=True)
# 执行命令
try:
stdout, stderr = sqladvisor.execute_cmd(cmd_args, shell=True).communicate()
- result['data'] = f'{stdout}{stderr}'
+ result["data"] = f"{stdout}{stderr}"
except RuntimeError as e:
- result['status'] = 1
- result['msg'] = str(e)
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result["status"] = 1
+ result["msg"] = str(e)
+ return HttpResponse(json.dumps(result), content_type="application/json")
-@permission_required('sql.optimize_soar', raise_exception=True)
+@permission_required("sql.optimize_soar", raise_exception=True)
def optimize_soar(request):
- instance_name = request.POST.get('instance_name')
- db_name = request.POST.get('db_name')
- sql = request.POST.get('sql')
- result = {'status': 0, 'msg': 'ok', 'data': []}
+ instance_name = request.POST.get("instance_name")
+ db_name = request.POST.get("db_name")
+ sql = request.POST.get("sql")
+ result = {"status": 0, "msg": "ok", "data": []}
# 服务器端参数验证
if not (instance_name and db_name and sql):
- result['status'] = 1
- result['msg'] = '页面提交参数可能为空'
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result["status"] = 1
+ result["msg"] = "页面提交参数可能为空"
+ return HttpResponse(json.dumps(result), content_type="application/json")
try:
- instance_info = user_instances(request.user, db_type=['mysql']).get(instance_name=instance_name)
+ instance_info = user_instances(request.user, db_type=["mysql"]).get(
+ instance_name=instance_name
+ )
except Exception:
- result['status'] = 1
- result['msg'] = '你所在组未关联该实例'
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result["status"] = 1
+ result["msg"] = "你所在组未关联该实例"
+ return HttpResponse(json.dumps(result), content_type="application/json")
# 检查测试实例的连接信息和soar程序路径
- soar_test_dsn = SysConfig().get('soar_test_dsn')
- soar_path = SysConfig().get('soar')
+ soar_test_dsn = SysConfig().get("soar_test_dsn")
+ soar_path = SysConfig().get("soar")
if not (soar_path and soar_test_dsn):
- result['status'] = 1
- result['msg'] = '请配置soar_path和test_dsn!'
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result["status"] = 1
+ result["msg"] = "请配置soar_path和test_dsn!"
+ return HttpResponse(json.dumps(result), content_type="application/json")
# 目标实例的连接信息
- online_dsn = '{user}:"{pwd}"@{host}:{port}/{db}'.format(user=instance_info.user,
- pwd=instance_info.password,
- host=instance_info.host,
- port=instance_info.port,
- db=db_name)
+ online_dsn = '{user}:"{pwd}"@{host}:{port}/{db}'.format(
+ user=instance_info.user,
+ pwd=instance_info.password,
+ host=instance_info.host,
+ port=instance_info.port,
+ db=db_name,
+ )
# 提交给soar获取分析报告
soar = Soar()
# 准备参数
- args = {"online-dsn": online_dsn,
- "test-dsn": soar_test_dsn,
- "allow-online-as-test": "false",
- "report-type": "markdown",
- "query": sql.strip()
- }
+ args = {
+ "online-dsn": online_dsn,
+ "test-dsn": soar_test_dsn,
+ "allow-online-as-test": "false",
+ "report-type": "markdown",
+ "query": sql.strip(),
+ }
# 参数检查
args_check_result = soar.check_args(args)
- if args_check_result['status'] == 1:
- return HttpResponse(json.dumps(args_check_result), content_type='application/json')
+ if args_check_result["status"] == 1:
+ return HttpResponse(
+ json.dumps(args_check_result), content_type="application/json"
+ )
# 参数转换
cmd_args = soar.generate_args2cmd(args, shell=True)
# 执行命令
try:
stdout, stderr = soar.execute_cmd(cmd_args, shell=True).communicate()
- result['data'] = stdout if stdout else stderr
+ result["data"] = stdout if stdout else stderr
except RuntimeError as e:
- result['status'] = 1
- result['msg'] = str(e)
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result["status"] = 1
+ result["msg"] = str(e)
+ return HttpResponse(json.dumps(result), content_type="application/json")
-@permission_required('sql.optimize_sqltuning', raise_exception=True)
+@permission_required("sql.optimize_sqltuning", raise_exception=True)
def optimize_sqltuning(request):
- instance_name = request.POST.get('instance_name')
- db_name = request.POST.get('db_name')
- sqltext = request.POST.get('sql_content')
- option = request.POST.getlist('option[]')
+ instance_name = request.POST.get("instance_name")
+ db_name = request.POST.get("db_name")
+ sqltext = request.POST.get("sql_content")
+ option = request.POST.getlist("option[]")
sqltext = sqlparse.format(sqltext, strip_comments=True)
sqltext = sqlparse.split(sqltext)[0]
if re.match(r"^select|^show|^explain", sqltext, re.I) is None:
- result = {'status': 1, 'msg': '只支持查询SQL!', 'data': []}
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result = {"status": 1, "msg": "只支持查询SQL!", "data": []}
+ return HttpResponse(json.dumps(result), content_type="application/json")
try:
user_instances(request.user).get(instance_name=instance_name)
except Instance.DoesNotExist:
- result = {'status': 1, 'msg': '你所在组未关联该实例!', 'data': []}
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result = {"status": 1, "msg": "你所在组未关联该实例!", "data": []}
+ return HttpResponse(json.dumps(result), content_type="application/json")
# escape
- db_name = MySQLdb.escape_string(db_name).decode('utf-8')
+ db_name = MySQLdb.escape_string(db_name).decode("utf-8")
- sql_tunning = SqlTuning(instance_name=instance_name, db_name=db_name, sqltext=sqltext)
- result = {'status': 0, 'msg': 'ok', 'data': {}}
- if 'sys_parm' in option:
+ sql_tunning = SqlTuning(
+ instance_name=instance_name, db_name=db_name, sqltext=sqltext
+ )
+ result = {"status": 0, "msg": "ok", "data": {}}
+ if "sys_parm" in option:
basic_information = sql_tunning.basic_information()
sys_parameter = sql_tunning.sys_parameter()
optimizer_switch = sql_tunning.optimizer_switch()
- result['data']['basic_information'] = basic_information
- result['data']['sys_parameter'] = sys_parameter
- result['data']['optimizer_switch'] = optimizer_switch
- if 'sql_plan' in option:
+ result["data"]["basic_information"] = basic_information
+ result["data"]["sys_parameter"] = sys_parameter
+ result["data"]["optimizer_switch"] = optimizer_switch
+ if "sql_plan" in option:
plan, optimizer_rewrite_sql = sql_tunning.sqlplan()
- result['data']['optimizer_rewrite_sql'] = optimizer_rewrite_sql
- result['data']['plan'] = plan
- if 'obj_stat' in option:
- result['data']['object_statistics'] = sql_tunning.object_statistics()
- if 'sql_profile' in option:
+ result["data"]["optimizer_rewrite_sql"] = optimizer_rewrite_sql
+ result["data"]["plan"] = plan
+ if "obj_stat" in option:
+ result["data"]["object_statistics"] = sql_tunning.object_statistics()
+ if "sql_profile" in option:
session_status = sql_tunning.exec_sql()
- result['data']['session_status'] = session_status
+ result["data"]["session_status"] = session_status
# 关闭连接
sql_tunning.engine.close()
- result['data']['sqltext'] = sqltext
- return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
- content_type='application/json')
+ result["data"]["sqltext"] = sqltext
+ return HttpResponse(
+ json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
+ content_type="application/json",
+ )
def explain(request):
@@ -189,46 +205,48 @@ def explain(request):
:param request:
:return:
"""
- sql_content = request.POST.get('sql_content')
- instance_name = request.POST.get('instance_name')
- db_name = request.POST.get('db_name')
- result = {'status': 0, 'msg': 'ok', 'data': []}
+ sql_content = request.POST.get("sql_content")
+ instance_name = request.POST.get("instance_name")
+ db_name = request.POST.get("db_name")
+ result = {"status": 0, "msg": "ok", "data": []}
# 服务器端参数验证
if sql_content is None or instance_name is None:
- result['status'] = 1
- result['msg'] = '页面提交参数可能为空'
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result["status"] = 1
+ result["msg"] = "页面提交参数可能为空"
+ return HttpResponse(json.dumps(result), content_type="application/json")
try:
instance = user_instances(request.user).get(instance_name=instance_name)
except Instance.DoesNotExist:
- result = {'status': 1, 'msg': '实例不存在', 'data': []}
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result = {"status": 1, "msg": "实例不存在", "data": []}
+ return HttpResponse(json.dumps(result), content_type="application/json")
# 删除注释语句,进行语法判断,执行第一条有效sql
sql_content = sqlparse.format(sql_content.strip(), strip_comments=True)
try:
sql_content = sqlparse.split(sql_content)[0]
except IndexError:
- result['status'] = 1
- result['msg'] = '没有有效的SQL语句'
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result["status"] = 1
+ result["msg"] = "没有有效的SQL语句"
+ return HttpResponse(json.dumps(result), content_type="application/json")
else:
# 过滤非explain的语句
if not re.match(r"^explain", sql_content, re.I):
- result['status'] = 1
- result['msg'] = '仅支持explain开头的语句,请检查'
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result["status"] = 1
+ result["msg"] = "仅支持explain开头的语句,请检查"
+ return HttpResponse(json.dumps(result), content_type="application/json")
# 执行获取执行计划语句
query_engine = get_engine(instance=instance)
sql_result = query_engine.query(str(db_name), sql_content).to_sep_dict()
- result['data'] = sql_result
+ result["data"] = sql_result
# 返回查询结果
- return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
- content_type='application/json')
+ return HttpResponse(
+ json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
+ content_type="application/json",
+ )
def optimize_sqltuningadvisor(request):
@@ -237,45 +255,47 @@ def optimize_sqltuningadvisor(request):
:param request:
:return:
"""
- sql_content = request.POST.get('sql_content')
- instance_name = request.POST.get('instance_name')
- db_name = request.POST.get('schema_name')
- result = {'status': 0, 'msg': 'ok', 'data': []}
+ sql_content = request.POST.get("sql_content")
+ instance_name = request.POST.get("instance_name")
+ db_name = request.POST.get("schema_name")
+ result = {"status": 0, "msg": "ok", "data": []}
# 服务器端参数验证
if sql_content is None or instance_name is None:
- result['status'] = 1
- result['msg'] = '页面提交参数可能为空'
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result["status"] = 1
+ result["msg"] = "页面提交参数可能为空"
+ return HttpResponse(json.dumps(result), content_type="application/json")
try:
instance = user_instances(request.user).get(instance_name=instance_name)
except Instance.DoesNotExist:
- result = {'status': 1, 'msg': '实例不存在', 'data': []}
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result = {"status": 1, "msg": "实例不存在", "data": []}
+ return HttpResponse(json.dumps(result), content_type="application/json")
# 不删除注释语句,已获取加hints的SQL优化建议,进行语法判断,执行第一条有效sql
sql_content = sqlparse.format(sql_content.strip(), strip_comments=False)
# 对单引号加转义符,支持plsql语法
- sql_content = sql_content.replace("'", "''");
+ sql_content = sql_content.replace("'", "''")
try:
sql_content = sqlparse.split(sql_content)[0]
except IndexError:
- result['status'] = 1
- result['msg'] = '没有有效的SQL语句'
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result["status"] = 1
+ result["msg"] = "没有有效的SQL语句"
+ return HttpResponse(json.dumps(result), content_type="application/json")
else:
# 过滤非Oracle语句
- if not instance.db_type == 'oracle':
- result['status'] = 1
- result['msg'] = 'SQLTuningAdvisor仅支持oracle数据库的检查'
- return HttpResponse(json.dumps(result), content_type='application/json')
+ if not instance.db_type == "oracle":
+ result["status"] = 1
+ result["msg"] = "SQLTuningAdvisor仅支持oracle数据库的检查"
+ return HttpResponse(json.dumps(result), content_type="application/json")
# 执行获取优化报告
query_engine = get_engine(instance=instance)
sql_result = query_engine.sqltuningadvisor(str(db_name), sql_content).to_sep_dict()
- result['data'] = sql_result
+ result["data"] = sql_result
# 返回查询结果
- return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
- content_type='application/json')
+ return HttpResponse(
+ json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
+ content_type="application/json",
+ )
diff --git a/sql/sql_tuning.py b/sql/sql_tuning.py
index 81251dea42..973406cba9 100644
--- a/sql/sql_tuning.py
+++ b/sql/sql_tuning.py
@@ -15,19 +15,21 @@ def __init__(self, instance_name, db_name, sqltext):
self.engine = query_engine
self.db_name = db_name
self.sqltext = sqltext
- self.sql_variable = '''
+ self.sql_variable = """
select
lower(variable_name),
variable_value
from performance_schema.global_variables
where upper(variable_name) in ('%s')
- order by variable_name;''' % ('\',\''.join(SQLTuning.SYS_PARM_FILTER))
- self.sql_optimizer_switch = '''
+ order by variable_name;""" % (
+ "','".join(SQLTuning.SYS_PARM_FILTER)
+ )
+ self.sql_optimizer_switch = """
select variable_value
from performance_schema.global_variables
where upper(variable_name) = 'OPTIMIZER_SWITCH';
- '''
- self.sql_table_info = '''
+ """
+ self.sql_table_info = """
select
table_name,
engine,
@@ -39,8 +41,8 @@ def __init__(self, instance_name, db_name, sqltext):
round((index_length) / 1024 / 1024, 2) as index_mb
from information_schema.tables
where table_schema = '%s' and table_name = '%s'
- '''
- self.sql_table_index = '''
+ """
+ self.sql_table_index = """
select
table_name,
index_name,
@@ -54,11 +56,11 @@ def __init__(self, instance_name, db_name, sqltext):
from information_schema.statistics
where table_schema = '%s' and table_name = '%s'
order by 1, 3;
- '''
+ """
def __extract_tables(self):
"""获取sql语句中的表名"""
- return [i['name'].strip('`') for i in extract_tables(self.sqltext)]
+ return [i["name"].strip("`") for i in extract_tables(self.sqltext)]
def basic_information(self):
return self.engine.query(sql="select @@version", close_conn=False).to_sep_dict()
@@ -67,7 +69,7 @@ def sys_parameter(self):
# 获取mysql版本信息
server_version = self.engine.server_version
if server_version < (5, 7, 0):
- sql = self.sql_variable.replace('performance_schema', 'information_schema')
+ sql = self.sql_variable.replace("performance_schema", "information_schema")
else:
sql = self.sql_variable
return self.engine.query(sql=sql, close_conn=False).to_sep_dict()
@@ -76,41 +78,57 @@ def optimizer_switch(self):
# 获取mysql版本信息
server_version = self.engine.server_version
if server_version < (5, 7, 0):
- sql = self.sql_optimizer_switch.replace('performance_schema', 'information_schema')
+ sql = self.sql_optimizer_switch.replace(
+ "performance_schema", "information_schema"
+ )
else:
sql = self.sql_optimizer_switch
return self.engine.query(sql=sql, close_conn=False).to_sep_dict()
def sqlplan(self):
- plan = self.engine.query(self.db_name, "explain " + self.sqltext, close_conn=False).to_sep_dict()
- optimizer_rewrite_sql = self.engine.query(sql="show warnings", close_conn=False).to_sep_dict()
+ plan = self.engine.query(
+ self.db_name, "explain " + self.sqltext, close_conn=False
+ ).to_sep_dict()
+ optimizer_rewrite_sql = self.engine.query(
+ sql="show warnings", close_conn=False
+ ).to_sep_dict()
return plan, optimizer_rewrite_sql
# 获取关联表信息存在缺陷,只能获取到一张表
def object_statistics(self):
object_statistics = []
for index, table_name in enumerate(self.__extract_tables()):
- object_statistics.append({
- "structure": self.engine.query(
- db_name=self.db_name, sql=f"show create table `{table_name}`;",
- close_conn=False).to_sep_dict(),
- "table_info": self.engine.query(
- sql=self.sql_table_info % (self.db_name, table_name),
- close_conn=False).to_sep_dict(),
- "index_info": self.engine.query(
- sql=self.sql_table_index % (self.db_name, table_name),
- close_conn=False).to_sep_dict()
- })
+ object_statistics.append(
+ {
+ "structure": self.engine.query(
+ db_name=self.db_name,
+ sql=f"show create table `{table_name}`;",
+ close_conn=False,
+ ).to_sep_dict(),
+ "table_info": self.engine.query(
+ sql=self.sql_table_info % (self.db_name, table_name),
+ close_conn=False,
+ ).to_sep_dict(),
+ "index_info": self.engine.query(
+ sql=self.sql_table_index % (self.db_name, table_name),
+ close_conn=False,
+ ).to_sep_dict(),
+ }
+ )
return object_statistics
def exec_sql(self):
- result = {"EXECUTE_TIME": 0,
- "BEFORE_STATUS": {'column_list': [], 'rows': []},
- "AFTER_STATUS": {'column_list': [], 'rows': []},
- "SESSION_STATUS(DIFFERENT)": {'column_list': ['status_name', 'before', 'after', 'diff'], 'rows': []},
- "PROFILING_DETAIL": {'column_list': [], 'rows': []},
- "PROFILING_SUMMARY": {'column_list': [], 'rows': []}
- }
+ result = {
+ "EXECUTE_TIME": 0,
+ "BEFORE_STATUS": {"column_list": [], "rows": []},
+ "AFTER_STATUS": {"column_list": [], "rows": []},
+ "SESSION_STATUS(DIFFERENT)": {
+ "column_list": ["status_name", "before", "after", "diff"],
+ "rows": [],
+ },
+ "PROFILING_DETAIL": {"column_list": [], "rows": []},
+ "PROFILING_SUMMARY": {"column_list": [], "rows": []},
+ }
sql_profiling = """select concat(upper(left(variable_name,1)),
substring(lower(variable_name),
2,
@@ -121,43 +139,60 @@ def exec_sql(self):
# 获取mysql版本信息
server_version = self.engine.server_version
if server_version < (5, 7, 0):
- sql = sql_profiling.replace('performance_schema', 'information_schema')
+ sql = sql_profiling.replace("performance_schema", "information_schema")
else:
sql = sql_profiling
self.engine.query(sql="set profiling=1", close_conn=False).to_sep_dict()
- records = self.engine.query(sql="select ifnull(max(query_id),0) from INFORMATION_SCHEMA.PROFILING",
- close_conn=False).to_sep_dict()
- query_id = records['rows'][0][0] + 3 # skip next sql
+ records = self.engine.query(
+ sql="select ifnull(max(query_id),0) from INFORMATION_SCHEMA.PROFILING",
+ close_conn=False,
+ ).to_sep_dict()
+ query_id = records["rows"][0][0] + 3 # skip next sql
# 获取执行前信息
- result['BEFORE_STATUS'] = self.engine.query(sql=sql, close_conn=False).to_sep_dict()
+ result["BEFORE_STATUS"] = self.engine.query(
+ sql=sql, close_conn=False
+ ).to_sep_dict()
# 执行查询语句,统计执行时间
t_start = time.time()
self.engine.query(sql=self.sqltext, close_conn=False).to_sep_dict()
t_end = time.time()
cost_time = "%5s" % "{:.4f}".format(t_end - t_start)
- result['EXECUTE_TIME'] = cost_time
+ result["EXECUTE_TIME"] = cost_time
# 获取执行后信息
- result['AFTER_STATUS'] = self.engine.query(sql=sql, close_conn=False).to_sep_dict()
+ result["AFTER_STATUS"] = self.engine.query(
+ sql=sql, close_conn=False
+ ).to_sep_dict()
# 获取PROFILING_DETAIL信息
- result['PROFILING_DETAIL'] = self.engine.query(
- sql="select STATE,DURATION,CPU_USER,CPU_SYSTEM,BLOCK_OPS_IN,BLOCK_OPS_OUT ,MESSAGES_SENT ,MESSAGES_RECEIVED ,PAGE_FAULTS_MAJOR ,PAGE_FAULTS_MINOR ,SWAPS from INFORMATION_SCHEMA.PROFILING where query_id=" + str(
- query_id) + " order by seq", close_conn=False).to_sep_dict()
- result['PROFILING_SUMMARY'] = self.engine.query(
- sql="SELECT STATE,SUM(DURATION) AS Total_R,ROUND(100*SUM(DURATION)/(SELECT SUM(DURATION) FROM INFORMATION_SCHEMA.PROFILING WHERE QUERY_ID=" + str(
- query_id) + "),2) AS Pct_R,COUNT(*) AS Calls,SUM(DURATION)/COUNT(*) AS R_Call FROM INFORMATION_SCHEMA.PROFILING WHERE QUERY_ID=" + str(
- query_id) + " GROUP BY STATE ORDER BY Total_R DESC", close_conn=False).to_sep_dict()
+ result["PROFILING_DETAIL"] = self.engine.query(
+ sql="select STATE,DURATION,CPU_USER,CPU_SYSTEM,BLOCK_OPS_IN,BLOCK_OPS_OUT ,MESSAGES_SENT ,MESSAGES_RECEIVED ,PAGE_FAULTS_MAJOR ,PAGE_FAULTS_MINOR ,SWAPS from INFORMATION_SCHEMA.PROFILING where query_id="
+ + str(query_id)
+ + " order by seq",
+ close_conn=False,
+ ).to_sep_dict()
+ result["PROFILING_SUMMARY"] = self.engine.query(
+ sql="SELECT STATE,SUM(DURATION) AS Total_R,ROUND(100*SUM(DURATION)/(SELECT SUM(DURATION) FROM INFORMATION_SCHEMA.PROFILING WHERE QUERY_ID="
+ + str(query_id)
+ + "),2) AS Pct_R,COUNT(*) AS Calls,SUM(DURATION)/COUNT(*) AS R_Call FROM INFORMATION_SCHEMA.PROFILING WHERE QUERY_ID="
+ + str(query_id)
+ + " GROUP BY STATE ORDER BY Total_R DESC",
+ close_conn=False,
+ ).to_sep_dict()
# 处理执行前后对比信息
- before_status_rows = [list(item) for item in result['BEFORE_STATUS']['rows']]
- after_status_rows = [list(item) for item in result['AFTER_STATUS']['rows']]
+ before_status_rows = [list(item) for item in result["BEFORE_STATUS"]["rows"]]
+ after_status_rows = [list(item) for item in result["AFTER_STATUS"]["rows"]]
for index, item in enumerate(before_status_rows):
if before_status_rows[index][1] != after_status_rows[index][1]:
before_status_rows[index].append(after_status_rows[index][1])
before_status_rows[index].append(
- str(float(after_status_rows[index][1]) - float(before_status_rows[index][1])))
+ str(
+ float(after_status_rows[index][1])
+ - float(before_status_rows[index][1])
+ )
+ )
diff_rows = [item for item in before_status_rows if len(item) == 4]
- result['SESSION_STATUS(DIFFERENT)']['rows'] = diff_rows
+ result["SESSION_STATUS(DIFFERENT)"]["rows"] = diff_rows
return result
diff --git a/sql/sql_workflow.py b/sql/sql_workflow.py
index 4a38f38545..322c4d67c1 100644
--- a/sql/sql_workflow.py
+++ b/sql/sql_workflow.py
@@ -21,172 +21,204 @@
from sql.models import ResourceGroup
from sql.utils.resource_group import user_groups, user_instances
from sql.utils.tasks import add_sql_schedule, del_schedule
-from sql.utils.sql_review import can_timingtask, can_cancel, can_execute, on_correct_time_period, can_view, can_rollback
+from sql.utils.sql_review import (
+ can_timingtask,
+ can_cancel,
+ can_execute,
+ on_correct_time_period,
+ can_view,
+ can_rollback,
+)
from sql.utils.workflow_audit import Audit
from .models import SqlWorkflow, SqlWorkflowContent, Instance
from django_q.tasks import async_task
from sql.engines import get_engine
-logger = logging.getLogger('default')
+logger = logging.getLogger("default")
-@permission_required('sql.menu_sqlworkflow', raise_exception=True)
+@permission_required("sql.menu_sqlworkflow", raise_exception=True)
def sql_workflow_list(request):
return _sql_workflow_list(request)
-@permission_required('sql.audit_user', raise_exception=True)
+
+@permission_required("sql.audit_user", raise_exception=True)
def sql_workflow_list_audit(request):
return _sql_workflow_list(request)
+
def _sql_workflow_list(request):
"""
获取审核列表
:param request:
:return:
"""
- nav_status = request.POST.get('navStatus')
- instance_id = request.POST.get('instance_id')
- resource_group_id = request.POST.get('group_id')
- start_date = request.POST.get('start_date')
- end_date = request.POST.get('end_date')
- limit = int(request.POST.get('limit',0))
- offset = int(request.POST.get('offset',0))
+ nav_status = request.POST.get("navStatus")
+ instance_id = request.POST.get("instance_id")
+ resource_group_id = request.POST.get("group_id")
+ start_date = request.POST.get("start_date")
+ end_date = request.POST.get("end_date")
+ limit = int(request.POST.get("limit", 0))
+ offset = int(request.POST.get("offset", 0))
limit = offset + limit
limit = limit if limit else None
- search = request.POST.get('search')
+ search = request.POST.get("search")
user = request.user
# 组合筛选项
filter_dict = dict()
# 工单状态
if nav_status:
- filter_dict['status'] = nav_status
+ filter_dict["status"] = nav_status
# 实例
if instance_id:
- filter_dict['instance_id'] = instance_id
+ filter_dict["instance_id"] = instance_id
# 资源组
if resource_group_id:
- filter_dict['group_id'] = resource_group_id
+ filter_dict["group_id"] = resource_group_id
# 时间
if start_date and end_date:
- end_date = datetime.datetime.strptime(end_date, '%Y-%m-%d') + datetime.timedelta(days=1)
- filter_dict['create_time__range'] = (start_date, end_date)
+ end_date = datetime.datetime.strptime(
+ end_date, "%Y-%m-%d"
+ ) + datetime.timedelta(days=1)
+ filter_dict["create_time__range"] = (start_date, end_date)
# 管理员,审计员,可查看所有工单
- if user.is_superuser or user.has_perm('sql.audit_user'):
+ if user.is_superuser or user.has_perm("sql.audit_user"):
pass
# 非管理员,拥有审核权限、资源组粒度执行权限的,可以查看组内所有工单
- elif user.has_perm('sql.sql_review') or user.has_perm('sql.sql_execute_for_resource_group'):
+ elif user.has_perm("sql.sql_review") or user.has_perm(
+ "sql.sql_execute_for_resource_group"
+ ):
# 先获取用户所在资源组列表
group_list = user_groups(user)
group_ids = [group.group_id for group in group_list]
- filter_dict['group_id__in'] = group_ids
+ filter_dict["group_id__in"] = group_ids
# 其他人只能查看自己提交的工单
else:
- filter_dict['engineer'] = user.username
+ filter_dict["engineer"] = user.username
# 过滤组合筛选项
workflow = SqlWorkflow.objects.filter(**filter_dict)
# 过滤搜索项,模糊检索项包括提交人名称、工单名
if search:
- workflow = workflow.filter(Q(engineer_display__icontains=search) | Q(workflow_name__icontains=search))
+ workflow = workflow.filter(
+ Q(engineer_display__icontains=search) | Q(workflow_name__icontains=search)
+ )
count = workflow.count()
- workflow_list = workflow.order_by('-create_time')[offset:limit].values(
- "id", "workflow_name", "engineer_display",
- "status", "is_backup", "create_time",
- "instance__instance_name", "db_name",
- "group_name", "syntax_type")
+ workflow_list = workflow.order_by("-create_time")[offset:limit].values(
+ "id",
+ "workflow_name",
+ "engineer_display",
+ "status",
+ "is_backup",
+ "create_time",
+ "instance__instance_name",
+ "db_name",
+ "group_name",
+ "syntax_type",
+ )
# QuerySet 序列化
rows = [row for row in workflow_list]
result = {"total": count, "rows": rows}
# 返回查询结果
- return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
- content_type='application/json')
+ return HttpResponse(
+ json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
+ content_type="application/json",
+ )
-@permission_required('sql.sql_submit', raise_exception=True)
+@permission_required("sql.sql_submit", raise_exception=True)
def check(request):
"""SQL检测按钮, 此处没有产生工单"""
- sql_content = request.POST.get('sql_content')
- instance_name = request.POST.get('instance_name')
+ sql_content = request.POST.get("sql_content")
+ instance_name = request.POST.get("instance_name")
instance = Instance.objects.get(instance_name=instance_name)
- db_name = request.POST.get('db_name')
+ db_name = request.POST.get("db_name")
- result = {'status': 0, 'msg': 'ok', 'data': {}}
+ result = {"status": 0, "msg": "ok", "data": {}}
# 服务器端参数验证
if sql_content is None or instance_name is None or db_name is None:
- result['status'] = 1
- result['msg'] = '页面提交参数可能为空'
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result["status"] = 1
+ result["msg"] = "页面提交参数可能为空"
+ return HttpResponse(json.dumps(result), content_type="application/json")
# 交给engine进行检测
try:
check_engine = get_engine(instance=instance)
- check_result = check_engine.execute_check(db_name=db_name, sql=sql_content.strip())
+ check_result = check_engine.execute_check(
+ db_name=db_name, sql=sql_content.strip()
+ )
except Exception as e:
- result['status'] = 1
- result['msg'] = str(e)
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result["status"] = 1
+ result["msg"] = str(e)
+ return HttpResponse(json.dumps(result), content_type="application/json")
# 处理检测结果
- result['data']['rows'] = check_result.to_dict()
- result['data']['CheckWarningCount'] = check_result.warning_count
- result['data']['CheckErrorCount'] = check_result.error_count
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result["data"]["rows"] = check_result.to_dict()
+ result["data"]["CheckWarningCount"] = check_result.warning_count
+ result["data"]["CheckErrorCount"] = check_result.error_count
+ return HttpResponse(json.dumps(result), content_type="application/json")
-@permission_required('sql.sql_submit', raise_exception=True)
+@permission_required("sql.sql_submit", raise_exception=True)
def submit(request):
"""正式提交SQL, 此处生成工单"""
- sql_content = request.POST.get('sql_content').strip()
- workflow_title = request.POST.get('workflow_name')
- demand_url = request.POST.get('demand_url', '')
+ sql_content = request.POST.get("sql_content").strip()
+ workflow_title = request.POST.get("workflow_name")
+ demand_url = request.POST.get("demand_url", "")
# 检查用户是否有权限涉及到资源组等, 比较复杂, 可以把检查权限改成一个独立的方法
- group_name = request.POST.get('group_name')
+ group_name = request.POST.get("group_name")
group_id = ResourceGroup.objects.get(group_name=group_name).group_id
- instance_name = request.POST.get('instance_name')
+ instance_name = request.POST.get("instance_name")
instance = Instance.objects.get(instance_name=instance_name)
- db_name = request.POST.get('db_name')
- is_backup = True if request.POST.get('is_backup') == 'True' else False
- cc_users = request.POST.getlist('cc_users')
- run_date_start = request.POST.get('run_date_start')
- run_date_end = request.POST.get('run_date_end')
+ db_name = request.POST.get("db_name")
+ is_backup = True if request.POST.get("is_backup") == "True" else False
+ cc_users = request.POST.getlist("cc_users")
+ run_date_start = request.POST.get("run_date_start")
+ run_date_end = request.POST.get("run_date_end")
# 服务器端参数验证
if None in [sql_content, db_name, instance_name, db_name, is_backup, demand_url]:
- context = {'errMsg': '页面提交参数可能为空'}
- return render(request, 'error.html', context)
+ context = {"errMsg": "页面提交参数可能为空"}
+ return render(request, "error.html", context)
# 验证组权限(用户是否在该组、该组是否有指定实例)
try:
- user_instances(request.user, tag_codes=['can_write']).get(instance_name=instance_name)
+ user_instances(request.user, tag_codes=["can_write"]).get(
+ instance_name=instance_name
+ )
except instance.DoesNotExist:
- context = {'errMsg': '你所在组未关联该实例!'}
- return render(request, 'error.html', context)
+ context = {"errMsg": "你所在组未关联该实例!"}
+ return render(request, "error.html", context)
# 再次交给engine进行检测,防止绕过
try:
check_engine = get_engine(instance=instance)
- check_result = check_engine.execute_check(db_name=db_name, sql=sql_content.strip())
+ check_result = check_engine.execute_check(
+ db_name=db_name, sql=sql_content.strip()
+ )
except Exception as e:
- context = {'errMsg': str(e)}
- return render(request, 'error.html', context)
+ context = {"errMsg": str(e)}
+ return render(request, "error.html", context)
# 未开启备份选项,并且engine支持备份,强制设置备份
sys_config = SysConfig()
- if not sys_config.get('enable_backup_switch') and check_engine.auto_backup:
+ if not sys_config.get("enable_backup_switch") and check_engine.auto_backup:
is_backup = True
# 按照系统配置确定是自动驳回还是放行
- auto_review_wrong = sys_config.get('auto_review_wrong', '') # 1表示出现警告就驳回,2和空表示出现错误才驳回
- workflow_status = 'workflow_manreviewing'
- if check_result.warning_count > 0 and auto_review_wrong == '1':
- workflow_status = 'workflow_autoreviewwrong'
- elif check_result.error_count > 0 and auto_review_wrong in ('', '1', '2'):
- workflow_status = 'workflow_autoreviewwrong'
+ auto_review_wrong = sys_config.get(
+ "auto_review_wrong", ""
+ ) # 1表示出现警告就驳回,2和空表示出现错误才驳回
+ workflow_status = "workflow_manreviewing"
+ if check_result.warning_count > 0 and auto_review_wrong == "1":
+ workflow_status = "workflow_autoreviewwrong"
+ elif check_result.error_count > 0 and auto_review_wrong in ("", "1", "2"):
+ workflow_status = "workflow_autoreviewwrong"
# 调用工作流生成工单
# 使用事务保持数据一致性
@@ -200,7 +232,9 @@ def submit(request):
group_name=group_name,
engineer=request.user.username,
engineer_display=request.user.display,
- audit_auth_groups=Audit.settings(group_id, WorkflowDict.workflow_type['sqlreview']),
+ audit_auth_groups=Audit.settings(
+ group_id, WorkflowDict.workflow_type["sqlreview"]
+ ),
status=workflow_status,
is_backup=is_backup,
instance=instance,
@@ -209,44 +243,55 @@ def submit(request):
syntax_type=check_result.syntax_type,
create_time=timezone.now(),
run_date_start=run_date_start or None,
- run_date_end=run_date_end or None
+ run_date_end=run_date_end or None,
+ )
+ SqlWorkflowContent.objects.create(
+ workflow=sql_workflow,
+ sql_content=sql_content,
+ review_content=check_result.json(),
+ execute_result="",
)
- SqlWorkflowContent.objects.create(workflow=sql_workflow,
- sql_content=sql_content,
- review_content=check_result.json(),
- execute_result=''
- )
workflow_id = sql_workflow.id
# 自动审核通过了,才调用工作流
- if workflow_status == 'workflow_manreviewing':
+ if workflow_status == "workflow_manreviewing":
# 调用工作流插入审核信息, SQL上线权限申请workflow_type=2
- Audit.add(WorkflowDict.workflow_type['sqlreview'], workflow_id)
+ Audit.add(WorkflowDict.workflow_type["sqlreview"], workflow_id)
except Exception as msg:
logger.error(f"提交工单报错,错误信息:{traceback.format_exc()}")
- context = {'errMsg': msg}
+ context = {"errMsg": msg}
logger.error(traceback.format_exc())
- return render(request, 'error.html', context)
+ return render(request, "error.html", context)
else:
# 自动审核通过且开启了Apply阶段通知参数才发送消息通知
- is_notified = 'Apply' in sys_config.get('notify_phase_control').split(',') \
- if sys_config.get('notify_phase_control') else True
- if workflow_status == 'workflow_manreviewing' and is_notified:
+ is_notified = (
+ "Apply" in sys_config.get("notify_phase_control").split(",")
+ if sys_config.get("notify_phase_control")
+ else True
+ )
+ if workflow_status == "workflow_manreviewing" and is_notified:
# 获取审核信息
- audit_id = Audit.detail_by_workflow_id(workflow_id=workflow_id,
- workflow_type=WorkflowDict.workflow_type['sqlreview']).audit_id
- async_task(notify_for_audit, audit_id=audit_id, cc_users=cc_users, timeout=60,
- task_name=f'sqlreview-submit-{workflow_id}')
+ audit_id = Audit.detail_by_workflow_id(
+ workflow_id=workflow_id,
+ workflow_type=WorkflowDict.workflow_type["sqlreview"],
+ ).audit_id
+ async_task(
+ notify_for_audit,
+ audit_id=audit_id,
+ cc_users=cc_users,
+ timeout=60,
+ task_name=f"sqlreview-submit-{workflow_id}",
+ )
- return HttpResponseRedirect(reverse('sql:detail', args=(workflow_id,)))
+ return HttpResponseRedirect(reverse("sql:detail", args=(workflow_id,)))
def detail_content(request):
"""获取工单内容"""
- workflow_id = request.GET.get('workflow_id')
+ workflow_id = request.GET.get("workflow_id")
workflow_detail = get_object_or_404(SqlWorkflow, pk=workflow_id)
if not can_view(request.user, workflow_id):
raise PermissionDenied
- if workflow_detail.status in ['workflow_finish', 'workflow_exception']:
+ if workflow_detail.status in ["workflow_finish", "workflow_exception"]:
rows = workflow_detail.sqlworkflowcontent.execute_result
else:
rows = workflow_detail.sqlworkflowcontent.review_content
@@ -262,32 +307,34 @@ def detail_content(request):
review_result.rows += [ReviewResult(inception_result=r)]
rows = review_result.json()
except IndexError:
- review_result.rows += [ReviewResult(
- id=1,
- sql=workflow_detail.sqlworkflowcontent.sql_content,
- errormessage="Json decode failed."
- "执行结果Json解析失败, 请联系管理员"
- )]
+ review_result.rows += [
+ ReviewResult(
+ id=1,
+ sql=workflow_detail.sqlworkflowcontent.sql_content,
+ errormessage="Json decode failed." "执行结果Json解析失败, 请联系管理员",
+ )
+ ]
rows = review_result.json()
except json.decoder.JSONDecodeError:
- review_result.rows += [ReviewResult(
- id=1,
- sql=workflow_detail.sqlworkflowcontent.sql_content,
- # 迫于无法单元测试这里加上英文报错信息
- errormessage="Json decode failed."
- "执行结果Json解析失败, 请联系管理员"
- )]
+ review_result.rows += [
+ ReviewResult(
+ id=1,
+ sql=workflow_detail.sqlworkflowcontent.sql_content,
+ # 迫于无法单元测试这里加上英文报错信息
+ errormessage="Json decode failed." "执行结果Json解析失败, 请联系管理员",
+ )
+ ]
rows = review_result.json()
else:
rows = workflow_detail.sqlworkflowcontent.review_content
result = {"rows": json.loads(rows)}
- return HttpResponse(json.dumps(result), content_type='application/json')
+ return HttpResponse(json.dumps(result), content_type="application/json")
def backup_sql(request):
"""获取回滚语句"""
- workflow_id = request.GET.get('workflow_id')
+ workflow_id = request.GET.get("workflow_id")
if not can_rollback(request.user, workflow_id):
raise PermissionDenied
workflow = get_object_or_404(SqlWorkflow, pk=workflow_id)
@@ -297,89 +344,109 @@ def backup_sql(request):
list_backup_sql = query_engine.get_rollback(workflow=workflow)
except Exception as msg:
logger.error(traceback.format_exc())
- return JsonResponse({'status': 1, 'msg': f'{msg}', 'rows': []})
+ return JsonResponse({"status": 1, "msg": f"{msg}", "rows": []})
- result = {'status': 0, 'msg': '', 'rows': list_backup_sql}
- return HttpResponse(json.dumps(result), content_type='application/json')
+ result = {"status": 0, "msg": "", "rows": list_backup_sql}
+ return HttpResponse(json.dumps(result), content_type="application/json")
-@permission_required('sql.sql_review', raise_exception=True)
+@permission_required("sql.sql_review", raise_exception=True)
def alter_run_date(request):
"""
审核人修改可执行时间
:param request:
:return:
"""
- workflow_id = int(request.POST.get('workflow_id', 0))
- run_date_start = request.POST.get('run_date_start')
- run_date_end = request.POST.get('run_date_end')
+ workflow_id = int(request.POST.get("workflow_id", 0))
+ run_date_start = request.POST.get("run_date_start")
+ run_date_end = request.POST.get("run_date_end")
if workflow_id == 0:
- context = {'errMsg': 'workflow_id参数为空.'}
- return render(request, 'error.html', context)
+ context = {"errMsg": "workflow_id参数为空."}
+ return render(request, "error.html", context)
user = request.user
if Audit.can_review(user, workflow_id, 2) is False:
- context = {'errMsg': '你无权操作当前工单!'}
- return render(request, 'error.html', context)
+ context = {"errMsg": "你无权操作当前工单!"}
+ return render(request, "error.html", context)
try:
# 存进数据库里
- SqlWorkflow(id=workflow_id,
- run_date_start=run_date_start or None,
- run_date_end=run_date_end or None
- ).save(update_fields=['run_date_start', 'run_date_end'])
+ SqlWorkflow(
+ id=workflow_id,
+ run_date_start=run_date_start or None,
+ run_date_end=run_date_end or None,
+ ).save(update_fields=["run_date_start", "run_date_end"])
except Exception as msg:
- context = {'errMsg': msg}
- return render(request, 'error.html', context)
+ context = {"errMsg": msg}
+ return render(request, "error.html", context)
- return HttpResponseRedirect(reverse('sql:detail', args=(workflow_id,)))
+ return HttpResponseRedirect(reverse("sql:detail", args=(workflow_id,)))
-@permission_required('sql.sql_review', raise_exception=True)
+@permission_required("sql.sql_review", raise_exception=True)
def passed(request):
"""
审核通过,不执行
:param request:
:return:
"""
- workflow_id = int(request.POST.get('workflow_id', 0))
- audit_remark = request.POST.get('audit_remark', '')
+ workflow_id = int(request.POST.get("workflow_id", 0))
+ audit_remark = request.POST.get("audit_remark", "")
if workflow_id == 0:
- context = {'errMsg': 'workflow_id参数为空.'}
- return render(request, 'error.html', context)
+ context = {"errMsg": "workflow_id参数为空."}
+ return render(request, "error.html", context)
user = request.user
if Audit.can_review(user, workflow_id, 2) is False:
- context = {'errMsg': '你无权操作当前工单!'}
- return render(request, 'error.html', context)
+ context = {"errMsg": "你无权操作当前工单!"}
+ return render(request, "error.html", context)
# 使用事务保持数据一致性
try:
with transaction.atomic():
# 调用工作流接口审核
- audit_id = Audit.detail_by_workflow_id(workflow_id=workflow_id,
- workflow_type=WorkflowDict.workflow_type['sqlreview']).audit_id
- audit_result = Audit.audit(audit_id, WorkflowDict.workflow_status['audit_success'],
- user.username, audit_remark)
+ audit_id = Audit.detail_by_workflow_id(
+ workflow_id=workflow_id,
+ workflow_type=WorkflowDict.workflow_type["sqlreview"],
+ ).audit_id
+ audit_result = Audit.audit(
+ audit_id,
+ WorkflowDict.workflow_status["audit_success"],
+ user.username,
+ audit_remark,
+ )
# 按照审核结果更新业务表审核状态
- if audit_result['data']['workflow_status'] == WorkflowDict.workflow_status['audit_success']:
+ if (
+ audit_result["data"]["workflow_status"]
+ == WorkflowDict.workflow_status["audit_success"]
+ ):
# 将流程状态修改为审核通过
- SqlWorkflow(id=workflow_id, status='workflow_review_pass').save(update_fields=['status'])
+ SqlWorkflow(id=workflow_id, status="workflow_review_pass").save(
+ update_fields=["status"]
+ )
except Exception as msg:
logger.error(f"审核工单报错,错误信息:{traceback.format_exc()}")
- context = {'errMsg': msg}
- return render(request, 'error.html', context)
+ context = {"errMsg": msg}
+ return render(request, "error.html", context)
else:
# 开启了Pass阶段通知参数才发送消息通知
sys_config = SysConfig()
- is_notified = 'Pass' in sys_config.get('notify_phase_control').split(',') \
- if sys_config.get('notify_phase_control') else True
+ is_notified = (
+ "Pass" in sys_config.get("notify_phase_control").split(",")
+ if sys_config.get("notify_phase_control")
+ else True
+ )
if is_notified:
- async_task(notify_for_audit, audit_id=audit_id, audit_remark=audit_remark, timeout=60,
- task_name=f'sqlreview-pass-{workflow_id}')
+ async_task(
+ notify_for_audit,
+ audit_id=audit_id,
+ audit_remark=audit_remark,
+ timeout=60,
+ task_name=f"sqlreview-pass-{workflow_id}",
+ )
- return HttpResponseRedirect(reverse('sql:detail', args=(workflow_id,)))
+ return HttpResponseRedirect(reverse("sql:detail", args=(workflow_id,)))
def execute(request):
@@ -389,63 +456,84 @@ def execute(request):
:return:
"""
# 校验多个权限
- if not (request.user.has_perm('sql.sql_execute') or request.user.has_perm('sql.sql_execute_for_resource_group')):
+ if not (
+ request.user.has_perm("sql.sql_execute")
+ or request.user.has_perm("sql.sql_execute_for_resource_group")
+ ):
raise PermissionDenied
- workflow_id = int(request.POST.get('workflow_id', 0))
+ workflow_id = int(request.POST.get("workflow_id", 0))
if workflow_id == 0:
- context = {'errMsg': 'workflow_id参数为空.'}
- return render(request, 'error.html', context)
+ context = {"errMsg": "workflow_id参数为空."}
+ return render(request, "error.html", context)
if can_execute(request.user, workflow_id) is False:
- context = {'errMsg': '你无权操作当前工单!'}
- return render(request, 'error.html', context)
+ context = {"errMsg": "你无权操作当前工单!"}
+ return render(request, "error.html", context)
if on_correct_time_period(workflow_id) is False:
- context = {'errMsg': '不在可执行时间范围内,如果需要修改执行时间请重新提交工单!'}
- return render(request, 'error.html', context)
+ context = {"errMsg": "不在可执行时间范围内,如果需要修改执行时间请重新提交工单!"}
+ return render(request, "error.html", context)
# 获取审核信息
- audit_id = Audit.detail_by_workflow_id(workflow_id=workflow_id,
- workflow_type=WorkflowDict.workflow_type['sqlreview']).audit_id
+ audit_id = Audit.detail_by_workflow_id(
+ workflow_id=workflow_id, workflow_type=WorkflowDict.workflow_type["sqlreview"]
+ ).audit_id
# 根据执行模式进行对应修改
- mode = request.POST.get('mode')
+ mode = request.POST.get("mode")
# 交由系统执行
if mode == "auto":
# 修改工单状态为排队中
- SqlWorkflow(id=workflow_id, status="workflow_queuing").save(update_fields=['status'])
+ SqlWorkflow(id=workflow_id, status="workflow_queuing").save(
+ update_fields=["status"]
+ )
# 删除定时执行任务
schedule_name = f"sqlreview-timing-{workflow_id}"
del_schedule(schedule_name)
# 加入执行队列
- async_task('sql.utils.execute_sql.execute', workflow_id, request.user,
- hook='sql.utils.execute_sql.execute_callback',
- timeout=-1, task_name=f'sqlreview-execute-{workflow_id}')
+ async_task(
+ "sql.utils.execute_sql.execute",
+ workflow_id,
+ request.user,
+ hook="sql.utils.execute_sql.execute_callback",
+ timeout=-1,
+ task_name=f"sqlreview-execute-{workflow_id}",
+ )
# 增加工单日志
- Audit.add_log(audit_id=audit_id,
- operation_type=5,
- operation_type_desc='执行工单',
- operation_info='工单执行排队中',
- operator=request.user.username,
- operator_display=request.user.display)
+ Audit.add_log(
+ audit_id=audit_id,
+ operation_type=5,
+ operation_type_desc="执行工单",
+ operation_info="工单执行排队中",
+ operator=request.user.username,
+ operator_display=request.user.display,
+ )
# 线下手工执行
elif mode == "manual":
# 将流程状态修改为执行结束
- SqlWorkflow(id=workflow_id, status="workflow_finish", finish_time=datetime.datetime.now()
- ).save(update_fields=['status', 'finish_time'])
+ SqlWorkflow(
+ id=workflow_id,
+ status="workflow_finish",
+ finish_time=datetime.datetime.now(),
+ ).save(update_fields=["status", "finish_time"])
# 增加工单日志
- Audit.add_log(audit_id=audit_id,
- operation_type=6,
- operation_type_desc='手工工单',
- operation_info='确认手工执行结束',
- operator=request.user.username,
- operator_display=request.user.display)
+ Audit.add_log(
+ audit_id=audit_id,
+ operation_type=6,
+ operation_type_desc="手工工单",
+ operation_info="确认手工执行结束",
+ operator=request.user.username,
+ operator_display=request.user.display,
+ )
# 开启了Execute阶段通知参数才发送消息通知
sys_config = SysConfig()
- is_notified = 'Execute' in sys_config.get('notify_phase_control').split(',') \
- if sys_config.get('notify_phase_control') else True
+ is_notified = (
+ "Execute" in sys_config.get("notify_phase_control").split(",")
+ if sys_config.get("notify_phase_control")
+ else True
+ )
if is_notified:
notify_for_execute(SqlWorkflow.objects.get(id=workflow_id))
- return HttpResponseRedirect(reverse('sql:detail', args=(workflow_id,)))
+ return HttpResponseRedirect(reverse("sql:detail", args=(workflow_id,)))
def timing_task(request):
@@ -455,53 +543,58 @@ def timing_task(request):
:return:
"""
# 校验多个权限
- if not (request.user.has_perm('sql.sql_execute') or request.user.has_perm('sql.sql_execute_for_resource_group')):
+ if not (
+ request.user.has_perm("sql.sql_execute")
+ or request.user.has_perm("sql.sql_execute_for_resource_group")
+ ):
raise PermissionDenied
- workflow_id = request.POST.get('workflow_id')
- run_date = request.POST.get('run_date')
+ workflow_id = request.POST.get("workflow_id")
+ run_date = request.POST.get("run_date")
if run_date is None or workflow_id is None:
- context = {'errMsg': '时间不能为空'}
- return render(request, 'error.html', context)
- elif run_date < datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'):
- context = {'errMsg': '时间不能小于当前时间'}
- return render(request, 'error.html', context)
+ context = {"errMsg": "时间不能为空"}
+ return render(request, "error.html", context)
+ elif run_date < datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"):
+ context = {"errMsg": "时间不能小于当前时间"}
+ return render(request, "error.html", context)
workflow_detail = SqlWorkflow.objects.get(id=workflow_id)
if can_timingtask(request.user, workflow_id) is False:
- context = {'errMsg': '你无权操作当前工单!'}
- return render(request, 'error.html', context)
+ context = {"errMsg": "你无权操作当前工单!"}
+ return render(request, "error.html", context)
run_date = datetime.datetime.strptime(run_date, "%Y-%m-%d %H:%M")
schedule_name = f"sqlreview-timing-{workflow_id}"
if on_correct_time_period(workflow_id, run_date) is False:
- context = {'errMsg': '不在可执行时间范围内,如果需要修改执 行时间请重新提交工单!'}
- return render(request, 'error.html', context)
+ context = {"errMsg": "不在可执行时间范围内,如果需要修改执 行时间请重新提交工单!"}
+ return render(request, "error.html", context)
# 使用事务保持数据一致性
try:
with transaction.atomic():
# 将流程状态修改为定时执行
- workflow_detail.status = 'workflow_timingtask'
+ workflow_detail.status = "workflow_timingtask"
workflow_detail.save()
# 调用添加定时任务
add_sql_schedule(schedule_name, run_date, workflow_id)
# 增加工单日志
- audit_id = Audit.detail_by_workflow_id(workflow_id=workflow_id,
- workflow_type=WorkflowDict.workflow_type[
- 'sqlreview']).audit_id
- Audit.add_log(audit_id=audit_id,
- operation_type=4,
- operation_type_desc='定时执行',
- operation_info="定时执行时间:{}".format(run_date),
- operator=request.user.username,
- operator_display=request.user.display
- )
+ audit_id = Audit.detail_by_workflow_id(
+ workflow_id=workflow_id,
+ workflow_type=WorkflowDict.workflow_type["sqlreview"],
+ ).audit_id
+ Audit.add_log(
+ audit_id=audit_id,
+ operation_type=4,
+ operation_type_desc="定时执行",
+ operation_info="定时执行时间:{}".format(run_date),
+ operator=request.user.username,
+ operator_display=request.user.display,
+ )
except Exception as msg:
logger.error(f"定时执行工单报错,错误信息:{traceback.format_exc()}")
- context = {'errMsg': msg}
- return render(request, 'error.html', context)
- return HttpResponseRedirect(reverse('sql:detail', args=(workflow_id,)))
+ context = {"errMsg": msg}
+ return render(request, "error.html", context)
+ return HttpResponseRedirect(reverse("sql:detail", args=(workflow_id,)))
def cancel(request):
@@ -510,94 +603,115 @@ def cancel(request):
:param request:
:return:
"""
- workflow_id = int(request.POST.get('workflow_id', 0))
+ workflow_id = int(request.POST.get("workflow_id", 0))
if workflow_id == 0:
- context = {'errMsg': 'workflow_id参数为空.'}
- return render(request, 'error.html', context)
+ context = {"errMsg": "workflow_id参数为空."}
+ return render(request, "error.html", context)
workflow_detail = SqlWorkflow.objects.get(id=workflow_id)
- audit_remark = request.POST.get('cancel_remark')
+ audit_remark = request.POST.get("cancel_remark")
if audit_remark is None:
- context = {'errMsg': '终止原因不能为空'}
- return render(request, 'error.html', context)
+ context = {"errMsg": "终止原因不能为空"}
+ return render(request, "error.html", context)
user = request.user
if can_cancel(request.user, workflow_id) is False:
- context = {'errMsg': '你无权操作当前工单!'}
- return render(request, 'error.html', context)
+ context = {"errMsg": "你无权操作当前工单!"}
+ return render(request, "error.html", context)
# 使用事务保持数据一致性
try:
with transaction.atomic():
# 调用工作流接口取消或者驳回
- audit_id = Audit.detail_by_workflow_id(workflow_id=workflow_id,
- workflow_type=WorkflowDict.workflow_type[
- 'sqlreview']).audit_id
+ audit_id = Audit.detail_by_workflow_id(
+ workflow_id=workflow_id,
+ workflow_type=WorkflowDict.workflow_type["sqlreview"],
+ ).audit_id
# 仅待审核的需要调用工作流,审核通过的不需要
- if workflow_detail.status != 'workflow_manreviewing':
+ if workflow_detail.status != "workflow_manreviewing":
# 增加工单日志
if user.username == workflow_detail.engineer:
- Audit.add_log(audit_id=audit_id,
- operation_type=3,
- operation_type_desc='取消执行',
- operation_info="取消原因:{}".format(audit_remark),
- operator=request.user.username,
- operator_display=request.user.display
- )
+ Audit.add_log(
+ audit_id=audit_id,
+ operation_type=3,
+ operation_type_desc="取消执行",
+ operation_info="取消原因:{}".format(audit_remark),
+ operator=request.user.username,
+ operator_display=request.user.display,
+ )
else:
- Audit.add_log(audit_id=audit_id,
- operation_type=2,
- operation_type_desc='审批不通过',
- operation_info="审批备注:{}".format(audit_remark),
- operator=request.user.username,
- operator_display=request.user.display
- )
+ Audit.add_log(
+ audit_id=audit_id,
+ operation_type=2,
+ operation_type_desc="审批不通过",
+ operation_info="审批备注:{}".format(audit_remark),
+ operator=request.user.username,
+ operator_display=request.user.display,
+ )
else:
if user.username == workflow_detail.engineer:
- Audit.audit(audit_id,
- WorkflowDict.workflow_status['audit_abort'],
- user.username, audit_remark)
+ Audit.audit(
+ audit_id,
+ WorkflowDict.workflow_status["audit_abort"],
+ user.username,
+ audit_remark,
+ )
# 非提交人需要校验审核权限
- elif user.has_perm('sql.sql_review'):
- Audit.audit(audit_id,
- WorkflowDict.workflow_status['audit_reject'],
- user.username, audit_remark)
+ elif user.has_perm("sql.sql_review"):
+ Audit.audit(
+ audit_id,
+ WorkflowDict.workflow_status["audit_reject"],
+ user.username,
+ audit_remark,
+ )
else:
raise PermissionDenied
# 删除定时执行task
- if workflow_detail.status == 'workflow_timingtask':
+ if workflow_detail.status == "workflow_timingtask":
schedule_name = f"sqlreview-timing-{workflow_id}"
del_schedule(schedule_name)
# 将流程状态修改为人工终止流程
- workflow_detail.status = 'workflow_abort'
+ workflow_detail.status = "workflow_abort"
workflow_detail.save()
except Exception as msg:
logger.error(f"取消工单报错,错误信息:{traceback.format_exc()}")
- context = {'errMsg': msg}
- return render(request, 'error.html', context)
+ context = {"errMsg": msg}
+ return render(request, "error.html", context)
else:
# 发送取消、驳回通知,开启了Cancel阶段通知参数才发送消息通知
sys_config = SysConfig()
- is_notified = 'Cancel' in sys_config.get('notify_phase_control').split(',') \
- if sys_config.get('notify_phase_control') else True
+ is_notified = (
+ "Cancel" in sys_config.get("notify_phase_control").split(",")
+ if sys_config.get("notify_phase_control")
+ else True
+ )
if is_notified:
- audit_detail = Audit.detail_by_workflow_id(workflow_id=workflow_id,
- workflow_type=WorkflowDict.workflow_type['sqlreview'])
+ audit_detail = Audit.detail_by_workflow_id(
+ workflow_id=workflow_id,
+ workflow_type=WorkflowDict.workflow_type["sqlreview"],
+ )
if audit_detail.current_status in (
- WorkflowDict.workflow_status['audit_abort'], WorkflowDict.workflow_status['audit_reject']):
- async_task(notify_for_audit, audit_id=audit_detail.audit_id, audit_remark=audit_remark, timeout=60,
- task_name=f'sqlreview-cancel-{workflow_id}')
- return HttpResponseRedirect(reverse('sql:detail', args=(workflow_id,)))
+ WorkflowDict.workflow_status["audit_abort"],
+ WorkflowDict.workflow_status["audit_reject"],
+ ):
+ async_task(
+ notify_for_audit,
+ audit_id=audit_detail.audit_id,
+ audit_remark=audit_remark,
+ timeout=60,
+ task_name=f"sqlreview-cancel-{workflow_id}",
+ )
+ return HttpResponseRedirect(reverse("sql:detail", args=(workflow_id,)))
def get_workflow_status(request):
"""
获取某个工单的当前状态
"""
- workflow_id = request.POST['workflow_id']
- if workflow_id == '' or workflow_id is None:
- context = {"status": -1, 'msg': 'workflow_id参数为空.', "data": ""}
- return HttpResponse(json.dumps(context), content_type='application/json')
+ workflow_id = request.POST["workflow_id"]
+ if workflow_id == "" or workflow_id is None:
+ context = {"status": -1, "msg": "workflow_id参数为空.", "data": ""}
+ return HttpResponse(json.dumps(context), content_type="application/json")
workflow_id = int(workflow_id)
workflow_detail = get_object_or_404(SqlWorkflow, pk=workflow_id)
@@ -607,9 +721,9 @@ def get_workflow_status(request):
def osc_control(request):
"""用于mysql控制osc执行"""
- workflow_id = request.POST.get('workflow_id')
- sqlsha1 = request.POST.get('sqlsha1')
- command = request.POST.get('command')
+ workflow_id = request.POST.get("workflow_id")
+ sqlsha1 = request.POST.get("sqlsha1")
+ command = request.POST.get("command")
workflow = SqlWorkflow.objects.get(id=workflow_id)
execute_engine = get_engine(workflow.instance)
try:
@@ -620,5 +734,7 @@ def osc_control(request):
rows = []
error = str(e)
result = {"total": len(rows), "rows": rows, "msg": error}
- return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
- content_type='application/json')
+ return HttpResponse(
+ json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
+ content_type="application/json",
+ )
diff --git a/sql/templatetags/format_tags.py b/sql/templatetags/format_tags.py
index 64991ac17d..902eb52371 100644
--- a/sql/templatetags/format_tags.py
+++ b/sql/templatetags/format_tags.py
@@ -10,7 +10,7 @@
@register.simple_tag
def format_str(string):
# 换行
- return mark_safe(string.replace(',', '
').replace('\n', '
'))
+ return mark_safe(string.replace(",", "
").replace("\n", "
"))
# split
@@ -26,7 +26,7 @@ def split(string, sep):
# in
@register.filter
def is_in(var, args):
- return True if str(var) in args.split(',') else False
+ return True if str(var) in args.split(",") else False
@register.filter
@@ -34,4 +34,4 @@ def key_value(data, key):
try:
return data[key]
except KeyError:
- return ''
+ return ""
diff --git a/sql/tests.py b/sql/tests.py
index 5a6700a0b0..c2fb567cc0 100644
--- a/sql/tests.py
+++ b/sql/tests.py
@@ -16,9 +16,21 @@
from sql.notify import notify_for_audit, notify_for_execute, notify_for_my2sql
from sql.utils.execute_sql import execute_callback
from sql.query import kill_query_conn
-from sql.models import Users, Instance, QueryPrivilegesApply, QueryPrivileges, SqlWorkflow, SqlWorkflowContent, \
- ResourceGroup, ParamTemplate, WorkflowAudit, QueryLog, WorkflowLog, WorkflowAuditSetting, \
- ArchiveConfig
+from sql.models import (
+ Users,
+ Instance,
+ QueryPrivilegesApply,
+ QueryPrivileges,
+ SqlWorkflow,
+ SqlWorkflowContent,
+ ResourceGroup,
+ ParamTemplate,
+ WorkflowAudit,
+ QueryLog,
+ WorkflowLog,
+ WorkflowAuditSetting,
+ ArchiveConfig,
+)
User = Users
@@ -32,53 +44,63 @@ def setUp(self):
"""
self.sys_config = SysConfig()
self.client = Client()
- self.superuser = User.objects.create(username='super', is_superuser=True)
+ self.superuser = User.objects.create(username="super", is_superuser=True)
self.client.force_login(self.superuser)
- self.ins = Instance.objects.create(instance_name='some_ins', type='slave', db_type='mysql',
- host='some_host',
- port=3306, user='ins_user', password='some_str')
- self.res_group = ResourceGroup.objects.create(group_id=1, group_name='group_name')
+ self.ins = Instance.objects.create(
+ instance_name="some_ins",
+ type="slave",
+ db_type="mysql",
+ host="some_host",
+ port=3306,
+ user="ins_user",
+ password="some_str",
+ )
+ self.res_group = ResourceGroup.objects.create(
+ group_id=1, group_name="group_name"
+ )
self.wf = SqlWorkflow.objects.create(
- workflow_name='some_name',
+ workflow_name="some_name",
group_id=1,
- group_name='g1',
- engineer_display='',
- audit_auth_groups='some_audit_group',
- status='workflow_finish',
+ group_name="g1",
+ engineer_display="",
+ audit_auth_groups="some_audit_group",
+ status="workflow_finish",
is_backup=True,
instance=self.ins,
- db_name='some_db',
- syntax_type=1
+ db_name="some_db",
+ syntax_type=1,
+ )
+ SqlWorkflowContent.objects.create(
+ workflow=self.wf, sql_content="some_sql", execute_result=""
)
- SqlWorkflowContent.objects.create(workflow=self.wf,
- sql_content='some_sql',
- execute_result='')
self.query_apply = QueryPrivilegesApply.objects.create(
group_id=1,
- group_name='some_name',
- title='some_title1',
- user_name='some_user',
+ group_name="some_name",
+ title="some_title1",
+ user_name="some_user",
instance=self.ins,
- db_list='some_db,some_db2',
+ db_list="some_db,some_db2",
limit_num=100,
- valid_date='2020-01-1',
+ valid_date="2020-01-1",
priv_type=1,
status=0,
- audit_auth_groups='some_audit_group'
+ audit_auth_groups="some_audit_group",
)
self.audit = WorkflowAudit.objects.create(
group_id=1,
- group_name='some_group',
+ group_name="some_group",
workflow_id=1,
workflow_type=1,
- workflow_title='申请标题',
- workflow_remark='申请备注',
- audit_auth_groups='1,2,3',
- current_audit='1',
- next_audit='2',
- current_status=0)
- self.wl = WorkflowLog.objects.create(audit_id=self.audit.audit_id,
- operation_type=1)
+ workflow_title="申请标题",
+ workflow_remark="申请备注",
+ audit_auth_groups="1,2,3",
+ current_audit="1",
+ next_audit="2",
+ current_status=0,
+ )
+ self.wl = WorkflowLog.objects.create(
+ audit_id=self.audit.audit_id, operation_type=1
+ )
def tearDown(self):
self.sys_config.purge()
@@ -93,170 +115,170 @@ def tearDown(self):
def test_index(self):
"""测试index页面"""
data = {}
- r = self.client.get('/index/', data=data)
- self.assertRedirects(r, f'/sqlworkflow/', fetch_redirect_response=False)
+ r = self.client.get("/index/", data=data)
+ self.assertRedirects(r, f"/sqlworkflow/", fetch_redirect_response=False)
def test_dashboard(self):
"""测试dashboard页面"""
data = {}
- r = self.client.get('/dashboard/', data=data)
+ r = self.client.get("/dashboard/", data=data)
self.assertEqual(r.status_code, 200)
- self.assertContains(r, 'SQL上线工单')
+ self.assertContains(r, "SQL上线工单")
def test_sqlworkflow(self):
"""测试sqlworkflow页面"""
data = {}
- r = self.client.get('/sqlworkflow/', data=data)
+ r = self.client.get("/sqlworkflow/", data=data)
self.assertEqual(r.status_code, 200)
def test_submitsql(self):
"""测试submitsql页面"""
data = {}
- r = self.client.get('/submitsql/', data=data)
+ r = self.client.get("/submitsql/", data=data)
self.assertEqual(r.status_code, 200)
def test_rollback(self):
"""测试rollback页面"""
data = {"workflow_id": self.wf.id}
- r = self.client.get('/rollback/', data=data)
+ r = self.client.get("/rollback/", data=data)
self.assertEqual(r.status_code, 200)
def test_sqlanalyze(self):
"""测试sqlanalyze页面"""
data = {}
- r = self.client.get('/sqlanalyze/', data=data)
+ r = self.client.get("/sqlanalyze/", data=data)
self.assertEqual(r.status_code, 200)
def test_sqlquery(self):
"""测试sqlquery页面"""
data = {}
- r = self.client.get('/sqlquery/', data=data)
+ r = self.client.get("/sqlquery/", data=data)
self.assertEqual(r.status_code, 200)
def test_queryapplylist(self):
"""测试queryapplylist页面"""
data = {}
- r = self.client.get('/queryapplylist/', data=data)
+ r = self.client.get("/queryapplylist/", data=data)
self.assertEqual(r.status_code, 200)
def test_queryuserprivileges(self):
"""测试queryuserprivileges页面"""
data = {}
- r = self.client.get(f'/queryuserprivileges/', data=data)
+ r = self.client.get(f"/queryuserprivileges/", data=data)
self.assertEqual(r.status_code, 200)
def test_sqladvisor(self):
"""测试sqladvisor页面"""
data = {}
- r = self.client.get(f'/sqladvisor/', data=data)
+ r = self.client.get(f"/sqladvisor/", data=data)
self.assertEqual(r.status_code, 200)
def test_slowquery(self):
"""测试slowquery页面"""
data = {}
- r = self.client.get(f'/slowquery/', data=data)
+ r = self.client.get(f"/slowquery/", data=data)
self.assertEqual(r.status_code, 200)
def test_instance(self):
"""测试instance页面"""
data = {}
- r = self.client.get(f'/instance/', data=data)
+ r = self.client.get(f"/instance/", data=data)
self.assertEqual(r.status_code, 200)
def test_instanceaccount(self):
"""测试instanceaccount页面"""
data = {}
- r = self.client.get(f'/instanceaccount/', data=data)
+ r = self.client.get(f"/instanceaccount/", data=data)
self.assertEqual(r.status_code, 200)
def test_database(self):
"""测试database页面"""
data = {}
- r = self.client.get(f'/database/', data=data)
+ r = self.client.get(f"/database/", data=data)
self.assertEqual(r.status_code, 200)
def test_dbdiagnostic(self):
"""测试dbdiagnostic页面"""
data = {}
- r = self.client.get(f'/dbdiagnostic/', data=data)
+ r = self.client.get(f"/dbdiagnostic/", data=data)
self.assertEqual(r.status_code, 200)
def test_instanceparam(self):
"""测试instance_param页面"""
data = {}
- r = self.client.get(f'/instanceparam/', data=data)
+ r = self.client.get(f"/instanceparam/", data=data)
self.assertEqual(r.status_code, 200)
def test_my2sql(self):
"""测试my2sql页面"""
data = {}
- r = self.client.get(f'/my2sql/', data=data)
+ r = self.client.get(f"/my2sql/", data=data)
self.assertEqual(r.status_code, 200)
def test_schemasync(self):
"""测试schemasync页面"""
data = {}
- r = self.client.get(f'/schemasync/', data=data)
+ r = self.client.get(f"/schemasync/", data=data)
self.assertEqual(r.status_code, 200)
def test_archive(self):
"""测试archive页面"""
data = {}
- r = self.client.get(f'/archive/', data=data)
+ r = self.client.get(f"/archive/", data=data)
self.assertEqual(r.status_code, 200)
def test_config(self):
"""测试config页面"""
data = {}
- r = self.client.get(f'/config/', data=data)
+ r = self.client.get(f"/config/", data=data)
self.assertEqual(r.status_code, 200)
def test_group(self):
"""测试group页面"""
data = {}
- r = self.client.get(f'/group/', data=data)
+ r = self.client.get(f"/group/", data=data)
self.assertEqual(r.status_code, 200)
def test_audit(self):
"""测试audit页面"""
data = {}
- r = self.client.get(f'/audit/', data=data)
+ r = self.client.get(f"/audit/", data=data)
self.assertEqual(r.status_code, 200)
def test_audit_sqlquery(self):
"""测试audit_sqlquery页面"""
data = {}
- r = self.client.get(f'/audit_sqlquery/', data=data)
+ r = self.client.get(f"/audit_sqlquery/", data=data)
self.assertEqual(r.status_code, 200)
def test_audit_sqlworkflow(self):
"""测试audit_sqlworkflow页面"""
data = {}
- r = self.client.get(f'/audit_sqlworkflow/', data=data)
+ r = self.client.get(f"/audit_sqlworkflow/", data=data)
self.assertEqual(r.status_code, 200)
def test_groupmgmt(self):
"""测试groupmgmt页面"""
data = {}
- r = self.client.get(f'/grouprelations/{self.res_group.group_id}/', data=data)
+ r = self.client.get(f"/grouprelations/{self.res_group.group_id}/", data=data)
self.assertEqual(r.status_code, 200)
def test_workflows(self):
"""测试workflows页面"""
data = {}
- r = self.client.get(f'/workflow/', data=data)
+ r = self.client.get(f"/workflow/", data=data)
self.assertEqual(r.status_code, 200)
def test_workflowsdetail(self):
"""测试workflows页面"""
data = {}
- r = self.client.get(f'/workflow/{self.audit.audit_id}/', data=data)
- self.assertRedirects(r, f'/queryapplydetail/1/', fetch_redirect_response=False)
+ r = self.client.get(f"/workflow/{self.audit.audit_id}/", data=data)
+ self.assertRedirects(r, f"/queryapplydetail/1/", fetch_redirect_response=False)
def test_dbaprinciples(self):
"""测试workflows页面"""
data = {}
- r = self.client.get(f'/dbaprinciples/', data=data)
+ r = self.client.get(f"/dbaprinciples/", data=data)
self.assertEqual(r.status_code, 200)
@@ -268,10 +290,10 @@ def setUp(self):
创建默认组给注册关联用户, 打开注册
"""
archer_config = SysConfig()
- archer_config.set('sign_up_enabled', 'true')
+ archer_config.set("sign_up_enabled", "true")
archer_config.get_all_config()
self.client = Client()
- Group.objects.create(id=1, name='默认组')
+ Group.objects.create(id=1, name="默认组")
def tearDown(self):
SysConfig().purge()
@@ -282,100 +304,133 @@ def test_sing_up_not_username(self):
"""
用户名不能为空
"""
- response = self.client.post('/signup/', data={})
+ response = self.client.post("/signup/", data={})
data = json.loads(response.content)
- content = {'status': 1, 'msg': '用户名和密码不能为空', 'data': None}
+ content = {"status": 1, "msg": "用户名和密码不能为空", "data": None}
self.assertEqual(data, content)
def test_sing_up_not_password(self):
"""
密码不能为空
"""
- response = self.client.post('/signup/', data={'username': 'test'})
+ response = self.client.post("/signup/", data={"username": "test"})
data = json.loads(response.content)
- content = {'status': 1, 'msg': '用户名和密码不能为空', 'data': None}
+ content = {"status": 1, "msg": "用户名和密码不能为空", "data": None}
self.assertEqual(data, content)
def test_sing_up_not_display(self):
"""
中文名不能为空
"""
- response = self.client.post('/signup/', data={'username': 'test', 'password': '123456test',
- 'password2': '123456test', 'display': '',
- 'email': '123@123.com'})
+ response = self.client.post(
+ "/signup/",
+ data={
+ "username": "test",
+ "password": "123456test",
+ "password2": "123456test",
+ "display": "",
+ "email": "123@123.com",
+ },
+ )
data = json.loads(response.content)
- content = {'status': 1, 'msg': '请填写中文名', 'data': None}
+ content = {"status": 1, "msg": "请填写中文名", "data": None}
self.assertEqual(data, content)
def test_sing_up_2password(self):
"""
两次输入密码不一致
"""
- response = self.client.post('/signup/', data={'username': 'test', 'password': '123456', 'password2': '12345'})
+ response = self.client.post(
+ "/signup/",
+ data={"username": "test", "password": "123456", "password2": "12345"},
+ )
data = json.loads(response.content)
- content = {'status': 1, 'msg': '两次输入密码不一致', 'data': None}
+ content = {"status": 1, "msg": "两次输入密码不一致", "data": None}
self.assertEqual(data, content)
def test_sing_up_duplicate_uesrname(self):
"""
用户名已存在
"""
- User.objects.create(username='test', password='123456')
- response = self.client.post('/signup/',
- data={'username': 'test', 'password': '123456', 'password2': '123456'})
+ User.objects.create(username="test", password="123456")
+ response = self.client.post(
+ "/signup/",
+ data={"username": "test", "password": "123456", "password2": "123456"},
+ )
data = json.loads(response.content)
- content = {'status': 1, 'msg': '用户名已存在', 'data': None}
+ content = {"status": 1, "msg": "用户名已存在", "data": None}
self.assertEqual(data, content)
def test_sing_up_invalid(self):
"""
密码无效
"""
- self.client.post('/signup/',
- data={'username': 'test', 'password': '123456',
- 'password2': '123456test', 'display': 'test', 'email': '123@123.com'})
+ self.client.post(
+ "/signup/",
+ data={
+ "username": "test",
+ "password": "123456",
+ "password2": "123456test",
+ "display": "test",
+ "email": "123@123.com",
+ },
+ )
with self.assertRaises(User.DoesNotExist):
- User.objects.get(username='test')
+ User.objects.get(username="test")
- @patch('common.auth.init_user')
+ @patch("common.auth.init_user")
def test_sing_up_valid(self, mock_init):
"""
正常注册
"""
- self.client.post('/signup/',
- data={'username': 'test', 'password': '123456test',
- 'password2': '123456test', 'display': 'test', 'email': '123@123.com'})
- user = User.objects.get(username='test')
+ self.client.post(
+ "/signup/",
+ data={
+ "username": "test",
+ "password": "123456test",
+ "password2": "123456test",
+ "display": "test",
+ "email": "123@123.com",
+ },
+ )
+ user = User.objects.get(username="test")
self.assertTrue(user)
# 注册后登录
- r = self.client.post('/authenticate/', data={'username': 'test', 'password': '123456test'}, follow=False)
+ r = self.client.post(
+ "/authenticate/",
+ data={"username": "test", "password": "123456test"},
+ follow=False,
+ )
r_json = r.json()
- self.assertEqual(0, r_json['status'])
+ self.assertEqual(0, r_json["status"])
# 只允许初始化用户一次
mock_init.assert_called_once()
class TestUser(TestCase):
def setUp(self):
- self.u1 = User(username='test_user', display='中文显示', is_active=True)
- self.u1.set_password('test_password')
+ self.u1 = User(username="test_user", display="中文显示", is_active=True)
+ self.u1.set_password("test_password")
self.u1.save()
def tearDown(self):
self.u1.delete()
- @patch('common.auth.init_user')
+ @patch("common.auth.init_user")
def testLogin(self, mock_init):
"""login 页面测试"""
- r = self.client.get('/login/')
+ r = self.client.get("/login/")
self.assertEqual(r.status_code, 200)
- self.assertTemplateUsed(r, 'login.html')
- r = self.client.post('/authenticate/', data={'username': 'test_user', 'password': 'test_password'})
+ self.assertTemplateUsed(r, "login.html")
+ r = self.client.post(
+ "/authenticate/",
+ data={"username": "test_user", "password": "test_password"},
+ )
r_json = r.json()
- self.assertEqual(0, r_json['status'])
+ self.assertEqual(0, r_json["status"])
# 登录后直接跳首页
- r = self.client.get('/login/', follow=True)
- self.assertRedirects(r, '/sqlworkflow/')
+ r = self.client.get("/login/", follow=True)
+ self.assertRedirects(r, "/sqlworkflow/")
# init 只调用一次
mock_init.assert_called_once()
@@ -401,18 +456,22 @@ class TestQueryPrivilegesCheck(TestCase):
"""测试权限校验"""
def setUp(self):
- self.superuser = User.objects.create(username='super', is_superuser=True)
- self.user_can_query_all = User.objects.create(username='normaluser')
- query_all_instance_perm = Permission.objects.get(codename='query_all_instances')
+ self.superuser = User.objects.create(username="super", is_superuser=True)
+ self.user_can_query_all = User.objects.create(username="normaluser")
+ query_all_instance_perm = Permission.objects.get(codename="query_all_instances")
self.user_can_query_all.user_permissions.add(query_all_instance_perm)
- self.user = User.objects.create(username='user')
+ self.user = User.objects.create(username="user")
# 使用 travis.ci 时实例和测试service保持一致
- self.slave = Instance.objects.create(instance_name='test_instance', type='slave', db_type='mysql',
- host=settings.DATABASES['default']['HOST'],
- port=settings.DATABASES['default']['PORT'],
- user=settings.DATABASES['default']['USER'],
- password=settings.DATABASES['default']['PASSWORD'])
- self.db_name = settings.DATABASES['default']['TEST']['NAME']
+ self.slave = Instance.objects.create(
+ instance_name="test_instance",
+ type="slave",
+ db_type="mysql",
+ host=settings.DATABASES["default"]["HOST"],
+ port=settings.DATABASES["default"]["PORT"],
+ user=settings.DATABASES["default"]["USER"],
+ password=settings.DATABASES["default"]["PASSWORD"],
+ )
+ self.db_name = settings.DATABASES["default"]["TEST"]["NAME"]
self.sys_config = SysConfig()
self.client = Client()
@@ -428,9 +487,11 @@ def test_db_priv_super(self):
测试超级管理员验证数据库权限
:return:
"""
- self.sys_config.set('admin_query_limit', '50')
+ self.sys_config.set("admin_query_limit", "50")
self.sys_config.get_all_config()
- r = sql.query_privileges._db_priv(user=self.superuser, instance=self.slave, db_name=self.db_name)
+ r = sql.query_privileges._db_priv(
+ user=self.superuser, instance=self.slave, db_name=self.db_name
+ )
self.assertEqual(r, 50)
def test_db_priv_user_priv_not_exist(self):
@@ -438,7 +499,9 @@ def test_db_priv_user_priv_not_exist(self):
测试普通用户验证数据库权限,用户无权限
:return:
"""
- r = sql.query_privileges._db_priv(user=self.user, instance=self.slave, db_name=self.db_name)
+ r = sql.query_privileges._db_priv(
+ user=self.user, instance=self.slave, db_name=self.db_name
+ )
self.assertFalse(r)
def test_db_priv_user_priv_exist(self):
@@ -446,13 +509,17 @@ def test_db_priv_user_priv_exist(self):
测试普通用户验证数据库权限,用户有权限
:return:
"""
- QueryPrivileges.objects.create(user_name=self.user.username,
- instance=self.slave,
- db_name=self.db_name,
- valid_date=date.today() + timedelta(days=1),
- limit_num=10,
- priv_type=1)
- r = sql.query_privileges._db_priv(user=self.user, instance=self.slave, db_name=self.db_name)
+ QueryPrivileges.objects.create(
+ user_name=self.user.username,
+ instance=self.slave,
+ db_name=self.db_name,
+ valid_date=date.today() + timedelta(days=1),
+ limit_num=10,
+ priv_type=1,
+ )
+ r = sql.query_privileges._db_priv(
+ user=self.user, instance=self.slave, db_name=self.db_name
+ )
self.assertTrue(r)
def test_tb_priv_super(self):
@@ -460,10 +527,14 @@ def test_tb_priv_super(self):
测试超级管理员验证表权限
:return:
"""
- self.sys_config.set('admin_query_limit', '50')
+ self.sys_config.set("admin_query_limit", "50")
self.sys_config.get_all_config()
- r = sql.query_privileges._tb_priv(user=self.superuser, instance=self.slave, db_name=self.db_name,
- tb_name='table_name')
+ r = sql.query_privileges._tb_priv(
+ user=self.superuser,
+ instance=self.slave,
+ db_name=self.db_name,
+ tb_name="table_name",
+ )
self.assertEqual(r, 50)
def test_tb_priv_user_priv_not_exist(self):
@@ -471,8 +542,12 @@ def test_tb_priv_user_priv_not_exist(self):
测试普通用户验证表权限,用户无权限
:return:
"""
- r = sql.query_privileges._tb_priv(user=self.user, instance=self.slave, db_name=self.db_name,
- tb_name='table_name')
+ r = sql.query_privileges._tb_priv(
+ user=self.user,
+ instance=self.slave,
+ db_name=self.db_name,
+ tb_name="table_name",
+ )
self.assertFalse(r)
def test_tb_priv_user_priv_exist(self):
@@ -480,29 +555,37 @@ def test_tb_priv_user_priv_exist(self):
测试普通用户验证表权限,用户有权限
:return:
"""
- QueryPrivileges.objects.create(user_name=self.user.username,
- instance=self.slave,
- db_name=self.db_name,
- table_name='table_name',
- valid_date=date.today() + timedelta(days=1),
- limit_num=10,
- priv_type=2)
- r = sql.query_privileges._tb_priv(user=self.user, instance=self.slave, db_name=self.db_name,
- tb_name='table_name')
+ QueryPrivileges.objects.create(
+ user_name=self.user.username,
+ instance=self.slave,
+ db_name=self.db_name,
+ table_name="table_name",
+ valid_date=date.today() + timedelta(days=1),
+ limit_num=10,
+ priv_type=2,
+ )
+ r = sql.query_privileges._tb_priv(
+ user=self.user,
+ instance=self.slave,
+ db_name=self.db_name,
+ tb_name="table_name",
+ )
self.assertTrue(r)
- @patch('sql.query_privileges._db_priv')
+ @patch("sql.query_privileges._db_priv")
def test_priv_limit_from_db(self, __db_priv):
"""
测试用户获取查询数量限制,通过库名获取
:return:
"""
__db_priv.return_value = 10
- r = sql.query_privileges._priv_limit(user=self.user, instance=self.slave, db_name=self.db_name)
+ r = sql.query_privileges._priv_limit(
+ user=self.user, instance=self.slave, db_name=self.db_name
+ )
self.assertEqual(r, 10)
- @patch('sql.query_privileges._tb_priv')
- @patch('sql.query_privileges._db_priv')
+ @patch("sql.query_privileges._tb_priv")
+ @patch("sql.query_privileges._db_priv")
def test_priv_limit_from_tb(self, __db_priv, __tb_priv):
"""
测试用户获取查询数量限制,通过表名获取
@@ -510,190 +593,286 @@ def test_priv_limit_from_tb(self, __db_priv, __tb_priv):
"""
__db_priv.return_value = 10
__tb_priv.return_value = 1
- r = sql.query_privileges._priv_limit(user=self.user, instance=self.slave, db_name=self.db_name, tb_name='test')
+ r = sql.query_privileges._priv_limit(
+ user=self.user, instance=self.slave, db_name=self.db_name, tb_name="test"
+ )
self.assertEqual(r, 1)
- @patch('sql.engines.goinception.GoInceptionEngine.query_print')
+ @patch("sql.engines.goinception.GoInceptionEngine.query_print")
def test_table_ref(self, _query_print):
"""
测试通过goInception获取查询语句的table_ref
:return:
"""
- _query_print.return_value = {'id': 2, 'statement': 'select * from sql_users limit 100', 'errlevel': 0,
- 'query_tree': '{"text":"select * from sql_users limit 100","resultFields":null,"SQLCache":true,"CalcFoundRows":false,"StraightJoin":false,"Priority":0,"Distinct":false,"From":{"text":"","TableRefs":{"text":"","resultFields":null,"Left":{"text":"","Source":{"text":"","resultFields":null,"Schema":{"O":"","L":""},"Name":{"O":"sql_users","L":"sql_users"},"DBInfo":null,"TableInfo":null,"IndexHints":null},"AsName":{"O":"","L":""}},"Right":null,"Tp":0,"On":null,"Using":null,"NaturalJoin":false,"StraightJoin":false}},"Where":null,"Fields":{"text":"","Fields":[{"text":"","Offset":33,"WildCard":{"text":"","Table":{"O":"","L":""},"Schema":{"O":"","L":""}},"Expr":null,"AsName":{"O":"","L":""},"Auxiliary":false}]},"GroupBy":null,"Having":null,"OrderBy":null,"Limit":{"text":"","Count":{"text":"","k":2,"collation":0,"decimal":0,"length":0,"i":100,"b":null,"x":null,"Type":{"Tp":8,"Flag":160,"Flen":3,"Decimal":0,"Charset":"binary","Collate":"binary","Elems":null},"flag":0,"projectionOffset":-1},"Offset":null},"LockTp":0,"TableHints":null,"IsAfterUnionDistinct":false,"IsInBraces":false}',
- 'errmsg': None}
- r = sql.query_privileges._table_ref('select * from sql_users limit 100;', self.slave, self.db_name)
- self.assertListEqual(r, [{'schema': 'test_archery', 'name': 'sql_users'}])
+ _query_print.return_value = {
+ "id": 2,
+ "statement": "select * from sql_users limit 100",
+ "errlevel": 0,
+ "query_tree": '{"text":"select * from sql_users limit 100","resultFields":null,"SQLCache":true,"CalcFoundRows":false,"StraightJoin":false,"Priority":0,"Distinct":false,"From":{"text":"","TableRefs":{"text":"","resultFields":null,"Left":{"text":"","Source":{"text":"","resultFields":null,"Schema":{"O":"","L":""},"Name":{"O":"sql_users","L":"sql_users"},"DBInfo":null,"TableInfo":null,"IndexHints":null},"AsName":{"O":"","L":""}},"Right":null,"Tp":0,"On":null,"Using":null,"NaturalJoin":false,"StraightJoin":false}},"Where":null,"Fields":{"text":"","Fields":[{"text":"","Offset":33,"WildCard":{"text":"","Table":{"O":"","L":""},"Schema":{"O":"","L":""}},"Expr":null,"AsName":{"O":"","L":""},"Auxiliary":false}]},"GroupBy":null,"Having":null,"OrderBy":null,"Limit":{"text":"","Count":{"text":"","k":2,"collation":0,"decimal":0,"length":0,"i":100,"b":null,"x":null,"Type":{"Tp":8,"Flag":160,"Flen":3,"Decimal":0,"Charset":"binary","Collate":"binary","Elems":null},"flag":0,"projectionOffset":-1},"Offset":null},"LockTp":0,"TableHints":null,"IsAfterUnionDistinct":false,"IsInBraces":false}',
+ "errmsg": None,
+ }
+ r = sql.query_privileges._table_ref(
+ "select * from sql_users limit 100;", self.slave, self.db_name
+ )
+ self.assertListEqual(r, [{"schema": "test_archery", "name": "sql_users"}])
- @patch('sql.engines.goinception.GoInceptionEngine.query_print')
+ @patch("sql.engines.goinception.GoInceptionEngine.query_print")
def test_table_ref_wrong(self, _query_print):
"""
测试通过goInception获取查询语句的table_ref
:return:
"""
- _query_print.side_effect = RuntimeError('语法错误')
+ _query_print.side_effect = RuntimeError("语法错误")
with self.assertRaises(RuntimeError):
- sql.query_privileges._table_ref('select * from archery.sql_users;', self.slave, self.db_name)
+ sql.query_privileges._table_ref(
+ "select * from archery.sql_users;", self.slave, self.db_name
+ )
def test_query_priv_check_super(self):
"""
测试用户权限校验,超级管理员不做校验,直接返回系统配置的limit
:return:
"""
- r = sql.query_privileges.query_priv_check(user=self.superuser,
- instance=self.slave, db_name=self.db_name,
- sql_content="select * from archery.sql_users;",
- limit_num=100)
- self.assertDictEqual(r, {'status': 0, 'msg': 'ok', 'data': {'priv_check': True, 'limit_num': 100}})
- r = sql.query_privileges.query_priv_check(user=self.user_can_query_all,
- instance=self.slave, db_name=self.db_name,
- sql_content="select * from archery.sql_users;",
- limit_num=100)
- self.assertDictEqual(r, {'status': 0, 'msg': 'ok', 'data': {'priv_check': True, 'limit_num': 100}})
+ r = sql.query_privileges.query_priv_check(
+ user=self.superuser,
+ instance=self.slave,
+ db_name=self.db_name,
+ sql_content="select * from archery.sql_users;",
+ limit_num=100,
+ )
+ self.assertDictEqual(
+ r,
+ {"status": 0, "msg": "ok", "data": {"priv_check": True, "limit_num": 100}},
+ )
+ r = sql.query_privileges.query_priv_check(
+ user=self.user_can_query_all,
+ instance=self.slave,
+ db_name=self.db_name,
+ sql_content="select * from archery.sql_users;",
+ limit_num=100,
+ )
+ self.assertDictEqual(
+ r,
+ {"status": 0, "msg": "ok", "data": {"priv_check": True, "limit_num": 100}},
+ )
def test_query_priv_check_explain_or_show_create(self):
"""测试用户权限校验,explain和show create不做校验"""
- r = sql.query_privileges.query_priv_check(user=self.user,
- instance=self.slave, db_name=self.db_name,
- sql_content="show create table archery.sql_users;",
- limit_num=100)
+ r = sql.query_privileges.query_priv_check(
+ user=self.user,
+ instance=self.slave,
+ db_name=self.db_name,
+ sql_content="show create table archery.sql_users;",
+ limit_num=100,
+ )
self.assertTrue(r)
- @patch('sql.query_privileges._table_ref', return_value=[{'schema': 'archery', 'name': 'sql_users'}])
- @patch('sql.query_privileges._tb_priv', return_value=False)
- @patch('sql.query_privileges._db_priv', return_value=False)
+ @patch(
+ "sql.query_privileges._table_ref",
+ return_value=[{"schema": "archery", "name": "sql_users"}],
+ )
+ @patch("sql.query_privileges._tb_priv", return_value=False)
+ @patch("sql.query_privileges._db_priv", return_value=False)
def test_query_priv_check_no_priv(self, __db_priv, __tb_priv, __table_ref):
"""
测试用户权限校验,mysql实例、普通用户 无库表权限,inception语法树正常打印
:return:
"""
- r = sql.query_privileges.query_priv_check(user=self.user,
- instance=self.slave, db_name=self.db_name,
- sql_content="select * from archery.sql_users;",
- limit_num=100)
- self.assertDictEqual(r, {'status': 2, 'msg': '你无archery.sql_users表的查询权限!请先到查询权限管理进行申请',
- 'data': {'priv_check': True, 'limit_num': 0}})
+ r = sql.query_privileges.query_priv_check(
+ user=self.user,
+ instance=self.slave,
+ db_name=self.db_name,
+ sql_content="select * from archery.sql_users;",
+ limit_num=100,
+ )
+ self.assertDictEqual(
+ r,
+ {
+ "status": 2,
+ "msg": "你无archery.sql_users表的查询权限!请先到查询权限管理进行申请",
+ "data": {"priv_check": True, "limit_num": 0},
+ },
+ )
- @patch('sql.query_privileges._table_ref', return_value=[{'schema': 'archery', 'name': 'sql_users'}])
- @patch('sql.query_privileges._tb_priv', return_value=False)
- @patch('sql.query_privileges._db_priv', return_value=1000)
+ @patch(
+ "sql.query_privileges._table_ref",
+ return_value=[{"schema": "archery", "name": "sql_users"}],
+ )
+ @patch("sql.query_privileges._tb_priv", return_value=False)
+ @patch("sql.query_privileges._db_priv", return_value=1000)
def test_query_priv_check_db_priv_exist(self, __db_priv, __tb_priv, __table_ref):
"""
测试用户权限校验,mysql实例、普通用户 有库权限,inception语法树正常打印
:return:
"""
- r = sql.query_privileges.query_priv_check(user=self.user,
- instance=self.slave, db_name=self.db_name,
- sql_content="select * from archery.sql_users;",
- limit_num=100)
- self.assertDictEqual(r, {'data': {'limit_num': 100, 'priv_check': True}, 'msg': 'ok', 'status': 0})
+ r = sql.query_privileges.query_priv_check(
+ user=self.user,
+ instance=self.slave,
+ db_name=self.db_name,
+ sql_content="select * from archery.sql_users;",
+ limit_num=100,
+ )
+ self.assertDictEqual(
+ r,
+ {"data": {"limit_num": 100, "priv_check": True}, "msg": "ok", "status": 0},
+ )
- @patch('sql.query_privileges._table_ref', return_value=[{'schema': 'archery', 'name': 'sql_users'}])
- @patch('sql.query_privileges._tb_priv', return_value=10)
- @patch('sql.query_privileges._db_priv', return_value=False)
+ @patch(
+ "sql.query_privileges._table_ref",
+ return_value=[{"schema": "archery", "name": "sql_users"}],
+ )
+ @patch("sql.query_privileges._tb_priv", return_value=10)
+ @patch("sql.query_privileges._db_priv", return_value=False)
def test_query_priv_check_tb_priv_exist(self, __db_priv, __tb_priv, __table_ref):
"""
测试用户权限校验,mysql实例、普通用户 ,有表权限,inception语法树正常打印
:return:
"""
- r = sql.query_privileges.query_priv_check(user=self.user,
- instance=self.slave, db_name=self.db_name,
- sql_content="select * from archery.sql_users;",
- limit_num=100)
- self.assertDictEqual(r, {'data': {'limit_num': 10, 'priv_check': True}, 'msg': 'ok', 'status': 0})
+ r = sql.query_privileges.query_priv_check(
+ user=self.user,
+ instance=self.slave,
+ db_name=self.db_name,
+ sql_content="select * from archery.sql_users;",
+ limit_num=100,
+ )
+ self.assertDictEqual(
+ r, {"data": {"limit_num": 10, "priv_check": True}, "msg": "ok", "status": 0}
+ )
- @patch('sql.query_privileges._table_ref')
- @patch('sql.query_privileges._tb_priv', return_value=False)
- @patch('sql.query_privileges._db_priv', return_value=False)
- def test_query_priv_check_table_ref_Exception_and_no_db_priv(self, __db_priv, __tb_priv, __table_ref):
+ @patch("sql.query_privileges._table_ref")
+ @patch("sql.query_privileges._tb_priv", return_value=False)
+ @patch("sql.query_privileges._db_priv", return_value=False)
+ def test_query_priv_check_table_ref_Exception_and_no_db_priv(
+ self, __db_priv, __tb_priv, __table_ref
+ ):
"""
测试用户权限校验,mysql实例、普通用户 ,inception语法树抛出异常
:return:
"""
- __table_ref.side_effect = RuntimeError('语法错误')
+ __table_ref.side_effect = RuntimeError("语法错误")
self.sys_config.get_all_config()
- r = sql.query_privileges.query_priv_check(user=self.user,
- instance=self.slave, db_name=self.db_name,
- sql_content="select * from archery.sql_users;",
- limit_num=100)
- self.assertDictEqual(r, {'status': 1,
- 'msg': "无法校验查询语句权限,请联系管理员,错误信息:语法错误",
- 'data': {'priv_check': True, 'limit_num': 0}})
-
- @patch('sql.query_privileges._db_priv', return_value=1000)
+ r = sql.query_privileges.query_priv_check(
+ user=self.user,
+ instance=self.slave,
+ db_name=self.db_name,
+ sql_content="select * from archery.sql_users;",
+ limit_num=100,
+ )
+ self.assertDictEqual(
+ r,
+ {
+ "status": 1,
+ "msg": "无法校验查询语句权限,请联系管理员,错误信息:语法错误",
+ "data": {"priv_check": True, "limit_num": 0},
+ },
+ )
+
+ @patch("sql.query_privileges._db_priv", return_value=1000)
def test_query_priv_check_not_mysql_db_priv_exist(self, __db_priv):
"""
测试用户权限校验,非mysql实例、普通用户 有库权限
:return:
"""
- mssql_instance = Instance(instance_name='mssql', type='slave', db_type='mssql',
- host='some_host', port=3306, user='some_user', password='some_str')
- r = sql.query_privileges.query_priv_check(user=self.user,
- instance=mssql_instance, db_name=self.db_name,
- sql_content="select * from archery.sql_users;",
- limit_num=100)
- self.assertDictEqual(r, {'data': {'limit_num': 100, 'priv_check': True}, 'msg': 'ok', 'status': 0})
+ mssql_instance = Instance(
+ instance_name="mssql",
+ type="slave",
+ db_type="mssql",
+ host="some_host",
+ port=3306,
+ user="some_user",
+ password="some_str",
+ )
+ r = sql.query_privileges.query_priv_check(
+ user=self.user,
+ instance=mssql_instance,
+ db_name=self.db_name,
+ sql_content="select * from archery.sql_users;",
+ limit_num=100,
+ )
+ self.assertDictEqual(
+ r,
+ {"data": {"limit_num": 100, "priv_check": True}, "msg": "ok", "status": 0},
+ )
- @patch('sql.query_privileges._db_priv', return_value=False)
+ @patch("sql.query_privileges._db_priv", return_value=False)
def test_query_priv_check_not_mysql_db_priv_not_exist(self, __db_priv):
"""
测试用户权限校验,非mysql实例、普通用户 无库权限
:return:
"""
- mssql_instance = Instance(instance_name='mssql', type='slave', db_type='oracle',
- host='some_host', port=3306, user='some_user', password='some_str')
- r = sql.query_privileges.query_priv_check(user=self.user,
- instance=mssql_instance, db_name=self.db_name,
- sql_content="select * from archery.sql_users;",
- limit_num=100)
- self.assertDictEqual(r, {'data': {'limit_num': 0, 'priv_check': True},
- 'msg': '你无archery数据库的查询权限!请先到查询权限管理进行申请',
- 'status': 2})
+ mssql_instance = Instance(
+ instance_name="mssql",
+ type="slave",
+ db_type="oracle",
+ host="some_host",
+ port=3306,
+ user="some_user",
+ password="some_str",
+ )
+ r = sql.query_privileges.query_priv_check(
+ user=self.user,
+ instance=mssql_instance,
+ db_name=self.db_name,
+ sql_content="select * from archery.sql_users;",
+ limit_num=100,
+ )
+ self.assertDictEqual(
+ r,
+ {
+ "data": {"limit_num": 0, "priv_check": True},
+ "msg": "你无archery数据库的查询权限!请先到查询权限管理进行申请",
+ "status": 2,
+ },
+ )
class TestQueryPrivilegesApply(TestCase):
"""测试权限列表、权限管理"""
def setUp(self):
- self.superuser = User.objects.create(username='super', is_superuser=True)
- self.user = User.objects.create(username='user')
+ self.superuser = User.objects.create(username="super", is_superuser=True)
+ self.user = User.objects.create(username="user")
# 使用 travis.ci 时实例和测试service保持一致
- self.slave = Instance.objects.create(instance_name='test_instance', type='slave', db_type='mysql',
- host=settings.DATABASES['default']['HOST'],
- port=settings.DATABASES['default']['PORT'],
- user=settings.DATABASES['default']['USER'],
- password=settings.DATABASES['default']['PASSWORD'])
- self.db_name = settings.DATABASES['default']['TEST']['NAME']
+ self.slave = Instance.objects.create(
+ instance_name="test_instance",
+ type="slave",
+ db_type="mysql",
+ host=settings.DATABASES["default"]["HOST"],
+ port=settings.DATABASES["default"]["PORT"],
+ user=settings.DATABASES["default"]["USER"],
+ password=settings.DATABASES["default"]["PASSWORD"],
+ )
+ self.db_name = settings.DATABASES["default"]["TEST"]["NAME"]
self.sys_config = SysConfig()
self.client = Client()
tomorrow = datetime.today() + timedelta(days=1)
- self.group = ResourceGroup.objects.create(group_id=1, group_name='group_name')
+ self.group = ResourceGroup.objects.create(group_id=1, group_name="group_name")
self.query_apply_1 = QueryPrivilegesApply.objects.create(
group_id=self.group.group_id,
group_name=self.group.group_name,
- title='some_title1',
- user_name='some_user',
+ title="some_title1",
+ user_name="some_user",
instance=self.slave,
- db_list='some_db,some_db2',
+ db_list="some_db,some_db2",
limit_num=100,
valid_date=tomorrow,
priv_type=1,
status=0,
- audit_auth_groups='some_audit_group'
+ audit_auth_groups="some_audit_group",
)
self.query_apply_2 = QueryPrivilegesApply.objects.create(
group_id=2,
- group_name='some_group2',
- title='some_title2',
- user_name='some_user',
+ group_name="some_group2",
+ title="some_title2",
+ user_name="some_user",
instance=self.slave,
- db_list='some_db',
- table_list='some_table,some_tb2',
+ db_list="some_db",
+ table_list="some_table,some_tb2",
limit_num=100,
valid_date=tomorrow,
priv_type=2,
status=0,
- audit_auth_groups='some_audit_group'
+ audit_auth_groups="some_audit_group",
)
def tearDown(self):
@@ -708,189 +887,263 @@ def tearDown(self):
def test_query_audit_call_back(self):
"""测试权限申请工单回调"""
# 工单状态改为审核失败, 验证工单状态
- sql.query_privileges._query_apply_audit_call_back(self.query_apply_1.apply_id, 2)
+ sql.query_privileges._query_apply_audit_call_back(
+ self.query_apply_1.apply_id, 2
+ )
self.query_apply_1.refresh_from_db()
self.assertEqual(self.query_apply_1.status, 2)
- for db in self.query_apply_1.db_list.split(','):
- self.assertEqual(len(QueryPrivileges.objects.filter(
- user_name=self.query_apply_1.user_name,
- db_name=db,
- limit_num=100)), 0)
+ for db in self.query_apply_1.db_list.split(","):
+ self.assertEqual(
+ len(
+ QueryPrivileges.objects.filter(
+ user_name=self.query_apply_1.user_name,
+ db_name=db,
+ limit_num=100,
+ )
+ ),
+ 0,
+ )
# 工单改为审核成功, 验证工单状态和权限状态
- sql.query_privileges._query_apply_audit_call_back(self.query_apply_1.apply_id, 1)
+ sql.query_privileges._query_apply_audit_call_back(
+ self.query_apply_1.apply_id, 1
+ )
self.query_apply_1.refresh_from_db()
self.assertEqual(self.query_apply_1.status, 1)
- for db in self.query_apply_1.db_list.split(','):
- self.assertEqual(len(QueryPrivileges.objects.filter(
- user_name=self.query_apply_1.user_name,
- db_name=db,
- limit_num=100)), 1)
+ for db in self.query_apply_1.db_list.split(","):
+ self.assertEqual(
+ len(
+ QueryPrivileges.objects.filter(
+ user_name=self.query_apply_1.user_name,
+ db_name=db,
+ limit_num=100,
+ )
+ ),
+ 1,
+ )
# 表权限申请测试, 只测试审核成功
- sql.query_privileges._query_apply_audit_call_back(self.query_apply_2.apply_id, 1)
+ sql.query_privileges._query_apply_audit_call_back(
+ self.query_apply_2.apply_id, 1
+ )
self.query_apply_2.refresh_from_db()
self.assertEqual(self.query_apply_2.status, 1)
- for tb in self.query_apply_2.table_list.split(','):
- self.assertEqual(len(QueryPrivileges.objects.filter(
- user_name=self.query_apply_2.user_name,
- db_name=self.query_apply_2.db_list,
- table_name=tb,
- limit_num=self.query_apply_2.limit_num)), 1)
+ for tb in self.query_apply_2.table_list.split(","):
+ self.assertEqual(
+ len(
+ QueryPrivileges.objects.filter(
+ user_name=self.query_apply_2.user_name,
+ db_name=self.query_apply_2.db_list,
+ table_name=tb,
+ limit_num=self.query_apply_2.limit_num,
+ )
+ ),
+ 1,
+ )
def test_query_priv_apply_list_super_with_search(self):
"""
测试权限申请列表,管理员查看所有用户,并且搜索
"""
- data = {
- "limit": 14,
- "offset": 0,
- "search": 'some_title1'
- }
+ data = {"limit": 14, "offset": 0, "search": "some_title1"}
self.client.force_login(self.superuser)
- r = self.client.post(path='/query/applylist/', data=data)
- self.assertEqual(json.loads(r.content)['total'], 1)
- keys = list(json.loads(r.content)['rows'][0].keys())
- self.assertListEqual(keys,
- ['apply_id', 'title', 'instance__instance_name', 'db_list', 'priv_type', 'table_list',
- 'limit_num', 'valid_date', 'user_display', 'status', 'create_time', 'group_name'])
+ r = self.client.post(path="/query/applylist/", data=data)
+ self.assertEqual(json.loads(r.content)["total"], 1)
+ keys = list(json.loads(r.content)["rows"][0].keys())
+ self.assertListEqual(
+ keys,
+ [
+ "apply_id",
+ "title",
+ "instance__instance_name",
+ "db_list",
+ "priv_type",
+ "table_list",
+ "limit_num",
+ "valid_date",
+ "user_display",
+ "status",
+ "create_time",
+ "group_name",
+ ],
+ )
def test_query_priv_apply_list_with_query_review_perm(self):
"""
测试权限申请列表,普通用户,拥有sql.query_review权限,在组内
"""
- data = {
- "limit": 14,
- "offset": 0,
- "search": ''
- }
+ data = {"limit": 14, "offset": 0, "search": ""}
- menu_queryapplylist = Permission.objects.get(codename='menu_queryapplylist')
+ menu_queryapplylist = Permission.objects.get(codename="menu_queryapplylist")
self.user.user_permissions.add(menu_queryapplylist)
- query_review = Permission.objects.get(codename='query_review')
+ query_review = Permission.objects.get(codename="query_review")
self.user.user_permissions.add(query_review)
self.user.resource_group.add(self.group)
self.client.force_login(self.user)
- r = self.client.post(path='/query/applylist/', data=data)
- self.assertEqual(json.loads(r.content)['total'], 1)
- keys = list(json.loads(r.content)['rows'][0].keys())
- self.assertListEqual(keys,
- ['apply_id', 'title', 'instance__instance_name', 'db_list', 'priv_type', 'table_list',
- 'limit_num', 'valid_date', 'user_display', 'status', 'create_time', 'group_name'])
+ r = self.client.post(path="/query/applylist/", data=data)
+ self.assertEqual(json.loads(r.content)["total"], 1)
+ keys = list(json.loads(r.content)["rows"][0].keys())
+ self.assertListEqual(
+ keys,
+ [
+ "apply_id",
+ "title",
+ "instance__instance_name",
+ "db_list",
+ "priv_type",
+ "table_list",
+ "limit_num",
+ "valid_date",
+ "user_display",
+ "status",
+ "create_time",
+ "group_name",
+ ],
+ )
def test_query_priv_apply_list_no_query_review_perm(self):
"""
测试权限申请列表,普通用户,无sql.query_review权限,在组内
"""
- data = {
- "limit": 14,
- "offset": 0,
- "search": ''
- }
+ data = {"limit": 14, "offset": 0, "search": ""}
- menu_queryapplylist = Permission.objects.get(codename='menu_queryapplylist')
+ menu_queryapplylist = Permission.objects.get(codename="menu_queryapplylist")
self.user.user_permissions.add(menu_queryapplylist)
self.user.resource_group.add(self.group)
self.client.force_login(self.user)
- r = self.client.post(path='/query/applylist/', data=data)
+ r = self.client.post(path="/query/applylist/", data=data)
self.assertEqual(json.loads(r.content), {"total": 0, "rows": []})
def test_user_query_priv_with_search(self):
"""
测试权限申请列表,管理员查看所有用户,并且搜索
"""
- data = {
- "limit": 14,
- "offset": 0,
- "search": 'user'
- }
- QueryPrivileges.objects.create(user_name=self.user.username,
- user_display='user2',
- instance=self.slave,
- db_name=self.db_name,
- table_name='table_name',
- valid_date=date.today() + timedelta(days=1),
- limit_num=10,
- priv_type=2)
+ data = {"limit": 14, "offset": 0, "search": "user"}
+ QueryPrivileges.objects.create(
+ user_name=self.user.username,
+ user_display="user2",
+ instance=self.slave,
+ db_name=self.db_name,
+ table_name="table_name",
+ valid_date=date.today() + timedelta(days=1),
+ limit_num=10,
+ priv_type=2,
+ )
self.client.force_login(self.superuser)
- r = self.client.post(path='/query/userprivileges/', data=data)
- self.assertEqual(json.loads(r.content)['total'], 1)
- keys = list(json.loads(r.content)['rows'][0].keys())
- self.assertListEqual(keys,
- ['privilege_id', 'user_display', 'instance__instance_name', 'db_name', 'priv_type',
- 'table_name', 'limit_num', 'valid_date'])
+ r = self.client.post(path="/query/userprivileges/", data=data)
+ self.assertEqual(json.loads(r.content)["total"], 1)
+ keys = list(json.loads(r.content)["rows"][0].keys())
+ self.assertListEqual(
+ keys,
+ [
+ "privilege_id",
+ "user_display",
+ "instance__instance_name",
+ "db_name",
+ "priv_type",
+ "table_name",
+ "limit_num",
+ "valid_date",
+ ],
+ )
def test_user_query_priv_with_query_mgtpriv(self):
"""
测试权限申请列表,普通用户,拥有sql.query_mgtpriv权限,在组内
"""
- data = {
- "limit": 14,
- "offset": 0,
- "search": 'user'
- }
- QueryPrivileges.objects.create(user_name='some_name',
- user_display='user2',
- instance=self.slave,
- db_name=self.db_name,
- table_name='table_name',
- valid_date=date.today() + timedelta(days=1),
- limit_num=10,
- priv_type=2)
- menu_queryapplylist = Permission.objects.get(codename='menu_queryapplylist')
+ data = {"limit": 14, "offset": 0, "search": "user"}
+ QueryPrivileges.objects.create(
+ user_name="some_name",
+ user_display="user2",
+ instance=self.slave,
+ db_name=self.db_name,
+ table_name="table_name",
+ valid_date=date.today() + timedelta(days=1),
+ limit_num=10,
+ priv_type=2,
+ )
+ menu_queryapplylist = Permission.objects.get(codename="menu_queryapplylist")
self.user.user_permissions.add(menu_queryapplylist)
- query_mgtpriv = Permission.objects.get(codename='query_mgtpriv')
+ query_mgtpriv = Permission.objects.get(codename="query_mgtpriv")
self.user.user_permissions.add(query_mgtpriv)
self.user.resource_group.add(self.group)
self.client.force_login(self.user)
- r = self.client.post(path='/query/userprivileges/', data=data)
- self.assertEqual(json.loads(r.content)['total'], 1)
- keys = list(json.loads(r.content)['rows'][0].keys())
- self.assertListEqual(keys,
- ['privilege_id', 'user_display', 'instance__instance_name', 'db_name', 'priv_type',
- 'table_name', 'limit_num', 'valid_date'])
+ r = self.client.post(path="/query/userprivileges/", data=data)
+ self.assertEqual(json.loads(r.content)["total"], 1)
+ keys = list(json.loads(r.content)["rows"][0].keys())
+ self.assertListEqual(
+ keys,
+ [
+ "privilege_id",
+ "user_display",
+ "instance__instance_name",
+ "db_name",
+ "priv_type",
+ "table_name",
+ "limit_num",
+ "valid_date",
+ ],
+ )
def test_user_query_priv_no_query_mgtpriv(self):
"""
测试权限申请列表,普通用户,没有sql.query_mgtpriv权限,在组内
"""
- data = {
- "limit": 14,
- "offset": 0,
- "search": 'user'
- }
- QueryPrivileges.objects.create(user_name='some_name',
- user_display='user2',
- instance=self.slave,
- db_name=self.db_name,
- table_name='table_name',
- valid_date=date.today() + timedelta(days=1),
- limit_num=10,
- priv_type=2)
- menu_queryapplylist = Permission.objects.get(codename='menu_queryapplylist')
+ data = {"limit": 14, "offset": 0, "search": "user"}
+ QueryPrivileges.objects.create(
+ user_name="some_name",
+ user_display="user2",
+ instance=self.slave,
+ db_name=self.db_name,
+ table_name="table_name",
+ valid_date=date.today() + timedelta(days=1),
+ limit_num=10,
+ priv_type=2,
+ )
+ menu_queryapplylist = Permission.objects.get(codename="menu_queryapplylist")
self.user.user_permissions.add(menu_queryapplylist)
self.user.resource_group.add(self.group)
self.client.force_login(self.user)
- r = self.client.post(path='/query/userprivileges/', data=data)
+ r = self.client.post(path="/query/userprivileges/", data=data)
self.assertEqual(json.loads(r.content), {"total": 0, "rows": []})
class TestQuery(TransactionTestCase):
def setUp(self):
- self.slave1 = Instance(instance_name='test_slave_instance', type='slave', db_type='mysql',
- host='testhost', port=3306, user='mysql_user', password='mysql_password')
- self.slave2 = Instance(instance_name='test_instance_non_mysql', type='slave', db_type='mssql',
- host='some_host2', port=3306, user='some_user', password='some_str')
+ self.slave1 = Instance(
+ instance_name="test_slave_instance",
+ type="slave",
+ db_type="mysql",
+ host="testhost",
+ port=3306,
+ user="mysql_user",
+ password="mysql_password",
+ )
+ self.slave2 = Instance(
+ instance_name="test_instance_non_mysql",
+ type="slave",
+ db_type="mssql",
+ host="some_host2",
+ port=3306,
+ user="some_user",
+ password="some_str",
+ )
self.slave1.save()
self.slave2.save()
- self.superuser1 = User.objects.create(username='super1', is_superuser=True)
- self.u1 = User.objects.create(username='test_user', display='中文显示', is_active=True)
- self.u2 = User.objects.create(username='test_user2', display='中文显示', is_active=True)
- self.query_log = QueryLog.objects.create(instance_name=self.slave1.instance_name,
- db_name='some_db',
- sqllog='select 1;',
- effect_row=10,
- cost_time=1,
- username=self.superuser1.username)
- sql_query_perm = Permission.objects.get(codename='query_submit')
+ self.superuser1 = User.objects.create(username="super1", is_superuser=True)
+ self.u1 = User.objects.create(
+ username="test_user", display="中文显示", is_active=True
+ )
+ self.u2 = User.objects.create(
+ username="test_user2", display="中文显示", is_active=True
+ )
+ self.query_log = QueryLog.objects.create(
+ instance_name=self.slave1.instance_name,
+ db_name="some_db",
+ sqllog="select 1;",
+ effect_row=10,
+ cost_time=1,
+ username=self.superuser1.username,
+ )
+ sql_query_perm = Permission.objects.get(codename="query_submit")
self.u2.user_permissions.add(sql_query_perm)
def tearDown(self):
@@ -902,90 +1155,139 @@ def tearDown(self):
self.slave1.delete()
self.slave2.delete()
archer_config = SysConfig()
- archer_config.set('disable_star', False)
+ archer_config.set("disable_star", False)
- @patch('sql.query.user_instances')
- @patch('sql.query.get_engine')
- @patch('sql.query.query_priv_check')
+ @patch("sql.query.user_instances")
+ @patch("sql.query.get_engine")
+ @patch("sql.query.query_priv_check")
def testCorrectSQL(self, _priv_check, _get_engine, _user_instances):
c = Client()
- some_sql = 'select some from some_table limit 100;'
- some_db = 'some_db'
+ some_sql = "select some from some_table limit 100;"
+ some_db = "some_db"
some_limit = 100
c.force_login(self.u1)
- r = c.post('/query/', data={'instance_name': self.slave1.instance_name,
- 'sql_content': some_sql,
- 'db_name': some_db,
- 'limit_num': some_limit})
+ r = c.post(
+ "/query/",
+ data={
+ "instance_name": self.slave1.instance_name,
+ "sql_content": some_sql,
+ "db_name": some_db,
+ "limit_num": some_limit,
+ },
+ )
self.assertEqual(r.status_code, 403)
c.force_login(self.u2)
- q_result = ResultSet(full_sql=some_sql, rows=['value'])
- q_result.column_list = ['some']
+ q_result = ResultSet(full_sql=some_sql, rows=["value"])
+ q_result.column_list = ["some"]
_get_engine.return_value.query_check.return_value = {
- 'msg': '', 'bad_query': False, 'filtered_sql': some_sql, 'has_star': False}
+ "msg": "",
+ "bad_query": False,
+ "filtered_sql": some_sql,
+ "has_star": False,
+ }
_get_engine.return_value.filter_sql.return_value = some_sql
_get_engine.return_value.query.return_value = q_result
_get_engine.return_value.seconds_behind_master = 100
- _priv_check.return_value = {'status': 0, 'data': {'limit_num': 100, 'priv_check': True}}
+ _priv_check.return_value = {
+ "status": 0,
+ "data": {"limit_num": 100, "priv_check": True},
+ }
_user_instances.return_value.get.return_value = self.slave1
- r = c.post('/query/', data={'instance_name': self.slave1.instance_name,
- 'sql_content': some_sql,
- 'db_name': some_db,
- 'limit_num': some_limit})
+ r = c.post(
+ "/query/",
+ data={
+ "instance_name": self.slave1.instance_name,
+ "sql_content": some_sql,
+ "db_name": some_db,
+ "limit_num": some_limit,
+ },
+ )
_get_engine.return_value.query.assert_called_once_with(
- some_db, some_sql, some_limit, schema_name=None, tb_name=None, max_execution_time=60000)
+ some_db,
+ some_sql,
+ some_limit,
+ schema_name=None,
+ tb_name=None,
+ max_execution_time=60000,
+ )
r_json = r.json()
- self.assertEqual(r_json['data']['rows'], ['value'])
- self.assertEqual(r_json['data']['column_list'], ['some'])
- self.assertEqual(r_json['data']['seconds_behind_master'], 100)
+ self.assertEqual(r_json["data"]["rows"], ["value"])
+ self.assertEqual(r_json["data"]["column_list"], ["some"])
+ self.assertEqual(r_json["data"]["seconds_behind_master"], 100)
- @patch('sql.query.user_instances')
- @patch('sql.query.get_engine')
- @patch('sql.query.query_priv_check')
+ @patch("sql.query.user_instances")
+ @patch("sql.query.get_engine")
+ @patch("sql.query.query_priv_check")
def testSQLWithoutLimit(self, _priv_check, _get_engine, _user_instances):
c = Client()
some_limit = 100
- sql_without_limit = 'select some from some_table'
- sql_with_limit = 'select some from some_table limit {0};'.format(some_limit)
- some_db = 'some_db'
+ sql_without_limit = "select some from some_table"
+ sql_with_limit = "select some from some_table limit {0};".format(some_limit)
+ some_db = "some_db"
c.force_login(self.u2)
- q_result = ResultSet(full_sql=sql_without_limit, rows=['value'])
- q_result.column_list = ['some']
+ q_result = ResultSet(full_sql=sql_without_limit, rows=["value"])
+ q_result.column_list = ["some"]
_get_engine.return_value.query_check.return_value = {
- 'msg': '', 'bad_query': False, 'filtered_sql': sql_without_limit, 'has_star': False}
+ "msg": "",
+ "bad_query": False,
+ "filtered_sql": sql_without_limit,
+ "has_star": False,
+ }
_get_engine.return_value.filter_sql.return_value = sql_with_limit
_get_engine.return_value.query.return_value = q_result
- _priv_check.return_value = {'status': 0, 'data': {'limit_num': 100, 'priv_check': True}}
+ _priv_check.return_value = {
+ "status": 0,
+ "data": {"limit_num": 100, "priv_check": True},
+ }
_user_instances.return_value.get.return_value = self.slave1
- r = c.post('/query/', data={'instance_name': self.slave1.instance_name,
- 'sql_content': sql_without_limit,
- 'db_name': some_db,
- 'limit_num': some_limit})
+ r = c.post(
+ "/query/",
+ data={
+ "instance_name": self.slave1.instance_name,
+ "sql_content": sql_without_limit,
+ "db_name": some_db,
+ "limit_num": some_limit,
+ },
+ )
_get_engine.return_value.query.assert_called_once_with(
- some_db, sql_with_limit, some_limit, schema_name=None, tb_name=None, max_execution_time=60000)
+ some_db,
+ sql_with_limit,
+ some_limit,
+ schema_name=None,
+ tb_name=None,
+ max_execution_time=60000,
+ )
r_json = r.json()
- self.assertEqual(r_json['data']['rows'], ['value'])
- self.assertEqual(r_json['data']['column_list'], ['some'])
+ self.assertEqual(r_json["data"]["rows"], ["value"])
+ self.assertEqual(r_json["data"]["column_list"], ["some"])
- @patch('sql.query.query_priv_check')
+ @patch("sql.query.query_priv_check")
def testStarOptionOn(self, _priv_check):
c = Client()
c.force_login(self.u2)
some_limit = 100
- sql_with_star = 'select * from some_table'
- some_db = 'some_db'
- _priv_check.return_value = {'status': 0, 'data': {'limit_num': 100, 'priv_check': True}}
+ sql_with_star = "select * from some_table"
+ some_db = "some_db"
+ _priv_check.return_value = {
+ "status": 0,
+ "data": {"limit_num": 100, "priv_check": True},
+ }
archer_config = SysConfig()
- archer_config.set('disable_star', True)
- r = c.post('/query/', data={'instance_name': self.slave1.instance_name,
- 'sql_content': sql_with_star,
- 'db_name': some_db,
- 'limit_num': some_limit})
- archer_config.set('disable_star', False)
+ archer_config.set("disable_star", True)
+ r = c.post(
+ "/query/",
+ data={
+ "instance_name": self.slave1.instance_name,
+ "sql_content": sql_with_star,
+ "db_name": some_db,
+ "limit_num": some_limit,
+ },
+ )
+ archer_config.set("disable_star", False)
r_json = r.json()
- self.assertEqual(1, r_json['status'])
+ self.assertEqual(1, r_json["status"])
- @patch('sql.query.get_engine')
+ @patch("sql.query.get_engine")
def test_kill_query_conn(self, _get_engine):
kill_query_conn(self.slave1.id, 10)
_get_engine.return_value.kill_connection.return_value = ResultSet()
@@ -994,110 +1296,121 @@ def test_query_log(self):
"""测试获取查询历史"""
c = Client()
c.force_login(self.superuser1)
- QueryLog(id=self.query_log.id, favorite=True, alias='test_a').save(update_fields=['favorite', 'alias'])
- data = {"star": "true",
- "query_log_id": self.query_log.id,
- "limit": 14,
- "offset": 0, }
- r = c.get('/query/querylog/', data=data)
- self.assertEqual(r.json()['total'], 1)
+ QueryLog(id=self.query_log.id, favorite=True, alias="test_a").save(
+ update_fields=["favorite", "alias"]
+ )
+ data = {
+ "star": "true",
+ "query_log_id": self.query_log.id,
+ "limit": 14,
+ "offset": 0,
+ }
+ r = c.get("/query/querylog/", data=data)
+ self.assertEqual(r.json()["total"], 1)
def test_star(self):
"""测试查询语句收藏"""
c = Client()
c.force_login(self.superuser1)
- r = c.post('/query/favorite/', data={'query_log_id': self.query_log.id,
- 'star': 'true',
- 'alias': 'test_alias'})
+ r = c.post(
+ "/query/favorite/",
+ data={
+ "query_log_id": self.query_log.id,
+ "star": "true",
+ "alias": "test_alias",
+ },
+ )
query_log = QueryLog.objects.get(id=self.query_log.id)
self.assertTrue(query_log.favorite)
- self.assertEqual(query_log.alias, 'test_alias')
+ self.assertEqual(query_log.alias, "test_alias")
def test_un_star(self):
"""测试查询语句取消收藏"""
c = Client()
c.force_login(self.superuser1)
- r = c.post('/query/favorite/', data={'query_log_id': self.query_log.id,
- 'star': 'false',
- 'alias': ''})
+ r = c.post(
+ "/query/favorite/",
+ data={"query_log_id": self.query_log.id, "star": "false", "alias": ""},
+ )
r_json = r.json()
query_log = QueryLog.objects.get(id=self.query_log.id)
self.assertFalse(query_log.favorite)
- self.assertEqual(query_log.alias, '')
+ self.assertEqual(query_log.alias, "")
class TestWorkflowView(TransactionTestCase):
-
def setUp(self):
self.now = datetime.now()
- can_view_permission = Permission.objects.get(codename='menu_sqlworkflow')
- can_execute_permission = Permission.objects.get(codename='sql_execute')
- can_execute_resource_permission = Permission.objects.get(codename='sql_execute_for_resource_group')
- self.u1 = User(username='some_user', display='用户1')
+ can_view_permission = Permission.objects.get(codename="menu_sqlworkflow")
+ can_execute_permission = Permission.objects.get(codename="sql_execute")
+ can_execute_resource_permission = Permission.objects.get(
+ codename="sql_execute_for_resource_group"
+ )
+ self.u1 = User(username="some_user", display="用户1")
self.u1.save()
self.u1.user_permissions.add(can_view_permission)
- self.u2 = User(username='some_user2', display='用户2')
+ self.u2 = User(username="some_user2", display="用户2")
self.u2.save()
self.u2.user_permissions.add(can_view_permission)
- self.u3 = User(username='some_user3', display='用户3')
+ self.u3 = User(username="some_user3", display="用户3")
self.u3.save()
self.u3.user_permissions.add(can_view_permission)
- self.executor1 = User(username='some_executor', display='执行者')
+ self.executor1 = User(username="some_executor", display="执行者")
self.executor1.save()
- self.executor1.user_permissions.add(can_view_permission, can_execute_permission,
- can_execute_resource_permission)
- self.superuser1 = User(username='super1', is_superuser=True)
+ self.executor1.user_permissions.add(
+ can_view_permission, can_execute_permission, can_execute_resource_permission
+ )
+ self.superuser1 = User(username="super1", is_superuser=True)
self.superuser1.save()
- self.master1 = Instance(instance_name='test_master_instance', type='master', db_type='mysql',
- host='testhost', port=3306, user='mysql_user', password='mysql_password')
+ self.master1 = Instance(
+ instance_name="test_master_instance",
+ type="master",
+ db_type="mysql",
+ host="testhost",
+ port=3306,
+ user="mysql_user",
+ password="mysql_password",
+ )
self.master1.save()
self.wf1 = SqlWorkflow.objects.create(
- workflow_name='some_name',
+ workflow_name="some_name",
group_id=1,
- group_name='g1',
+ group_name="g1",
engineer=self.u1.username,
engineer_display=self.u1.display,
- audit_auth_groups='some_group',
+ audit_auth_groups="some_group",
create_time=self.now - timedelta(days=1),
- status='workflow_finish',
+ status="workflow_finish",
is_backup=True,
instance=self.master1,
- db_name='some_db',
+ db_name="some_db",
syntax_type=1,
)
self.wfc1 = SqlWorkflowContent.objects.create(
workflow=self.wf1,
- sql_content='some_sql',
- execute_result=json.dumps([{
- 'id': 1,
- 'sql': 'some_content'
- }])
+ sql_content="some_sql",
+ execute_result=json.dumps([{"id": 1, "sql": "some_content"}]),
)
self.wf2 = SqlWorkflow.objects.create(
- workflow_name='some_name2',
+ workflow_name="some_name2",
group_id=1,
- group_name='g1',
+ group_name="g1",
engineer=self.u2.username,
engineer_display=self.u2.display,
- audit_auth_groups='some_group',
+ audit_auth_groups="some_group",
create_time=self.now - timedelta(days=1),
- status='workflow_manreviewing',
+ status="workflow_manreviewing",
is_backup=True,
instance=self.master1,
- db_name='some_db',
- syntax_type=1
+ db_name="some_db",
+ syntax_type=1,
)
self.wfc2 = SqlWorkflowContent.objects.create(
workflow=self.wf2,
- sql_content='some_sql',
- execute_result=json.dumps([{
- 'id': 1,
- 'sql': 'some_content'
- }])
- )
- self.resource_group1 = ResourceGroup(
- group_name='some_group'
+ sql_content="some_sql",
+ execute_result=json.dumps([{"id": 1, "sql": "some_content"}]),
)
+ self.resource_group1 = ResourceGroup(group_name="some_group")
self.resource_group1.save()
def tearDown(self):
@@ -1113,20 +1426,21 @@ def testWorkflowStatus(self):
"""测试获取工单状态"""
c = Client(header={})
c.force_login(self.u1)
- r = c.post('/getWorkflowStatus/', {'workflow_id': self.wf1.id})
+ r = c.post("/getWorkflowStatus/", {"workflow_id": self.wf1.id})
r_json = r.json()
- self.assertEqual(r_json['status'], 'workflow_finish')
+ self.assertEqual(r_json["status"], "workflow_finish")
def test_check_param_is_None(self):
"""测试工单检测,参数内容为空"""
c = Client()
c.force_login(self.superuser1)
data = {"instance_name": self.master1.instance_name}
- r = c.post('/simplecheck/', data=data)
- self.assertDictEqual(json.loads(r.content),
- {'status': 1, 'msg': '页面提交参数可能为空', 'data': {}})
+ r = c.post("/simplecheck/", data=data)
+ self.assertDictEqual(
+ json.loads(r.content), {"status": 1, "msg": "页面提交参数可能为空", "data": {}}
+ )
- @patch('sql.sql_workflow.get_engine')
+ @patch("sql.sql_workflow.get_engine")
def test_check_inception_Exception(self, _get_engine):
"""测试工单检测,inception报错"""
c = Client()
@@ -1136,14 +1450,13 @@ def test_check_inception_Exception(self, _get_engine):
"instance_name": self.master1.instance_name,
"db_name": "archery",
}
- _get_engine.side_effect = RuntimeError('RuntimeError')
- r = c.post('/simplecheck/', data=data)
- self.assertDictEqual(json.loads(r.content),
- {'status': 1,
- 'msg': "RuntimeError",
- 'data': {}})
-
- @patch('sql.sql_workflow.get_engine')
+ _get_engine.side_effect = RuntimeError("RuntimeError")
+ r = c.post("/simplecheck/", data=data)
+ self.assertDictEqual(
+ json.loads(r.content), {"status": 1, "msg": "RuntimeError", "data": {}}
+ )
+
+ @patch("sql.sql_workflow.get_engine")
def test_check(self, _get_engine):
"""测试工单检测,正常返回"""
c = Client()
@@ -1154,378 +1467,459 @@ def test_check(self, _get_engine):
"db_name": "archery",
}
column_list = [
- 'id', 'stage', 'errlevel', 'stagestatus', 'errormessage', 'sql', 'affected_rows', 'sequence',
- 'backup_dbname', 'execute_time', 'sqlsha1', 'backup_time', 'actual_affected_rows']
-
- rows = [ReviewResult(id=1,
- stage='CHECKED',
- errlevel=0,
- stagestatus='Audit Completed',
- errormessage='',
- sql='use `archer`',
- affected_rows=0,
- actual_affected_rows=0,
- sequence='0_0_00000000',
- backup_dbname='',
- execute_time='0',
- sqlsha1='')]
+ "id",
+ "stage",
+ "errlevel",
+ "stagestatus",
+ "errormessage",
+ "sql",
+ "affected_rows",
+ "sequence",
+ "backup_dbname",
+ "execute_time",
+ "sqlsha1",
+ "backup_time",
+ "actual_affected_rows",
+ ]
+
+ rows = [
+ ReviewResult(
+ id=1,
+ stage="CHECKED",
+ errlevel=0,
+ stagestatus="Audit Completed",
+ errormessage="",
+ sql="use `archer`",
+ affected_rows=0,
+ actual_affected_rows=0,
+ sequence="0_0_00000000",
+ backup_dbname="",
+ execute_time="0",
+ sqlsha1="",
+ )
+ ]
_get_engine.return_value.execute_check.return_value = ReviewSet(
- warning_count=0,
- error_count=0,
- column_list=column_list,
- rows=rows)
- r = c.post('/simplecheck/', data=data)
- self.assertListEqual(list(json.loads(r.content)['data'].keys()),
- ["rows", "CheckWarningCount", "CheckErrorCount"])
- self.assertListEqual(list(json.loads(r.content)['data']['rows'][0].keys()), column_list)
+ warning_count=0, error_count=0, column_list=column_list, rows=rows
+ )
+ r = c.post("/simplecheck/", data=data)
+ self.assertListEqual(
+ list(json.loads(r.content)["data"].keys()),
+ ["rows", "CheckWarningCount", "CheckErrorCount"],
+ )
+ self.assertListEqual(
+ list(json.loads(r.content)["data"]["rows"][0].keys()), column_list
+ )
def test_submit_param_is_None(self):
"""测试SQL提交,参数内容为空"""
c = Client()
c.force_login(self.superuser1)
- data = {"sql_content": "update sql_users set email='' where id>0;",
- "workflow_name": "【回滚工单】原工单Id:163+,3434434343",
- "group_name": self.resource_group1.group_name,
- "instance_name": self.master1.instance_name,
- "run_date_start": "",
- "run_date_end": "",
- "workflow_auditors": "11"}
- r = c.post('/autoreview/', data=data)
- self.assertContains(r, '页面提交参数可能为空')
-
- @patch('sql.sql_workflow.async_task')
- @patch('sql.sql_workflow.Audit')
- @patch('sql.sql_workflow.get_engine')
- @patch('sql.sql_workflow.user_instances')
- def test_submit_audit_wrong(self, _user_instances, _get_engine, _audit, _async_task):
+ data = {
+ "sql_content": "update sql_users set email='' where id>0;",
+ "workflow_name": "【回滚工单】原工单Id:163+,3434434343",
+ "group_name": self.resource_group1.group_name,
+ "instance_name": self.master1.instance_name,
+ "run_date_start": "",
+ "run_date_end": "",
+ "workflow_auditors": "11",
+ }
+ r = c.post("/autoreview/", data=data)
+ self.assertContains(r, "页面提交参数可能为空")
+
+ @patch("sql.sql_workflow.async_task")
+ @patch("sql.sql_workflow.Audit")
+ @patch("sql.sql_workflow.get_engine")
+ @patch("sql.sql_workflow.user_instances")
+ def test_submit_audit_wrong(
+ self, _user_instances, _get_engine, _audit, _async_task
+ ):
"""测试SQL提交,获取审核信息报错"""
c = Client()
c.force_login(self.superuser1)
- data = {"sql_content": "update sql_users set email='' where id>0;",
- "workflow_name": "【回滚工单】原工单Id:163+,3434434343",
- "group_name": self.resource_group1.group_name,
- "instance_name": self.master1.instance_name,
- "db_name": "archery",
- "demand_url": 'test_url',
- "run_date_start": "",
- "run_date_end": "",
- "workflow_auditors": "11"}
+ data = {
+ "sql_content": "update sql_users set email='' where id>0;",
+ "workflow_name": "【回滚工单】原工单Id:163+,3434434343",
+ "group_name": self.resource_group1.group_name,
+ "instance_name": self.master1.instance_name,
+ "db_name": "archery",
+ "demand_url": "test_url",
+ "run_date_start": "",
+ "run_date_end": "",
+ "workflow_auditors": "11",
+ }
_user_instances.return_value.get.return_value = self.master1
_get_engine.return_value.execute_check.return_value = ReviewSet(
- syntax_type=1,
- warning_count=1,
- error_count=1)
- _audit.settings = ValueError('error')
+ syntax_type=1, warning_count=1, error_count=1
+ )
+ _audit.settings = ValueError("error")
_audit.add.return_value = None
_async_task.return_value = None
- r = c.post('/autoreview/', data=data)
- self.assertContains(r, 'ValueError')
+ r = c.post("/autoreview/", data=data)
+ self.assertContains(r, "ValueError")
- @patch('sql.sql_workflow.async_task')
- @patch('sql.sql_workflow.Audit')
- @patch('sql.sql_workflow.get_engine')
- @patch('sql.sql_workflow.user_instances')
+ @patch("sql.sql_workflow.async_task")
+ @patch("sql.sql_workflow.Audit")
+ @patch("sql.sql_workflow.get_engine")
+ @patch("sql.sql_workflow.user_instances")
def test_submit(self, _user_instances, _get_engine, _audit, _async_task):
"""测试SQL提交,正常提交"""
c = Client()
c.force_login(self.superuser1)
- data = {"sql_content": "update sql_users set email='' where id>0;",
- "workflow_name": "【回滚工单】原工单Id:163+,3434434343",
- "group_name": self.resource_group1.group_name,
- "instance_name": self.master1.instance_name,
- "db_name": "archery",
- "demand_url": 'test_url',
- "run_date_start": "",
- "run_date_end": "",
- "workflow_auditors": "11"}
+ data = {
+ "sql_content": "update sql_users set email='' where id>0;",
+ "workflow_name": "【回滚工单】原工单Id:163+,3434434343",
+ "group_name": self.resource_group1.group_name,
+ "instance_name": self.master1.instance_name,
+ "db_name": "archery",
+ "demand_url": "test_url",
+ "run_date_start": "",
+ "run_date_end": "",
+ "workflow_auditors": "11",
+ }
_user_instances.return_value.get.return_value = self.master1
_get_engine.return_value.execute_check.return_value = ReviewSet(
- syntax_type=1,
- warning_count=1,
- error_count=1)
- _audit.settings.return_value = 'some_group,another_group'
+ syntax_type=1, warning_count=1, error_count=1
+ )
+ _audit.settings.return_value = "some_group,another_group"
_audit.add.return_value = None
_async_task.return_value = None
- r = c.post('/autoreview/', data=data)
- workflow_id = SqlWorkflow.objects.order_by('-id').first().id
- self.assertRedirects(r, f'/detail/{workflow_id}/', fetch_redirect_response=False)
+ r = c.post("/autoreview/", data=data)
+ workflow_id = SqlWorkflow.objects.order_by("-id").first().id
+ self.assertRedirects(
+ r, f"/detail/{workflow_id}/", fetch_redirect_response=False
+ )
- @patch('sql.utils.workflow_audit.Audit.can_review')
+ @patch("sql.utils.workflow_audit.Audit.can_review")
def test_alter_run_date_no_perm(self, _can_review):
"""测试修改可执行时间,无权限"""
- sql_review = Permission.objects.get(codename='sql_review')
+ sql_review = Permission.objects.get(codename="sql_review")
self.u1.user_permissions.add(sql_review)
_can_review.return_value = False
c = Client()
c.force_login(self.u1)
data = {"workflow_id": self.wf1.id}
- r = c.post('/alter_run_date/', data=data)
- self.assertContains(r, '你无权操作当前工单')
+ r = c.post("/alter_run_date/", data=data)
+ self.assertContains(r, "你无权操作当前工单")
- @patch('sql.utils.workflow_audit.Audit.can_review')
+ @patch("sql.utils.workflow_audit.Audit.can_review")
def test_alter_run_date(self, _can_review):
"""测试修改可执行时间,有权限"""
- sql_review = Permission.objects.get(codename='sql_review')
+ sql_review = Permission.objects.get(codename="sql_review")
self.u1.user_permissions.add(sql_review)
_can_review.return_value = True
c = Client()
c.force_login(self.u1)
data = {"workflow_id": self.wf1.id}
- r = c.post('/alter_run_date/', data=data)
- self.assertRedirects(r, f'/detail/{self.wf1.id}/', fetch_redirect_response=False)
+ r = c.post("/alter_run_date/", data=data)
+ self.assertRedirects(
+ r, f"/detail/{self.wf1.id}/", fetch_redirect_response=False
+ )
- @patch('sql.utils.workflow_audit.Audit.logs')
- @patch('sql.utils.workflow_audit.Audit.detail_by_workflow_id')
- @patch('sql.utils.workflow_audit.Audit.review_info')
- @patch('sql.utils.workflow_audit.Audit.can_review')
+ @patch("sql.utils.workflow_audit.Audit.logs")
+ @patch("sql.utils.workflow_audit.Audit.detail_by_workflow_id")
+ @patch("sql.utils.workflow_audit.Audit.review_info")
+ @patch("sql.utils.workflow_audit.Audit.can_review")
def testWorkflowDetailView(self, _can_review, _review_info, _detail_by_id, _logs):
"""测试工单详情"""
- _review_info.return_value = ('some_auth_group', 'current_auth_group')
+ _review_info.return_value = ("some_auth_group", "current_auth_group")
_can_review.return_value = False
_detail_by_id.return_value.audit_id = 123
- _logs.return_value.latest('id').operation_info = ''
+ _logs.return_value.latest("id").operation_info = ""
c = Client()
c.force_login(self.u1)
- r = c.get('/detail/{}/'.format(self.wf1.id))
+ r = c.get("/detail/{}/".format(self.wf1.id))
expected_status_display = r"""id="workflow_detail_disaply">已正常结束"""
self.assertContains(r, expected_status_display)
exepcted_status = r"""id="workflow_detail_status">workflow_finish"""
self.assertContains(r, exepcted_status)
# 测试执行详情解析失败
- self.wfc1.execute_result = 'cannotbedecode:1,:'
+ self.wfc1.execute_result = "cannotbedecode:1,:"
self.wfc1.save()
- r = c.get('/detail/{}/'.format(self.wf1.id))
+ r = c.get("/detail/{}/".format(self.wf1.id))
self.assertContains(r, expected_status_display)
self.assertContains(r, exepcted_status)
# 执行详情为空
self.wfc1.review_content = [
- {"id": 1, "stage": "CHECKED", "errlevel": 0, "stagestatus": "Audit completed", "errormessage": "None",
- "sql": "use archery", "affected_rows": 0, "sequence": "'0_0_0'", "backup_dbname": "None",
- "execute_time": "0", "sqlsha1": "", "actual_affected_rows": ""}]
- self.wfc1.execute_result = ''
+ {
+ "id": 1,
+ "stage": "CHECKED",
+ "errlevel": 0,
+ "stagestatus": "Audit completed",
+ "errormessage": "None",
+ "sql": "use archery",
+ "affected_rows": 0,
+ "sequence": "'0_0_0'",
+ "backup_dbname": "None",
+ "execute_time": "0",
+ "sqlsha1": "",
+ "actual_affected_rows": "",
+ }
+ ]
+ self.wfc1.execute_result = ""
self.wfc1.save()
- r = c.get('/detail/{}/'.format(self.wf1.id))
+ r = c.get("/detail/{}/".format(self.wf1.id))
def testWorkflowListView(self):
"""测试工单列表"""
c = Client()
c.force_login(self.superuser1)
- r = c.post('/sqlworkflow_list/', {'limit': 10, 'offset': 0, 'navStatus': ''})
+ r = c.post("/sqlworkflow_list/", {"limit": 10, "offset": 0, "navStatus": ""})
r_json = r.json()
- self.assertEqual(r_json['total'], 2)
+ self.assertEqual(r_json["total"], 2)
# 列表按创建时间倒序排列, 第二个是wf1 , 是已正常结束
- self.assertEqual(r_json['rows'][1]['status'], 'workflow_finish')
+ self.assertEqual(r_json["rows"][1]["status"], "workflow_finish")
# u1拿到u1的
c.force_login(self.u1)
- r = c.post('/sqlworkflow_list/', {'limit': 10, 'offset': 0, 'navStatus': ''})
+ r = c.post("/sqlworkflow_list/", {"limit": 10, "offset": 0, "navStatus": ""})
r_json = r.json()
- self.assertEqual(r_json['total'], 1)
- self.assertEqual(r_json['rows'][0]['id'], self.wf1.id)
+ self.assertEqual(r_json["total"], 1)
+ self.assertEqual(r_json["rows"][0]["id"], self.wf1.id)
# u3拿到None
c.force_login(self.u3)
- r = c.post('/sqlworkflow_list/', {'limit': 10, 'offset': 0, 'navStatus': ''})
+ r = c.post("/sqlworkflow_list/", {"limit": 10, "offset": 0, "navStatus": ""})
r_json = r.json()
- self.assertEqual(r_json['total'], 0)
+ self.assertEqual(r_json["total"], 0)
def testWorkflowListViewFilter(self):
"""测试工单列表筛选"""
c = Client()
c.force_login(self.superuser1)
# 工单状态
- r = c.post('/sqlworkflow_list/', {'limit': 10, 'offset': 0, 'navStatus': 'workflow_finish'})
+ r = c.post(
+ "/sqlworkflow_list/",
+ {"limit": 10, "offset": 0, "navStatus": "workflow_finish"},
+ )
r_json = r.json()
- self.assertEqual(r_json['total'], 1)
+ self.assertEqual(r_json["total"], 1)
# 列表按创建时间倒序排列
- self.assertEqual(r_json['rows'][0]['status'], 'workflow_finish')
+ self.assertEqual(r_json["rows"][0]["status"], "workflow_finish")
# 实例
- r = c.post('/sqlworkflow_list/', {'limit': 10, 'offset': 0, 'instance_id': self.wf1.instance_id})
+ r = c.post(
+ "/sqlworkflow_list/",
+ {"limit": 10, "offset": 0, "instance_id": self.wf1.instance_id},
+ )
r_json = r.json()
- self.assertEqual(r_json['total'], 2)
+ self.assertEqual(r_json["total"], 2)
# 列表按创建时间倒序排列, 第二个是wf1
- self.assertEqual(r_json['rows'][1]['workflow_name'], self.wf1.workflow_name)
+ self.assertEqual(r_json["rows"][1]["workflow_name"], self.wf1.workflow_name)
# 资源组
- r = c.post('/sqlworkflow_list/', {'limit': 10, 'offset': 0, 'resource_group_id': self.wf1.group_id})
+ r = c.post(
+ "/sqlworkflow_list/",
+ {"limit": 10, "offset": 0, "resource_group_id": self.wf1.group_id},
+ )
r_json = r.json()
- self.assertEqual(r_json['total'], 2)
+ self.assertEqual(r_json["total"], 2)
# 列表按创建时间倒序排列, 第二个是wf1
- self.assertEqual(r_json['rows'][1]['workflow_name'], self.wf1.workflow_name)
+ self.assertEqual(r_json["rows"][1]["workflow_name"], self.wf1.workflow_name)
# 时间
- start_date = datetime.strftime(self.now, '%Y-%m-%d')
- end_date = datetime.strftime(self.now, '%Y-%m-%d')
- r = c.post('/sqlworkflow_list/', {'limit': 10, 'offset': 0, 'start_date': start_date, 'end_date': end_date})
+ start_date = datetime.strftime(self.now, "%Y-%m-%d")
+ end_date = datetime.strftime(self.now, "%Y-%m-%d")
+ r = c.post(
+ "/sqlworkflow_list/",
+ {"limit": 10, "offset": 0, "start_date": start_date, "end_date": end_date},
+ )
r_json = r.json()
- self.assertEqual(r_json['total'], 2)
+ self.assertEqual(r_json["total"], 2)
- @patch('sql.utils.workflow_audit.Audit.detail_by_workflow_id')
- @patch('sql.utils.workflow_audit.Audit.audit')
- @patch('sql.utils.workflow_audit.Audit.can_review')
+ @patch("sql.utils.workflow_audit.Audit.detail_by_workflow_id")
+ @patch("sql.utils.workflow_audit.Audit.audit")
+ @patch("sql.utils.workflow_audit.Audit.can_review")
def testWorkflowPassedView(self, _can_review, _audit, _detail_by_id):
"""测试审核工单"""
c = Client()
c.force_login(self.superuser1)
- r = c.post('/passed/')
- self.assertContains(r, 'workflow_id参数为空.')
+ r = c.post("/passed/")
+ self.assertContains(r, "workflow_id参数为空.")
_can_review.return_value = False
- r = c.post('/passed/', {'workflow_id': self.wf1.id})
- self.assertContains(r, '你无权操作当前工单!')
+ r = c.post("/passed/", {"workflow_id": self.wf1.id})
+ self.assertContains(r, "你无权操作当前工单!")
_can_review.return_value = True
_detail_by_id.return_value.audit_id = 123
- _audit.return_value = {
- "data": {
- "workflow_status": 1 # TODO 改为audit_success
- }
- }
- r = c.post('/passed/', data={'workflow_id': self.wf1.id, 'audit_remark': 'some_audit'}, follow=False)
- self.assertRedirects(r, '/detail/{}/'.format(self.wf1.id), fetch_redirect_response=False)
+ _audit.return_value = {"data": {"workflow_status": 1}} # TODO 改为audit_success
+ r = c.post(
+ "/passed/",
+ data={"workflow_id": self.wf1.id, "audit_remark": "some_audit"},
+ follow=False,
+ )
+ self.assertRedirects(
+ r, "/detail/{}/".format(self.wf1.id), fetch_redirect_response=False
+ )
self.wf1.refresh_from_db()
- self.assertEqual(self.wf1.status, 'workflow_review_pass')
+ self.assertEqual(self.wf1.status, "workflow_review_pass")
- @patch('sql.sql_workflow.Audit.add_log')
- @patch('sql.sql_workflow.Audit.detail_by_workflow_id')
- @patch('sql.sql_workflow.can_execute')
+ @patch("sql.sql_workflow.Audit.add_log")
+ @patch("sql.sql_workflow.Audit.detail_by_workflow_id")
+ @patch("sql.sql_workflow.can_execute")
def test_workflow_execute(self, mock_can_excute, mock_detail_by_id, mock_add_log):
"""测试工单执行"""
c = Client()
c.force_login(self.executor1)
- r = c.post('/execute/')
- self.assertContains(r, 'workflow_id参数为空.')
+ r = c.post("/execute/")
+ self.assertContains(r, "workflow_id参数为空.")
mock_can_excute.return_value = False
- r = c.post('/execute/', data={'workflow_id': self.wf2.id})
- self.assertContains(r, '你无权操作当前工单!')
+ r = c.post("/execute/", data={"workflow_id": self.wf2.id})
+ self.assertContains(r, "你无权操作当前工单!")
mock_can_excute.return_value = True
mock_detail_by_id = 123
- r = c.post('/execute/', data={'workflow_id': self.wf2.id, 'mode': 'manual'})
+ r = c.post("/execute/", data={"workflow_id": self.wf2.id, "mode": "manual"})
self.wf2.refresh_from_db()
- self.assertEqual('workflow_finish', self.wf2.status)
+ self.assertEqual("workflow_finish", self.wf2.status)
- @patch('sql.sql_workflow.Audit.add_log')
- @patch('sql.sql_workflow.Audit.detail_by_workflow_id')
- @patch('sql.sql_workflow.Audit.audit')
+ @patch("sql.sql_workflow.Audit.add_log")
+ @patch("sql.sql_workflow.Audit.detail_by_workflow_id")
+ @patch("sql.sql_workflow.Audit.audit")
# patch view里的can_cancel 而不是原始位置的can_cancel ,因为在调用时, 已经 import 了真的 can_cancel ,会导致mock失效
# 在import 静态函数时需要注意这一点, 动态对象因为每次都会重新生成,也可以 mock 原函数/方法/对象
# 参见 : https://docs.python.org/3/library/unittest.mock.html#where-to-patch
- @patch('sql.sql_workflow.can_cancel')
+ @patch("sql.sql_workflow.can_cancel")
def testWorkflowCancelView(self, _can_cancel, _audit, _detail_by_id, _add_log):
"""测试工单驳回、取消"""
c = Client()
c.force_login(self.u2)
- r = c.post('/cancel/')
- self.assertContains(r, 'workflow_id参数为空.')
- r = c.post('/cancel/', data={'workflow_id': self.wf2.id})
- self.assertContains(r, '终止原因不能为空')
+ r = c.post("/cancel/")
+ self.assertContains(r, "workflow_id参数为空.")
+ r = c.post("/cancel/", data={"workflow_id": self.wf2.id})
+ self.assertContains(r, "终止原因不能为空")
_can_cancel.return_value = False
- r = c.post('/cancel/', data={'workflow_id': self.wf2.id, 'cancel_remark': 'some_reason'})
- self.assertContains(r, '你无权操作当前工单!')
+ r = c.post(
+ "/cancel/",
+ data={"workflow_id": self.wf2.id, "cancel_remark": "some_reason"},
+ )
+ self.assertContains(r, "你无权操作当前工单!")
_can_cancel.return_value = True
_detail_by_id = 123
- r = c.post('/cancel/', data={'workflow_id': self.wf2.id, 'cancel_remark': 'some_reason'})
+ r = c.post(
+ "/cancel/",
+ data={"workflow_id": self.wf2.id, "cancel_remark": "some_reason"},
+ )
self.wf2.refresh_from_db()
- self.assertEqual('workflow_abort', self.wf2.status)
-
- @patch('sql.sql_workflow.async_task')
- @patch('sql.sql_workflow.Audit')
- @patch('sql.sql_workflow.get_engine')
- @patch('sql.sql_workflow.user_instances')
- def test_workflow_auto_review_view(self, mock_user_instances, mock_get_engine, mock_audit, mock_async_task):
+ self.assertEqual("workflow_abort", self.wf2.status)
+
+ @patch("sql.sql_workflow.async_task")
+ @patch("sql.sql_workflow.Audit")
+ @patch("sql.sql_workflow.get_engine")
+ @patch("sql.sql_workflow.user_instances")
+ def test_workflow_auto_review_view(
+ self, mock_user_instances, mock_get_engine, mock_audit, mock_async_task
+ ):
"""测试 autoreview/submit view"""
c = Client()
c.force_login(self.superuser1)
request_data = {
- 'sql_content': "update some_db set some_key=\'some value\';",
- 'workflow_name': 'some_title',
- 'group_name': self.resource_group1.group_name,
- 'group_id': self.resource_group1.group_id,
- 'instance_name': self.master1.instance_name,
- "demand_url": 'test_url',
- 'db_name': 'some_db',
- 'is_backup': True,
- 'notify_users': ''
+ "sql_content": "update some_db set some_key='some value';",
+ "workflow_name": "some_title",
+ "group_name": self.resource_group1.group_name,
+ "group_id": self.resource_group1.group_id,
+ "instance_name": self.master1.instance_name,
+ "demand_url": "test_url",
+ "db_name": "some_db",
+ "is_backup": True,
+ "notify_users": "",
}
mock_user_instances.return_value.get.return_value = None
mock_get_engine.return_value.execute_check.return_value.warning_count = 0
mock_get_engine.return_value.execute_check.return_value.error_count = 0
mock_get_engine.return_value.execute_check.return_value.syntax_type = 0
mock_get_engine.return_value.execute_check.return_value.rows = []
- mock_get_engine.return_value.execute_check.return_value.json.return_value = json.dumps([{
- "id": 1,
- "stage": "CHECKED",
- "errlevel": 0,
- "stagestatus": "Audit completed",
- "errormessage": "None", "sql": "use thirdservice_db", "affected_rows": 0,
- "sequence": "'0_0_0'", "backup_dbname": "None", "execute_time": "0", "sqlsha1": "",
- "actual_affected_rows": None}])
- mock_audit.settings.return_value = 'some_group,another_group'
+ mock_get_engine.return_value.execute_check.return_value.json.return_value = (
+ json.dumps(
+ [
+ {
+ "id": 1,
+ "stage": "CHECKED",
+ "errlevel": 0,
+ "stagestatus": "Audit completed",
+ "errormessage": "None",
+ "sql": "use thirdservice_db",
+ "affected_rows": 0,
+ "sequence": "'0_0_0'",
+ "backup_dbname": "None",
+ "execute_time": "0",
+ "sqlsha1": "",
+ "actual_affected_rows": None,
+ }
+ ]
+ )
+ )
+ mock_audit.settings.return_value = "some_group,another_group"
mock_audit.add.return_value = None
mock_async_task.return_value = None
- r = c.post('/autoreview/', data=request_data, follow=False)
- self.assertIn('detail', r.url)
- workflow_id = int(re.search(r'\/detail\/(\d+)\/', r.url).groups()[0])
- self.assertEqual(request_data['workflow_name'], SqlWorkflow.objects.get(id=workflow_id).workflow_name)
+ r = c.post("/autoreview/", data=request_data, follow=False)
+ self.assertIn("detail", r.url)
+ workflow_id = int(re.search(r"\/detail\/(\d+)\/", r.url).groups()[0])
+ self.assertEqual(
+ request_data["workflow_name"],
+ SqlWorkflow.objects.get(id=workflow_id).workflow_name,
+ )
# 强制备份测试
# 打开备份开关, 对备份不要求
request_data_without_backup = {
- 'sql_content': "update some_db set some_key=\'some value\';",
- 'workflow_name': 'some_title_2',
- 'group_name': self.resource_group1.group_name,
- 'group_id': self.resource_group1.group_id,
- 'instance_name': self.master1.instance_name,
- 'db_name': 'some_db',
- "demand_url": 'test_url',
- 'is_backup': False,
- 'notify_users': ''
+ "sql_content": "update some_db set some_key='some value';",
+ "workflow_name": "some_title_2",
+ "group_name": self.resource_group1.group_name,
+ "group_id": self.resource_group1.group_id,
+ "instance_name": self.master1.instance_name,
+ "db_name": "some_db",
+ "demand_url": "test_url",
+ "is_backup": False,
+ "notify_users": "",
}
archer_config = SysConfig()
- archer_config.set('enable_backup_switch', 'true')
- r = c.post('/autoreview/', data=request_data_without_backup, follow=False)
- self.assertIn('detail', r.url)
- workflow_id = int(re.search(r'\/detail\/(\d+)\/', r.url).groups()[0])
- self.assertEqual(request_data_without_backup['workflow_name'],
- SqlWorkflow.objects.get(id=workflow_id).workflow_name)
+ archer_config.set("enable_backup_switch", "true")
+ r = c.post("/autoreview/", data=request_data_without_backup, follow=False)
+ self.assertIn("detail", r.url)
+ workflow_id = int(re.search(r"\/detail\/(\d+)\/", r.url).groups()[0])
+ self.assertEqual(
+ request_data_without_backup["workflow_name"],
+ SqlWorkflow.objects.get(id=workflow_id).workflow_name,
+ )
# 关闭备份选项, 不允许不备份
- archer_config.set('enable_backup_switch', 'false')
- r = c.post('/autoreview/', data=request_data_without_backup, follow=False)
- self.assertIn('detail', r.url)
- workflow_id = int(re.search(r'\/detail\/(\d+)\/', r.url).groups()[0])
+ archer_config.set("enable_backup_switch", "false")
+ r = c.post("/autoreview/", data=request_data_without_backup, follow=False)
+ self.assertIn("detail", r.url)
+ workflow_id = int(re.search(r"\/detail\/(\d+)\/", r.url).groups()[0])
self.assertEqual(SqlWorkflow.objects.get(id=workflow_id).is_backup, True)
- @patch('sql.sql_workflow.get_engine')
+ @patch("sql.sql_workflow.get_engine")
def test_osc_control(self, _get_engine):
"""测试MySQL工单osc控制"""
c = Client()
c.force_login(self.superuser1)
request_data = {
- 'workflow_id': self.wf1.id,
- 'sqlsha1': 'sqlsha1',
- 'command': 'get',
+ "workflow_id": self.wf1.id,
+ "sqlsha1": "sqlsha1",
+ "command": "get",
}
_get_engine.return_value.osc_control.return_value = ResultSet()
- r = c.post('/inception/osc_control/', data=request_data, follow=False)
- self.assertDictEqual(json.loads(r.content),
- {"total": 0, "rows": [], "msg": None})
+ r = c.post("/inception/osc_control/", data=request_data, follow=False)
+ self.assertDictEqual(
+ json.loads(r.content), {"total": 0, "rows": [], "msg": None}
+ )
- @patch('sql.sql_workflow.get_engine')
+ @patch("sql.sql_workflow.get_engine")
def test_osc_control_exception(self, _get_engine):
"""测试MySQL工单OSC控制异常"""
c = Client()
c.force_login(self.superuser1)
request_data = {
- 'workflow_id': self.wf1.id,
- 'sqlsha1': 'sqlsha1',
- 'command': 'get',
+ "workflow_id": self.wf1.id,
+ "sqlsha1": "sqlsha1",
+ "command": "get",
}
- _get_engine.return_value.osc_control.side_effect = RuntimeError('RuntimeError')
- r = c.post('/inception/osc_control/', data=request_data, follow=False)
- self.assertDictEqual(json.loads(r.content),
- {"total": 0, "rows": [], "msg": "RuntimeError"})
+ _get_engine.return_value.osc_control.side_effect = RuntimeError("RuntimeError")
+ r = c.post("/inception/osc_control/", data=request_data, follow=False)
+ self.assertDictEqual(
+ json.loads(r.content), {"total": 0, "rows": [], "msg": "RuntimeError"}
+ )
class TestOptimize(TestCase):
@@ -1534,14 +1928,18 @@ class TestOptimize(TestCase):
"""
def setUp(self):
- self.superuser = User(username='super', is_superuser=True)
+ self.superuser = User(username="super", is_superuser=True)
self.superuser.save()
# 使用 travis.ci 时实例和测试service保持一致
- self.master = Instance(instance_name='test_instance', type='master', db_type='mysql',
- host=settings.DATABASES['default']['HOST'],
- port=settings.DATABASES['default']['PORT'],
- user=settings.DATABASES['default']['USER'],
- password=settings.DATABASES['default']['PASSWORD'])
+ self.master = Instance(
+ instance_name="test_instance",
+ type="master",
+ db_type="mysql",
+ host=settings.DATABASES["default"]["HOST"],
+ port=settings.DATABASES["default"]["PORT"],
+ user=settings.DATABASES["default"]["USER"],
+ password=settings.DATABASES["default"]["PASSWORD"],
+ )
self.master.save()
self.sys_config = SysConfig()
self.client = Client()
@@ -1557,70 +1955,105 @@ def test_sqladvisor(self):
测试SQLAdvisor报告
:return:
"""
- r = self.client.post(path='/slowquery/optimize_sqladvisor/')
- self.assertEqual(json.loads(r.content), {'status': 1, 'msg': '页面提交参数可能为空', 'data': []})
- r = self.client.post(path='/slowquery/optimize_sqladvisor/',
- data={"sql_content": "select 1;", "instance_name": "test_instance"})
- self.assertEqual(json.loads(r.content), {'status': 1, 'msg': '请配置SQLAdvisor路径!', 'data': []})
- self.sys_config.set('sqladvisor', '/opt/archery/src/plugins/sqladvisor')
+ r = self.client.post(path="/slowquery/optimize_sqladvisor/")
+ self.assertEqual(
+ json.loads(r.content), {"status": 1, "msg": "页面提交参数可能为空", "data": []}
+ )
+ r = self.client.post(
+ path="/slowquery/optimize_sqladvisor/",
+ data={"sql_content": "select 1;", "instance_name": "test_instance"},
+ )
+ self.assertEqual(
+ json.loads(r.content), {"status": 1, "msg": "请配置SQLAdvisor路径!", "data": []}
+ )
+ self.sys_config.set("sqladvisor", "/opt/archery/src/plugins/sqladvisor")
self.sys_config.get_all_config()
- r = self.client.post(path='/slowquery/optimize_sqladvisor/',
- data={"sql_content": "select 1;", "instance_name": "test_instance"})
- self.assertEqual(json.loads(r.content)['status'], 0)
+ r = self.client.post(
+ path="/slowquery/optimize_sqladvisor/",
+ data={"sql_content": "select 1;", "instance_name": "test_instance"},
+ )
+ self.assertEqual(json.loads(r.content)["status"], 0)
def test_soar(self):
"""
测试SOAR报告
:return:
"""
- r = self.client.post(path='/slowquery/optimize_soar/')
- self.assertEqual(json.loads(r.content), {'status': 1, 'msg': '页面提交参数可能为空', 'data': []})
- r = self.client.post(path='/slowquery/optimize_soar/',
- data={"sql": "select 1;", "instance_name": "test_instance", "db_name": "mysql"})
- self.assertEqual(json.loads(r.content), {'status': 1, 'msg': '请配置soar_path和test_dsn!', 'data': []})
- self.sys_config.set('soar', '/opt/archery/src/plugins/soar')
- self.sys_config.set('soar_test_dsn', 'root:@127.0.0.1:3306/information_schema')
+ r = self.client.post(path="/slowquery/optimize_soar/")
+ self.assertEqual(
+ json.loads(r.content), {"status": 1, "msg": "页面提交参数可能为空", "data": []}
+ )
+ r = self.client.post(
+ path="/slowquery/optimize_soar/",
+ data={
+ "sql": "select 1;",
+ "instance_name": "test_instance",
+ "db_name": "mysql",
+ },
+ )
+ self.assertEqual(
+ json.loads(r.content),
+ {"status": 1, "msg": "请配置soar_path和test_dsn!", "data": []},
+ )
+ self.sys_config.set("soar", "/opt/archery/src/plugins/soar")
+ self.sys_config.set("soar_test_dsn", "root:@127.0.0.1:3306/information_schema")
self.sys_config.get_all_config()
- r = self.client.post(path='/slowquery/optimize_soar/',
- data={"sql": "select 1;", "instance_name": "test_instance", "db_name": "mysql"})
- self.assertEqual(json.loads(r.content)['status'], 0)
+ r = self.client.post(
+ path="/slowquery/optimize_soar/",
+ data={
+ "sql": "select 1;",
+ "instance_name": "test_instance",
+ "db_name": "mysql",
+ },
+ )
+ self.assertEqual(json.loads(r.content)["status"], 0)
def test_tuning(self):
"""
测试SQLTuning报告
:return:
"""
- data = {"sql_content": "select * from test_archery.sql_users;",
- "instance_name": "test_instance",
- "db_name": settings.DATABASES['default']['TEST']['NAME']
- }
- data['instance_name'] = 'test_instancex'
- r = self.client.post(path='/slowquery/optimize_sqltuning/', data=data)
- self.assertEqual(json.loads(r.content), {'status': 1, 'msg': '你所在组未关联该实例!', 'data': []})
+ data = {
+ "sql_content": "select * from test_archery.sql_users;",
+ "instance_name": "test_instance",
+ "db_name": settings.DATABASES["default"]["TEST"]["NAME"],
+ }
+ data["instance_name"] = "test_instancex"
+ r = self.client.post(path="/slowquery/optimize_sqltuning/", data=data)
+ self.assertEqual(
+ json.loads(r.content), {"status": 1, "msg": "你所在组未关联该实例!", "data": []}
+ )
# 获取sys_parm
- data['instance_name'] = 'test_instance'
- data['option[]'] = 'sys_parm'
- r = self.client.post(path='/slowquery/optimize_sqltuning/', data=data)
- self.assertListEqual(list(json.loads(r.content)['data'].keys()),
- ['basic_information', 'sys_parameter', 'optimizer_switch', 'sqltext'])
+ data["instance_name"] = "test_instance"
+ data["option[]"] = "sys_parm"
+ r = self.client.post(path="/slowquery/optimize_sqltuning/", data=data)
+ self.assertListEqual(
+ list(json.loads(r.content)["data"].keys()),
+ ["basic_information", "sys_parameter", "optimizer_switch", "sqltext"],
+ )
# 获取sql_plan
- data['option[]'] = 'sql_plan'
- r = self.client.post(path='/slowquery/optimize_sqltuning/', data=data)
- self.assertListEqual(list(json.loads(r.content)['data'].keys()),
- ['optimizer_rewrite_sql', 'plan', 'sqltext'])
+ data["option[]"] = "sql_plan"
+ r = self.client.post(path="/slowquery/optimize_sqltuning/", data=data)
+ self.assertListEqual(
+ list(json.loads(r.content)["data"].keys()),
+ ["optimizer_rewrite_sql", "plan", "sqltext"],
+ )
# 获取obj_stat
- data['option[]'] = 'obj_stat'
- r = self.client.post(path='/slowquery/optimize_sqltuning/', data=data)
- self.assertListEqual(list(json.loads(r.content)['data'].keys()),
- ['object_statistics', 'sqltext'])
+ data["option[]"] = "obj_stat"
+ r = self.client.post(path="/slowquery/optimize_sqltuning/", data=data)
+ self.assertListEqual(
+ list(json.loads(r.content)["data"].keys()), ["object_statistics", "sqltext"]
+ )
# 获取sql_profile
- data['option[]'] = 'sql_profile'
- r = self.client.post(path='/slowquery/optimize_sqltuning/', data=data)
- self.assertListEqual(list(json.loads(r.content)['data'].keys()), ['session_status', 'sqltext'])
+ data["option[]"] = "sql_profile"
+ r = self.client.post(path="/slowquery/optimize_sqltuning/", data=data)
+ self.assertListEqual(
+ list(json.loads(r.content)["data"].keys()), ["session_status", "sqltext"]
+ )
class TestSchemaSync(TestCase):
@@ -1629,14 +2062,18 @@ class TestSchemaSync(TestCase):
"""
def setUp(self):
- self.superuser = User(username='super', is_superuser=True)
+ self.superuser = User(username="super", is_superuser=True)
self.superuser.save()
# 使用 travis.ci 时实例和测试service保持一致
- self.master = Instance(instance_name='test_instance', type='master', db_type='mysql',
- host=settings.DATABASES['default']['HOST'],
- port=settings.DATABASES['default']['PORT'],
- user=settings.DATABASES['default']['USER'],
- password=settings.DATABASES['default']['PASSWORD'])
+ self.master = Instance(
+ instance_name="test_instance",
+ type="master",
+ db_type="mysql",
+ host=settings.DATABASES["default"]["HOST"],
+ port=settings.DATABASES["default"]["PORT"],
+ user=settings.DATABASES["default"]["USER"],
+ password=settings.DATABASES["default"]["PASSWORD"],
+ )
self.master.save()
self.sys_config = SysConfig()
self.client = Client()
@@ -1652,14 +2089,16 @@ def test_schema_sync(self):
测试SchemaSync
:return:
"""
- data = {"instance_name": "test_instance",
- "db_name": "test",
- "target_instance_name": "test_instance",
- "target_db_name": "test",
- "sync_auto_inc": True,
- "sync_comments": False}
- r = self.client.post(path='/instance/schemasync/', data=data)
- self.assertEqual(json.loads(r.content)['status'], 0)
+ data = {
+ "instance_name": "test_instance",
+ "db_name": "test",
+ "target_instance_name": "test_instance",
+ "target_db_name": "test",
+ "sync_auto_inc": True,
+ "sync_comments": False,
+ }
+ r = self.client.post(path="/instance/schemasync/", data=data)
+ self.assertEqual(json.loads(r.content)["status"], 0)
class TestArchiver(TestCase):
@@ -1668,39 +2107,45 @@ class TestArchiver(TestCase):
"""
def setUp(self):
- self.superuser = User.objects.create(username='super', is_superuser=True)
- self.u1 = User.objects.create(username='u1', is_superuser=False)
- self.u2 = User.objects.create(username='u2', is_superuser=False)
- menu_archive = Permission.objects.get(codename='menu_archive')
- archive_review = Permission.objects.get(codename='archive_review')
+ self.superuser = User.objects.create(username="super", is_superuser=True)
+ self.u1 = User.objects.create(username="u1", is_superuser=False)
+ self.u2 = User.objects.create(username="u2", is_superuser=False)
+ menu_archive = Permission.objects.get(codename="menu_archive")
+ archive_review = Permission.objects.get(codename="archive_review")
self.u1.user_permissions.add(menu_archive)
self.u2.user_permissions.add(menu_archive)
self.u2.user_permissions.add(archive_review)
# 使用 travis.ci 时实例和测试service保持一致
- self.ins = Instance.objects.create(instance_name='test_instance', type='master', db_type='mysql',
- host=settings.DATABASES['default']['HOST'],
- port=settings.DATABASES['default']['PORT'],
- user=settings.DATABASES['default']['USER'],
- password=settings.DATABASES['default']['PASSWORD'])
- self.res_group = ResourceGroup.objects.create(group_id=1, group_name='group_name')
+ self.ins = Instance.objects.create(
+ instance_name="test_instance",
+ type="master",
+ db_type="mysql",
+ host=settings.DATABASES["default"]["HOST"],
+ port=settings.DATABASES["default"]["PORT"],
+ user=settings.DATABASES["default"]["USER"],
+ password=settings.DATABASES["default"]["PASSWORD"],
+ )
+ self.res_group = ResourceGroup.objects.create(
+ group_id=1, group_name="group_name"
+ )
self.archive_apply = ArchiveConfig.objects.create(
- title='title',
+ title="title",
resource_group=self.res_group,
- audit_auth_groups='some_audit_group',
+ audit_auth_groups="some_audit_group",
src_instance=self.ins,
- src_db_name='src_db_name',
- src_table_name='src_table_name',
+ src_db_name="src_db_name",
+ src_table_name="src_table_name",
dest_instance=self.ins,
- dest_db_name='src_db_name',
- dest_table_name='src_table_name',
- condition='1=1',
- mode='file',
+ dest_db_name="src_db_name",
+ dest_table_name="src_table_name",
+ condition="1=1",
+ mode="file",
no_delete=True,
sleep=1,
- status=WorkflowDict.workflow_status['audit_wait'],
+ status=WorkflowDict.workflow_status["audit_wait"],
state=False,
- user_name='some_user',
- user_display='display',
+ user_name="some_user",
+ user_display="display",
)
self.sys_config = SysConfig()
self.client = Client()
@@ -1718,11 +2163,9 @@ def test_archive_list_super(self):
测试管理员获取归档申请列表
:return:
"""
- data = {"filter_instance_id": self.ins.id,
- "state": 'false',
- "search": "text"}
+ data = {"filter_instance_id": self.ins.id, "state": "false", "search": "text"}
self.client.force_login(self.superuser)
- r = self.client.get(path='/archive/list/', data=data)
+ r = self.client.get(path="/archive/list/", data=data)
self.assertDictEqual(json.loads(r.content), {"total": 0, "rows": []})
def test_archive_list_own(self):
@@ -1730,11 +2173,9 @@ def test_archive_list_own(self):
测试非管理员和审核人获取归档申请列表
:return:
"""
- data = {"filter_instance_id": self.ins.id,
- "state": 'false',
- "search": "text"}
+ data = {"filter_instance_id": self.ins.id, "state": "false", "search": "text"}
self.client.force_login(self.u1)
- r = self.client.get(path='/archive/list/', data=data)
+ r = self.client.get(path="/archive/list/", data=data)
self.assertDictEqual(json.loads(r.content), {"total": 0, "rows": []})
def test_archive_list_review(self):
@@ -1742,11 +2183,9 @@ def test_archive_list_review(self):
测试审核人获取归档申请列表
:return:
"""
- data = {"filter_instance_id": self.ins.id,
- "state": 'false',
- "search": "text"}
+ data = {"filter_instance_id": self.ins.id, "state": "false", "search": "text"}
self.client.force_login(self.u2)
- r = self.client.get(path='/archive/list/', data=data)
+ r = self.client.get(path="/archive/list/", data=data)
self.assertDictEqual(json.loads(r.content), {"total": 0, "rows": []})
def test_archive_apply_not_param(self):
@@ -1757,19 +2196,21 @@ def test_archive_apply_not_param(self):
data = {
"group_name": self.res_group.group_name,
"src_instance_name": self.ins.instance_name,
- "src_db_name": 'src_db_name',
- "src_table_name": 'src_table_name',
- "mode": 'dest',
+ "src_db_name": "src_db_name",
+ "src_table_name": "src_table_name",
+ "mode": "dest",
"dest_instance_name": self.ins.instance_name,
- "dest_db_name": 'dest_db_name',
- "dest_table_name": 'dest_table_name',
- "condition": '1=1',
- "no_delete": 'true',
- "sleep": 10
+ "dest_db_name": "dest_db_name",
+ "dest_table_name": "dest_table_name",
+ "condition": "1=1",
+ "no_delete": "true",
+ "sleep": 10,
}
self.client.force_login(self.superuser)
- r = self.client.post(path='/archive/apply/', data=data)
- self.assertDictEqual(json.loads(r.content), {'status': 1, 'msg': '请填写完整!', 'data': {}})
+ r = self.client.post(path="/archive/apply/", data=data)
+ self.assertDictEqual(
+ json.loads(r.content), {"status": 1, "msg": "请填写完整!", "data": {}}
+ )
def test_archive_apply_not_dest_param(self):
"""
@@ -1777,85 +2218,99 @@ def test_archive_apply_not_dest_param(self):
:return:
"""
data = {
- "title": 'title',
+ "title": "title",
"group_name": self.res_group.group_name,
"src_instance_name": self.ins.instance_name,
- "src_db_name": 'src_db_name',
- "src_table_name": 'src_table_name',
- "mode": 'dest',
- "condition": '1=1',
- "no_delete": 'true',
- "sleep": 10
+ "src_db_name": "src_db_name",
+ "src_table_name": "src_table_name",
+ "mode": "dest",
+ "condition": "1=1",
+ "no_delete": "true",
+ "sleep": 10,
}
self.client.force_login(self.superuser)
- r = self.client.post(path='/archive/apply/', data=data)
- self.assertDictEqual(json.loads(r.content), {'status': 1, 'msg': '归档到实例时目标实例信息必选!', 'data': {}})
+ r = self.client.post(path="/archive/apply/", data=data)
+ self.assertDictEqual(
+ json.loads(r.content), {"status": 1, "msg": "归档到实例时目标实例信息必选!", "data": {}}
+ )
def test_archive_apply_not_exist_review(self):
"""
测试申请归档实例数据,未配置审批流程
:return:
"""
- data = {"title": 'title',
- "group_name": self.res_group.group_name,
- "src_instance_name": self.ins.instance_name,
- "src_db_name": 'src_db_name',
- "src_table_name": 'src_table_name',
- "mode": 'dest',
- "dest_instance_name": self.ins.instance_name,
- "dest_db_name": 'dest_db_name',
- "dest_table_name": 'dest_table_name',
- "condition": '1=1',
- "no_delete": 'true',
- "sleep": 10
- }
+ data = {
+ "title": "title",
+ "group_name": self.res_group.group_name,
+ "src_instance_name": self.ins.instance_name,
+ "src_db_name": "src_db_name",
+ "src_table_name": "src_table_name",
+ "mode": "dest",
+ "dest_instance_name": self.ins.instance_name,
+ "dest_db_name": "dest_db_name",
+ "dest_table_name": "dest_table_name",
+ "condition": "1=1",
+ "no_delete": "true",
+ "sleep": 10,
+ }
self.client.force_login(self.superuser)
- r = self.client.post(path='/archive/apply/', data=data)
- self.assertDictEqual(json.loads(r.content), {'data': {}, 'msg': '审批流程不能为空,请先配置审批流程', 'status': 1})
+ r = self.client.post(path="/archive/apply/", data=data)
+ self.assertDictEqual(
+ json.loads(r.content), {"data": {}, "msg": "审批流程不能为空,请先配置审批流程", "status": 1}
+ )
- @patch('sql.archiver.async_task')
+ @patch("sql.archiver.async_task")
def test_archive_apply(self, _async_task):
"""
测试申请归档实例数据
:return:
"""
- WorkflowAuditSetting.objects.create(workflow_type=3, group_id=1, audit_auth_groups='1')
- data = {"title": 'title',
- "group_name": self.res_group.group_name,
- "src_instance_name": self.ins.instance_name,
- "src_db_name": 'src_db_name',
- "src_table_name": 'src_table_name',
- "mode": 'dest',
- "dest_instance_name": self.ins.instance_name,
- "dest_db_name": 'dest_db_name',
- "dest_table_name": 'dest_table_name',
- "condition": '1=1',
- "no_delete": 'true',
- "sleep": 10
- }
+ WorkflowAuditSetting.objects.create(
+ workflow_type=3, group_id=1, audit_auth_groups="1"
+ )
+ data = {
+ "title": "title",
+ "group_name": self.res_group.group_name,
+ "src_instance_name": self.ins.instance_name,
+ "src_db_name": "src_db_name",
+ "src_table_name": "src_table_name",
+ "mode": "dest",
+ "dest_instance_name": self.ins.instance_name,
+ "dest_db_name": "dest_db_name",
+ "dest_table_name": "dest_table_name",
+ "condition": "1=1",
+ "no_delete": "true",
+ "sleep": 10,
+ }
self.client.force_login(self.superuser)
- r = self.client.post(path='/archive/apply/', data=data)
- self.assertEqual(json.loads(r.content)['status'], 0)
+ r = self.client.post(path="/archive/apply/", data=data)
+ self.assertEqual(json.loads(r.content)["status"], 0)
- @patch('sql.archiver.Audit')
- @patch('sql.archiver.async_task')
+ @patch("sql.archiver.Audit")
+ @patch("sql.archiver.async_task")
def test_archive_audit(self, _async_task, _audit):
"""
测试审核归档实例数据
:return:
"""
_audit.detail_by_workflow_id.return_value.audit_id = 1
- _audit.audit.return_value = {'status': 0, 'msg': 'ok', 'data': {'workflow_status': 1}}
+ _audit.audit.return_value = {
+ "status": 0,
+ "msg": "ok",
+ "data": {"workflow_status": 1},
+ }
data = {
"archive_id": self.archive_apply.id,
- "audit_status": WorkflowDict.workflow_status['audit_success'],
- "audit_remark": 'xxxx'
+ "audit_status": WorkflowDict.workflow_status["audit_success"],
+ "audit_remark": "xxxx",
}
self.client.force_login(self.superuser)
- r = self.client.post(path='/archive/audit/', data=data)
- self.assertRedirects(r, f'/archive/{self.archive_apply.id}/', fetch_redirect_response=False)
+ r = self.client.post(path="/archive/audit/", data=data)
+ self.assertRedirects(
+ r, f"/archive/{self.archive_apply.id}/", fetch_redirect_response=False
+ )
- @patch('sql.archiver.async_task')
+ @patch("sql.archiver.async_task")
def test_add_archive_task(self, _async_task):
"""
测试添加异步归档任务
@@ -1863,7 +2318,7 @@ def test_add_archive_task(self, _async_task):
"""
add_archive_task()
- @patch('sql.archiver.async_task')
+ @patch("sql.archiver.async_task")
def test_add_archive(self, _async_task):
"""
测试执行归档任务
@@ -1872,7 +2327,7 @@ def test_add_archive(self, _async_task):
with self.assertRaises(Exception):
archive(self.archive_apply.id)
- @patch('sql.archiver.async_task')
+ @patch("sql.archiver.async_task")
def test_archive_log(self, _async_task):
"""
测试获取归档日志
@@ -1882,48 +2337,52 @@ def test_archive_log(self, _async_task):
"archive_id": self.archive_apply.id,
}
self.client.force_login(self.superuser)
- r = self.client.post(path='/archive/log/', data=data)
+ r = self.client.post(path="/archive/log/", data=data)
self.assertDictEqual(json.loads(r.content), {"total": 0, "rows": []})
class TestAsync(TestCase):
-
def setUp(self):
self.now = datetime.now()
- self.u1 = User(username='some_user', display='用户1')
+ self.u1 = User(username="some_user", display="用户1")
self.u1.save()
- self.master1 = Instance(instance_name='test_master_instance', type='master', db_type='mysql',
- host='testhost', port=3306, user='mysql_user', password='mysql_password')
+ self.master1 = Instance(
+ instance_name="test_master_instance",
+ type="master",
+ db_type="mysql",
+ host="testhost",
+ port=3306,
+ user="mysql_user",
+ password="mysql_password",
+ )
self.master1.save()
self.wf1 = SqlWorkflow.objects.create(
- workflow_name='some_name2',
+ workflow_name="some_name2",
group_id=1,
- group_name='g1',
+ group_name="g1",
engineer=self.u1.username,
engineer_display=self.u1.display,
- audit_auth_groups='some_group',
+ audit_auth_groups="some_group",
create_time=self.now - timedelta(days=1),
- status='workflow_executing',
+ status="workflow_executing",
is_backup=True,
instance=self.master1,
- db_name='some_db',
+ db_name="some_db",
syntax_type=1,
)
self.wfc1 = SqlWorkflowContent.objects.create(
- workflow=self.wf1,
- sql_content='some_sql',
- execute_result=''
+ workflow=self.wf1, sql_content="some_sql", execute_result=""
)
# 初始化工单执行返回对象
self.task_result = MagicMock()
self.task_result.args = [self.wf1.id]
self.task_result.success = True
self.task_result.stopped = self.now
- self.task_result.result.json.return_value = json.dumps([{
- 'id': 1,
- 'sql': 'some_content'}])
- self.task_result.result.warning = ''
- self.task_result.result.error = ''
+ self.task_result.result.json.return_value = json.dumps(
+ [{"id": 1, "sql": "some_content"}]
+ )
+ self.task_result.result.warning = ""
+ self.task_result.result.error = ""
def tearDown(self):
self.wf1.delete()
@@ -1931,13 +2390,15 @@ def tearDown(self):
self.task_result = None
self.master1.delete()
- @patch('sql.utils.execute_sql.notify_for_execute')
- @patch('sql.utils.execute_sql.Audit')
+ @patch("sql.utils.execute_sql.notify_for_execute")
+ @patch("sql.utils.execute_sql.Audit")
def test_call_back(self, mock_audit, mock_notify):
mock_audit.detail_by_workflow_id.return_value.audit_id = 123
- mock_audit.add_log.return_value = 'any thing'
+ mock_audit.add_log.return_value = "any thing"
execute_callback(self.task_result)
- mock_audit.detail_by_workflow_id.assert_called_with(workflow_id=self.wf1.id, workflow_type=ANY)
+ mock_audit.detail_by_workflow_id.assert_called_with(
+ workflow_id=self.wf1.id, workflow_type=ANY
+ )
mock_audit.add_log.assert_called_with(
audit_id=123,
operation_type=ANY,
@@ -1955,14 +2416,18 @@ class TestSQLAnalyze(TestCase):
"""
def setUp(self):
- self.superuser = User(username='super', is_superuser=True)
+ self.superuser = User(username="super", is_superuser=True)
self.superuser.save()
# 使用 travis.ci 时实例和测试service保持一致
- self.master = Instance(instance_name='test_instance', type='master', db_type='mysql',
- host=settings.DATABASES['default']['HOST'],
- port=settings.DATABASES['default']['PORT'],
- user=settings.DATABASES['default']['USER'],
- password=settings.DATABASES['default']['PASSWORD'])
+ self.master = Instance(
+ instance_name="test_instance",
+ type="master",
+ db_type="mysql",
+ host=settings.DATABASES["default"]["HOST"],
+ port=settings.DATABASES["default"]["PORT"],
+ user=settings.DATABASES["default"]["USER"],
+ password=settings.DATABASES["default"]["PASSWORD"],
+ )
self.master.save()
self.sys_config = SysConfig()
self.client = Client()
@@ -1978,8 +2443,8 @@ def test_generate_text_None(self):
测试解析SQL,text为空
:return:
"""
- r = self.client.post(path='/sql_analyze/generate/', data={})
- self.assertEqual(json.loads(r.content), {'rows': [], 'total': 0})
+ r = self.client.post(path="/sql_analyze/generate/", data={})
+ self.assertEqual(json.loads(r.content), {"rows": [], "total": 0})
def test_generate_text_not_None(self):
"""
@@ -1987,19 +2452,25 @@ def test_generate_text_not_None(self):
:return:
"""
text = "select * from sql_user;select * from sql_workflow;"
- r = self.client.post(path='/sql_analyze/generate/', data={"text": text})
- self.assertEqual(json.loads(r.content),
- {"total": 2, "rows": [{"sql_id": 1, "sql": "select * from sql_user;"},
- {"sql_id": 2, "sql": "select * from sql_workflow;"}]}
- )
+ r = self.client.post(path="/sql_analyze/generate/", data={"text": text})
+ self.assertEqual(
+ json.loads(r.content),
+ {
+ "total": 2,
+ "rows": [
+ {"sql_id": 1, "sql": "select * from sql_user;"},
+ {"sql_id": 2, "sql": "select * from sql_workflow;"},
+ ],
+ },
+ )
def test_analyze_text_None(self):
"""
测试分析SQL,text为空
:return:
"""
- r = self.client.post(path='/sql_analyze/analyze/', data={})
- self.assertEqual(json.loads(r.content), {'rows': [], 'total': 0})
+ r = self.client.post(path="/sql_analyze/analyze/", data={})
+ self.assertEqual(json.loads(r.content), {"rows": [], "total": 0})
def test_analyze_text_not_None(self):
"""
@@ -2008,10 +2479,14 @@ def test_analyze_text_not_None(self):
"""
text = "select * from sql_user;select * from sql_workflow;"
instance_name = self.master.instance_name
- db_name = settings.DATABASES['default']['TEST']['NAME']
- r = self.client.post(path='/sql_analyze/analyze/',
- data={"text": text, "instance_name": instance_name, "db_name": db_name})
- self.assertListEqual(list(json.loads(r.content)['rows'][0].keys()), ['sql_id', 'sql', 'report'])
+ db_name = settings.DATABASES["default"]["TEST"]["NAME"]
+ r = self.client.post(
+ path="/sql_analyze/analyze/",
+ data={"text": text, "instance_name": instance_name, "db_name": db_name},
+ )
+ self.assertListEqual(
+ list(json.loads(r.content)["rows"][0].keys()), ["sql_id", "sql", "report"]
+ )
class TestBinLog(TestCase):
@@ -2020,14 +2495,18 @@ class TestBinLog(TestCase):
"""
def setUp(self):
- self.superuser = User(username='super', is_superuser=True)
+ self.superuser = User(username="super", is_superuser=True)
self.superuser.save()
# 使用 travis.ci 时实例和测试service保持一致
- self.master = Instance(instance_name='test_instance', type='master', db_type='mysql',
- host=settings.DATABASES['default']['HOST'],
- port=settings.DATABASES['default']['PORT'],
- user=settings.DATABASES['default']['USER'],
- password=settings.DATABASES['default']['PASSWORD'])
+ self.master = Instance(
+ instance_name="test_instance",
+ type="master",
+ db_type="mysql",
+ host=settings.DATABASES["default"]["HOST"],
+ port=settings.DATABASES["default"]["PORT"],
+ user=settings.DATABASES["default"]["USER"],
+ password=settings.DATABASES["default"]["PASSWORD"],
+ )
self.master.save()
self.sys_config = SysConfig()
self.client = Client()
@@ -2043,21 +2522,19 @@ def test_binlog_list_instance_not_exist(self):
测试获取binlog列表,实例不存在
:return:
"""
- data = {
- "instance_name": 'some_instance'
- }
- r = self.client.post(path='/binlog/list/', data=data)
- self.assertEqual(json.loads(r.content), {'status': 1, 'msg': '实例不存在', 'data': []})
+ data = {"instance_name": "some_instance"}
+ r = self.client.post(path="/binlog/list/", data=data)
+ self.assertEqual(
+ json.loads(r.content), {"status": 1, "msg": "实例不存在", "data": []}
+ )
def test_binlog_list_instance(self):
"""
测试获取binlog列表,实例存在
:return:
"""
- data = {
- "instance_name": 'test_instance'
- }
- r = self.client.post(path='/binlog/list/', data=data)
+ data = {"instance_name": "test_instance"}
+ r = self.client.post(path="/binlog/list/", data=data)
# self.assertEqual(json.loads(r.content).get('status'), 1)
def test_my2sql_path_not_exist(self):
@@ -2065,83 +2542,91 @@ def test_my2sql_path_not_exist(self):
测试获取解析binlog,path未设置
:return:
"""
- data = {"instance_name": "test_instance",
- "save_sql": "false",
- "rollback": "2sql",
- "num": "",
- "threads": 1,
- "extra_info": "false",
- "ignore_primary_key": "false",
- "full_columns": "false",
- "no_db_prefix": "false",
- "file_per_table": "false",
- "start_file": "mysql-bin.000045",
- "start_pos": "",
- "end_file": "mysql-bin.000045",
- "end_pos": "",
- "stop_time": "",
- "start_time": "",
- "only_schemas": "",
- "sql_type": ""}
- r = self.client.post(path='/binlog/my2sql/', data=data)
- self.assertEqual(json.loads(r.content), {'status': 1, 'msg': '可执行文件路径不能为空!', 'data': {}})
-
- @patch('sql.plugins.plugin.subprocess')
+ data = {
+ "instance_name": "test_instance",
+ "save_sql": "false",
+ "rollback": "2sql",
+ "num": "",
+ "threads": 1,
+ "extra_info": "false",
+ "ignore_primary_key": "false",
+ "full_columns": "false",
+ "no_db_prefix": "false",
+ "file_per_table": "false",
+ "start_file": "mysql-bin.000045",
+ "start_pos": "",
+ "end_file": "mysql-bin.000045",
+ "end_pos": "",
+ "stop_time": "",
+ "start_time": "",
+ "only_schemas": "",
+ "sql_type": "",
+ }
+ r = self.client.post(path="/binlog/my2sql/", data=data)
+ self.assertEqual(
+ json.loads(r.content), {"status": 1, "msg": "可执行文件路径不能为空!", "data": {}}
+ )
+
+ @patch("sql.plugins.plugin.subprocess")
def test_my2sql(self, _subprocess):
"""
测试获取解析binlog,path设置
:param _subprocess:
:return:
"""
- self.sys_config.set('my2sql', '/opt/my2sql')
+ self.sys_config.set("my2sql", "/opt/my2sql")
self.sys_config.get_all_config()
- data = {"instance_name": "test_instance",
- "save_sql": "1",
- "rollback": "2sql",
- "num": "1",
- "threads": 1,
- "extra_info": "false",
- "ignore_primary_key": "false",
- "full_columns": "false",
- "no_db_prefix": "false",
- "file_per_table": "false",
- "start_file": "mysql-bin.000045",
- "start_pos": "",
- "end_file": "mysql-bin.000046",
- "end_pos": "",
- "stop_time": "",
- "start_time": "",
- "only_schemas": "",
- "sql_type": ""}
- r = self.client.post(path='/binlog/my2sql/', data=data)
+ data = {
+ "instance_name": "test_instance",
+ "save_sql": "1",
+ "rollback": "2sql",
+ "num": "1",
+ "threads": 1,
+ "extra_info": "false",
+ "ignore_primary_key": "false",
+ "full_columns": "false",
+ "no_db_prefix": "false",
+ "file_per_table": "false",
+ "start_file": "mysql-bin.000045",
+ "start_pos": "",
+ "end_file": "mysql-bin.000046",
+ "end_pos": "",
+ "stop_time": "",
+ "start_time": "",
+ "only_schemas": "",
+ "sql_type": "",
+ }
+ r = self.client.post(path="/binlog/my2sql/", data=data)
self.assertEqual(json.loads(r.content), {"status": 0, "msg": "ok", "data": []})
- @patch('builtins.open')
+ @patch("builtins.open")
def test_my2sql_file(self, _open):
"""
测试保存文件
:param _subprocess:
:return:
"""
- args = {"instance_name": "test_instance",
- "save_sql": "1",
- "rollback": "2sql",
- "num": "1",
- "threads": 1,
- "add-extraInfo": "false",
- "ignore-primaryKey-forInsert": "false",
- "full-columns": "false",
- "do-not-add-prifixDb": "false",
- "file-per-table": "false",
- "start-file": "mysql-bin.000045",
- "start-pos": "",
- "stop-file": "mysql-bin.000045",
- "stop-pos": "",
- "stop-datetime": "",
- "start-datetime": "",
- "databases": "",
- "sql": "",
- "instance": self.master}
+ args = {
+ "instance_name": "test_instance",
+ "save_sql": "1",
+ "rollback": "2sql",
+ "num": "1",
+ "threads": 1,
+ "add-extraInfo": "false",
+ "ignore-primaryKey-forInsert": "false",
+ "full-columns": "false",
+ "do-not-add-prifixDb": "false",
+ "file-per-table": "false",
+ "start-file": "mysql-bin.000045",
+ "start-pos": "",
+ "stop-file": "mysql-bin.000045",
+ "stop-pos": "",
+ "stop-datetime": "",
+ "start-datetime": "",
+ "databases": "",
+ "sql": "",
+ "instance": self.master,
+ }
r = my2sql_file(args=args, user=self.superuser)
self.assertEqual(self.superuser, r[0])
@@ -2154,51 +2639,50 @@ def test_del_binlog_instance_not_exist(self):
"instance_id": 0,
"binlog": "mysql-bin.000001",
}
- r = self.client.post(path='/binlog/del_log/', data=data)
- self.assertEqual(json.loads(r.content), {'status': 1, 'msg': '实例不存在', 'data': []})
+ r = self.client.post(path="/binlog/del_log/", data=data)
+ self.assertEqual(
+ json.loads(r.content), {"status": 1, "msg": "实例不存在", "data": []}
+ )
def test_del_binlog_binlog_not_exist(self):
"""
测试删除binlog,实例存在,binlog 不存在
:return:
"""
- data = {
- "instance_id": self.master.id,
- "binlog": ''
- }
- r = self.client.post(path='/binlog/del_log/', data=data)
- self.assertEqual(json.loads(r.content), {'status': 1, 'msg': 'Error:未选择binlog!', 'data': ''})
+ data = {"instance_id": self.master.id, "binlog": ""}
+ r = self.client.post(path="/binlog/del_log/", data=data)
+ self.assertEqual(
+ json.loads(r.content), {"status": 1, "msg": "Error:未选择binlog!", "data": ""}
+ )
- @patch('sql.engines.mysql.MysqlEngine.query')
- @patch('sql.engines.get_engine')
+ @patch("sql.engines.mysql.MysqlEngine.query")
+ @patch("sql.engines.get_engine")
def test_del_binlog(self, _get_engine, _query):
"""
测试删除binlog
:return:
"""
- data = {
- "instance_id": self.master.id,
- "binlog": "mysql-bin.000001"
- }
- _query.return_value = ResultSet(full_sql='select 1')
- r = self.client.post(path='/binlog/del_log/', data=data)
- self.assertEqual(json.loads(r.content), {'status': 0, 'msg': '清理成功', 'data': ''})
+ data = {"instance_id": self.master.id, "binlog": "mysql-bin.000001"}
+ _query.return_value = ResultSet(full_sql="select 1")
+ r = self.client.post(path="/binlog/del_log/", data=data)
+ self.assertEqual(
+ json.loads(r.content), {"status": 0, "msg": "清理成功", "data": ""}
+ )
- @patch('sql.engines.mysql.MysqlEngine.query')
- @patch('sql.engines.get_engine')
+ @patch("sql.engines.mysql.MysqlEngine.query")
+ @patch("sql.engines.get_engine")
def test_del_binlog_wrong(self, _get_engine, _query):
"""
测试删除binlog
:return:
"""
- data = {
- "instance_id": self.master.id,
- "binlog": "mysql-bin.000001"
- }
- _query.return_value = ResultSet(full_sql='select 1')
- _query.return_value.error = '清理失败'
- r = self.client.post(path='/binlog/del_log/', data=data)
- self.assertEqual(json.loads(r.content), {'status': 2, 'msg': '清理失败,Error:清理失败', 'data': ''})
+ data = {"instance_id": self.master.id, "binlog": "mysql-bin.000001"}
+ _query.return_value = ResultSet(full_sql="select 1")
+ _query.return_value.error = "清理失败"
+ r = self.client.post(path="/binlog/del_log/", data=data)
+ self.assertEqual(
+ json.loads(r.content), {"status": 2, "msg": "清理失败,Error:清理失败", "data": ""}
+ )
class TestParam(TestCase):
@@ -2207,14 +2691,18 @@ class TestParam(TestCase):
"""
def setUp(self):
- self.superuser = User(username='super', is_superuser=True)
+ self.superuser = User(username="super", is_superuser=True)
self.superuser.save()
# 使用 travis.ci 时实例和测试service保持一致
- self.master = Instance(instance_name='test_instance', type='master', db_type='mysql',
- host=settings.DATABASES['default']['HOST'],
- port=settings.DATABASES['default']['PORT'],
- user=settings.DATABASES['default']['USER'],
- password=settings.DATABASES['default']['PASSWORD'])
+ self.master = Instance(
+ instance_name="test_instance",
+ type="master",
+ db_type="mysql",
+ host=settings.DATABASES["default"]["HOST"],
+ port=settings.DATABASES["default"]["PORT"],
+ user=settings.DATABASES["default"]["USER"],
+ password=settings.DATABASES["default"]["PASSWORD"],
+ )
self.master.save()
self.client = Client()
self.client.force_login(self.superuser)
@@ -2229,24 +2717,21 @@ def test_param_list_instance_not_exist(self):
测试获取参数列表,实例不存在
:return:
"""
- data = {
- "instance_id": 0
- }
- r = self.client.post(path='/param/list/', data=data)
- self.assertEqual(json.loads(r.content), {'status': 1, 'msg': '实例不存在', 'data': []})
+ data = {"instance_id": 0}
+ r = self.client.post(path="/param/list/", data=data)
+ self.assertEqual(
+ json.loads(r.content), {"status": 1, "msg": "实例不存在", "data": []}
+ )
- @patch('sql.engines.mysql.MysqlEngine.get_variables')
- @patch('sql.engines.get_engine')
+ @patch("sql.engines.mysql.MysqlEngine.get_variables")
+ @patch("sql.engines.get_engine")
def test_param_list_instance_exist(self, _get_engine, _get_variables):
"""
测试获取参数列表,实例存在
:return:
"""
- data = {
- "instance_id": self.master.id,
- "editable": True
- }
- r = self.client.post(path='/param/list/', data=data)
+ data = {"instance_id": self.master.id, "editable": True}
+ r = self.client.post(path="/param/list/", data=data)
self.assertIsInstance(json.loads(r.content), list)
def test_param_history(self):
@@ -2254,92 +2739,124 @@ def test_param_history(self):
测试获取参数修改历史
:return:
"""
- data = {"instance_id": self.master.id,
- "search": "binlog",
- "limit": 14,
- "offset": 0}
- r = self.client.post(path='/param/history/', data=data)
- self.assertEqual(json.loads(r.content), {'rows': [], 'total': 0})
+ data = {
+ "instance_id": self.master.id,
+ "search": "binlog",
+ "limit": 14,
+ "offset": 0,
+ }
+ r = self.client.post(path="/param/history/", data=data)
+ self.assertEqual(json.loads(r.content), {"rows": [], "total": 0})
- @patch('sql.engines.mysql.MysqlEngine.set_variable')
- @patch('sql.engines.mysql.MysqlEngine.get_variables')
- @patch('sql.engines.get_engine')
- def test_param_edit_variable_not_config(self, _get_engine, _get_variables, _set_variable):
+ @patch("sql.engines.mysql.MysqlEngine.set_variable")
+ @patch("sql.engines.mysql.MysqlEngine.get_variables")
+ @patch("sql.engines.get_engine")
+ def test_param_edit_variable_not_config(
+ self, _get_engine, _get_variables, _set_variable
+ ):
"""
测试参数修改,参数未在模板配置
:return:
"""
- data = {"instance_id": self.master.id,
- "variable_name": "1",
- "variable_value": "false"}
- r = self.client.post(path='/param/edit/', data=data)
- self.assertEqual(json.loads(r.content), {'data': [], 'msg': '请先在参数模板中配置该参数!', 'status': 1})
+ data = {
+ "instance_id": self.master.id,
+ "variable_name": "1",
+ "variable_value": "false",
+ }
+ r = self.client.post(path="/param/edit/", data=data)
+ self.assertEqual(
+ json.loads(r.content), {"data": [], "msg": "请先在参数模板中配置该参数!", "status": 1}
+ )
- @patch('sql.engines.mysql.MysqlEngine.set_variable')
- @patch('sql.engines.mysql.MysqlEngine.get_variables')
- @patch('sql.engines.get_engine')
- def test_param_edit_variable_not_change(self, _get_engine, _get_variables, _set_variable):
+ @patch("sql.engines.mysql.MysqlEngine.set_variable")
+ @patch("sql.engines.mysql.MysqlEngine.get_variables")
+ @patch("sql.engines.get_engine")
+ def test_param_edit_variable_not_change(
+ self, _get_engine, _get_variables, _set_variable
+ ):
"""
测试参数修改,已在参数模板配置,但是值无变化
:return:
"""
- _get_variables.return_value.rows = (('binlog_format', 'ROW'),)
+ _get_variables.return_value.rows = (("binlog_format", "ROW"),)
_set_variable.return_value.error = None
_set_variable.return_value.full_sql = "set global binlog_format='STATEMENT';"
- ParamTemplate.objects.create(db_type='mysql',
- variable_name='binlog_format',
- default_value='ROW',
- editable=True)
- data = {"instance_id": self.master.id,
- "variable_name": "binlog_format",
- "runtime_value": "ROW"}
- r = self.client.post(path='/param/edit/', data=data)
- self.assertEqual(json.loads(r.content), {'status': 1, 'msg': '参数值与实际运行值一致,未调整!', 'data': []})
-
- @patch('sql.engines.mysql.MysqlEngine.set_variable')
- @patch('sql.engines.mysql.MysqlEngine.get_variables')
- @patch('sql.engines.get_engine')
- def test_param_edit_variable_change(self, _get_engine, _get_variables, _set_variable):
+ ParamTemplate.objects.create(
+ db_type="mysql",
+ variable_name="binlog_format",
+ default_value="ROW",
+ editable=True,
+ )
+ data = {
+ "instance_id": self.master.id,
+ "variable_name": "binlog_format",
+ "runtime_value": "ROW",
+ }
+ r = self.client.post(path="/param/edit/", data=data)
+ self.assertEqual(
+ json.loads(r.content), {"status": 1, "msg": "参数值与实际运行值一致,未调整!", "data": []}
+ )
+
+ @patch("sql.engines.mysql.MysqlEngine.set_variable")
+ @patch("sql.engines.mysql.MysqlEngine.get_variables")
+ @patch("sql.engines.get_engine")
+ def test_param_edit_variable_change(
+ self, _get_engine, _get_variables, _set_variable
+ ):
"""
测试参数修改,已在参数模板配置,且值有变化
:return:
"""
- _get_variables.return_value.rows = (('binlog_format', 'ROW'),)
+ _get_variables.return_value.rows = (("binlog_format", "ROW"),)
_set_variable.return_value.error = None
_set_variable.return_value.full_sql = "set global binlog_format='STATEMENT';"
- ParamTemplate.objects.create(db_type='mysql',
- variable_name='binlog_format',
- default_value='ROW',
- editable=True)
- data = {"instance_id": self.master.id,
- "variable_name": "binlog_format",
- "runtime_value": "STATEMENT"}
- r = self.client.post(path='/param/edit/', data=data)
- self.assertEqual(json.loads(r.content), {'status': 0, 'msg': '修改成功,请手动持久化到配置文件!', 'data': []})
-
- @patch('sql.engines.mysql.MysqlEngine.set_variable')
- @patch('sql.engines.mysql.MysqlEngine.get_variables')
- @patch('sql.engines.get_engine')
- def test_param_edit_variable_error(self, _get_engine, _get_variables, _set_variable):
+ ParamTemplate.objects.create(
+ db_type="mysql",
+ variable_name="binlog_format",
+ default_value="ROW",
+ editable=True,
+ )
+ data = {
+ "instance_id": self.master.id,
+ "variable_name": "binlog_format",
+ "runtime_value": "STATEMENT",
+ }
+ r = self.client.post(path="/param/edit/", data=data)
+ self.assertEqual(
+ json.loads(r.content), {"status": 0, "msg": "修改成功,请手动持久化到配置文件!", "data": []}
+ )
+
+ @patch("sql.engines.mysql.MysqlEngine.set_variable")
+ @patch("sql.engines.mysql.MysqlEngine.get_variables")
+ @patch("sql.engines.get_engine")
+ def test_param_edit_variable_error(
+ self, _get_engine, _get_variables, _set_variable
+ ):
"""
测试参数修改,已在参数模板配置,修改抛错
:return:
"""
- _get_variables.return_value.rows = (('binlog_format', 'ROW'),)
- _set_variable.return_value.error = '修改报错'
+ _get_variables.return_value.rows = (("binlog_format", "ROW"),)
+ _set_variable.return_value.error = "修改报错"
_set_variable.return_value.full_sql = "set global binlog_format='STATEMENT';"
- ParamTemplate.objects.create(db_type='mysql',
- variable_name='binlog_format',
- default_value='ROW',
- editable=True)
- data = {"instance_id": self.master.id,
- "variable_name": "binlog_format",
- "runtime_value": "STATEMENT"}
- r = self.client.post(path='/param/edit/', data=data)
- self.assertEqual(json.loads(r.content), {'status': 1, 'msg': '设置错误,错误信息:修改报错', 'data': []})
+ ParamTemplate.objects.create(
+ db_type="mysql",
+ variable_name="binlog_format",
+ default_value="ROW",
+ editable=True,
+ )
+ data = {
+ "instance_id": self.master.id,
+ "variable_name": "binlog_format",
+ "runtime_value": "STATEMENT",
+ }
+ r = self.client.post(path="/param/edit/", data=data)
+ self.assertEqual(
+ json.loads(r.content), {"status": 1, "msg": "设置错误,错误信息:修改报错", "data": []}
+ )
class TestNotify(TestCase):
@@ -2349,55 +2866,66 @@ class TestNotify(TestCase):
def setUp(self):
self.sys_config = SysConfig()
- self.user = User.objects.create(username='test_user', display='中文显示', is_active=True)
- self.su = User.objects.create(username='s_user', display='中文显示', is_active=True, is_superuser=True)
+ self.user = User.objects.create(
+ username="test_user", display="中文显示", is_active=True
+ )
+ self.su = User.objects.create(
+ username="s_user", display="中文显示", is_active=True, is_superuser=True
+ )
tomorrow = datetime.today() + timedelta(days=1)
- self.ins = Instance.objects.create(instance_name='some_ins', type='slave', db_type='mysql',
- host='some_host',
- port=3306, user='ins_user', password='some_str')
+ self.ins = Instance.objects.create(
+ instance_name="some_ins",
+ type="slave",
+ db_type="mysql",
+ host="some_host",
+ port=3306,
+ user="ins_user",
+ password="some_str",
+ )
self.wf = SqlWorkflow.objects.create(
- workflow_name='some_name',
+ workflow_name="some_name",
group_id=1,
- group_name='g1',
+ group_name="g1",
engineer=self.user.username,
engineer_display=self.user.display,
- audit_auth_groups='some_audit_group',
+ audit_auth_groups="some_audit_group",
create_time=datetime.now(),
- status='workflow_timingtask',
+ status="workflow_timingtask",
is_backup=True,
instance=self.ins,
- db_name='some_db',
- syntax_type=1
+ db_name="some_db",
+ syntax_type=1,
+ )
+ SqlWorkflowContent.objects.create(
+ workflow=self.wf, sql_content="some_sql", execute_result=""
)
- SqlWorkflowContent.objects.create(workflow=self.wf,
- sql_content='some_sql',
- execute_result='')
self.query_apply_1 = QueryPrivilegesApply.objects.create(
group_id=1,
- group_name='some_name',
- title='some_title1',
- user_name='some_user',
+ group_name="some_name",
+ title="some_title1",
+ user_name="some_user",
instance=self.ins,
- db_list='some_db,some_db2',
+ db_list="some_db,some_db2",
limit_num=100,
valid_date=tomorrow,
priv_type=1,
status=0,
- audit_auth_groups='some_audit_group'
+ audit_auth_groups="some_audit_group",
)
self.audit = WorkflowAudit.objects.create(
group_id=1,
- group_name='some_group',
+ group_name="some_group",
workflow_id=1,
workflow_type=1,
- workflow_title='申请标题',
- workflow_remark='申请备注',
- audit_auth_groups='1,2,3',
- current_audit='1',
- next_audit='2',
- current_status=0)
- self.aug = Group.objects.create(id=1, name='auth_group')
- self.rs = ResourceGroup.objects.create(group_id=1, ding_webhook='url')
+ workflow_title="申请标题",
+ workflow_remark="申请备注",
+ audit_auth_groups="1,2,3",
+ current_audit="1",
+ next_audit="2",
+ current_status=0,
+ )
+ self.aug = Group.objects.create(id=1, name="auth_group")
+ self.rs = ResourceGroup.objects.create(group_id=1, ding_webhook="url")
def tearDown(self):
self.sys_config.purge()
@@ -2413,13 +2941,13 @@ def test_notify_disable(self):
:return:
"""
# 关闭消息通知
- self.sys_config.set('mail', 'false')
- self.sys_config.set('ding', 'false')
+ self.sys_config.set("mail", "false")
+ self.sys_config.set("ding", "false")
r = notify_for_audit(audit_id=self.audit.audit_id)
self.assertIsNone(r)
- @patch('sql.notify.MsgSender')
- @patch('sql.notify.auth_group_users')
+ @patch("sql.notify.MsgSender")
+ @patch("sql.notify.auth_group_users")
def test_notify_for_sqlreview_audit_wait(self, _auth_group_users, _msg_sender):
"""
测试SQL上线申请审核通知
@@ -2428,19 +2956,19 @@ def test_notify_for_sqlreview_audit_wait(self, _auth_group_users, _msg_sender):
# 通知人修改
_auth_group_users.return_value = [self.user]
# 开启消息通知
- self.sys_config.set('mail', 'true')
- self.sys_config.set('ding', 'true')
+ self.sys_config.set("mail", "true")
+ self.sys_config.set("ding", "true")
# 修改工单状态为待审核
- self.audit.workflow_type = WorkflowDict.workflow_type['sqlreview']
+ self.audit.workflow_type = WorkflowDict.workflow_type["sqlreview"]
self.audit.workflow_id = self.wf.id
- self.audit.current_status = WorkflowDict.workflow_status['audit_wait']
+ self.audit.current_status = WorkflowDict.workflow_status["audit_wait"]
self.audit.save()
r = notify_for_audit(audit_id=self.audit.audit_id)
self.assertIsNone(r)
_msg_sender.assert_called_once()
- @patch('sql.notify.MsgSender')
- @patch('sql.notify.auth_group_users')
+ @patch("sql.notify.MsgSender")
+ @patch("sql.notify.auth_group_users")
def test_notify_for_sqlreview_audit_success(self, _auth_group_users, _msg_sender):
"""
测试SQL上线申请审核通过通知
@@ -2449,20 +2977,20 @@ def test_notify_for_sqlreview_audit_success(self, _auth_group_users, _msg_sender
# 通知人修改
_auth_group_users.return_value = [self.user]
# 开启消息通知
- self.sys_config.set('mail', 'true')
- self.sys_config.set('ding', 'true')
+ self.sys_config.set("mail", "true")
+ self.sys_config.set("ding", "true")
# 修改工单状态审核通过
- self.audit.workflow_type = WorkflowDict.workflow_type['sqlreview']
+ self.audit.workflow_type = WorkflowDict.workflow_type["sqlreview"]
self.audit.workflow_id = self.wf.id
- self.audit.current_status = WorkflowDict.workflow_status['audit_success']
+ self.audit.current_status = WorkflowDict.workflow_status["audit_success"]
self.audit.create_user = self.user.username
self.audit.save()
r = notify_for_audit(audit_id=self.audit.audit_id)
self.assertIsNone(r)
_msg_sender.assert_called_once()
- @patch('sql.notify.MsgSender')
- @patch('sql.notify.auth_group_users')
+ @patch("sql.notify.MsgSender")
+ @patch("sql.notify.auth_group_users")
def test_notify_for_sqlreview_audit_reject(self, _auth_group_users, _msg_sender):
"""
测试SQL上线申请审核驳回通知
@@ -2471,20 +2999,20 @@ def test_notify_for_sqlreview_audit_reject(self, _auth_group_users, _msg_sender)
# 通知人修改
_auth_group_users.return_value = [self.user]
# 开启消息通知
- self.sys_config.set('mail', 'true')
- self.sys_config.set('ding', 'true')
+ self.sys_config.set("mail", "true")
+ self.sys_config.set("ding", "true")
# 修改工单状态审核通过
- self.audit.workflow_type = WorkflowDict.workflow_type['sqlreview']
+ self.audit.workflow_type = WorkflowDict.workflow_type["sqlreview"]
self.audit.workflow_id = self.wf.id
- self.audit.current_status = WorkflowDict.workflow_status['audit_reject']
+ self.audit.current_status = WorkflowDict.workflow_status["audit_reject"]
self.audit.create_user = self.user.username
self.audit.save()
r = notify_for_audit(audit_id=self.audit.audit_id)
self.assertIsNone(r)
_msg_sender.assert_called_once()
- @patch('sql.notify.MsgSender')
- @patch('sql.notify.auth_group_users')
+ @patch("sql.notify.MsgSender")
+ @patch("sql.notify.auth_group_users")
def test_notify_for_sqlreview_audit_abort(self, _auth_group_users, _msg_sender):
"""
测试SQL上线申请审核取消通知
@@ -2493,12 +3021,12 @@ def test_notify_for_sqlreview_audit_abort(self, _auth_group_users, _msg_sender):
# 通知人修改
_auth_group_users.return_value = [self.user]
# 开启消息通知
- self.sys_config.set('mail', 'true')
- self.sys_config.set('ding', 'true')
+ self.sys_config.set("mail", "true")
+ self.sys_config.set("ding", "true")
# 修改工单状态审核取消
- self.audit.workflow_type = WorkflowDict.workflow_type['sqlreview']
+ self.audit.workflow_type = WorkflowDict.workflow_type["sqlreview"]
self.audit.workflow_id = self.wf.id
- self.audit.current_status = WorkflowDict.workflow_status['audit_abort']
+ self.audit.current_status = WorkflowDict.workflow_status["audit_abort"]
self.audit.create_user = self.user.username
self.audit.audit_auth_groups = self.aug.id
self.audit.save()
@@ -2506,9 +3034,11 @@ def test_notify_for_sqlreview_audit_abort(self, _auth_group_users, _msg_sender):
self.assertIsNone(r)
_msg_sender.assert_called_once()
- @patch('sql.notify.MsgSender')
- @patch('sql.notify.auth_group_users')
- def test_notify_for_sqlreview_wrong_workflow_type(self, _auth_group_users, _msg_sender):
+ @patch("sql.notify.MsgSender")
+ @patch("sql.notify.auth_group_users")
+ def test_notify_for_sqlreview_wrong_workflow_type(
+ self, _auth_group_users, _msg_sender
+ ):
"""
测试不存在的工单类型
:return:
@@ -2516,17 +3046,19 @@ def test_notify_for_sqlreview_wrong_workflow_type(self, _auth_group_users, _msg_
# 通知人修改
_auth_group_users.return_value = [self.user]
# 开启消息通知
- self.sys_config.set('mail', 'true')
- self.sys_config.set('ding', 'true')
+ self.sys_config.set("mail", "true")
+ self.sys_config.set("ding", "true")
# 修改工单状态审核取消
self.audit.workflow_type = 10
self.audit.save()
- with self.assertRaisesMessage(Exception, '工单类型不正确'):
+ with self.assertRaisesMessage(Exception, "工单类型不正确"):
notify_for_audit(audit_id=self.audit.audit_id)
- @patch('sql.notify.MsgSender')
- @patch('sql.notify.auth_group_users')
- def test_notify_for_query_audit_wait_apply_db_perm(self, _auth_group_users, _msg_sender):
+ @patch("sql.notify.MsgSender")
+ @patch("sql.notify.auth_group_users")
+ def test_notify_for_query_audit_wait_apply_db_perm(
+ self, _auth_group_users, _msg_sender
+ ):
"""
测试查询申请库权限
:return:
@@ -2534,12 +3066,12 @@ def test_notify_for_query_audit_wait_apply_db_perm(self, _auth_group_users, _msg
# 通知人修改
_auth_group_users.return_value = [self.user]
# 开启消息通知
- self.sys_config.set('mail', 'true')
- self.sys_config.set('ding', 'true')
+ self.sys_config.set("mail", "true")
+ self.sys_config.set("ding", "true")
# 修改工单状态为待审核
- self.audit.workflow_type = WorkflowDict.workflow_type['query']
+ self.audit.workflow_type = WorkflowDict.workflow_type["query"]
self.audit.workflow_id = self.query_apply_1.apply_id
- self.audit.current_status = WorkflowDict.workflow_status['audit_wait']
+ self.audit.current_status = WorkflowDict.workflow_status["audit_wait"]
self.audit.save()
# 修改工单为库权限申请
self.query_apply_1.priv_type = 1
@@ -2548,9 +3080,11 @@ def test_notify_for_query_audit_wait_apply_db_perm(self, _auth_group_users, _msg
self.assertIsNone(r)
_msg_sender.assert_called_once()
- @patch('sql.notify.MsgSender')
- @patch('sql.notify.auth_group_users')
- def test_notify_for_query_audit_wait_apply_tb_perm(self, _auth_group_users, _msg_sender):
+ @patch("sql.notify.MsgSender")
+ @patch("sql.notify.auth_group_users")
+ def test_notify_for_query_audit_wait_apply_tb_perm(
+ self, _auth_group_users, _msg_sender
+ ):
"""
测试查询申请表权限
:return:
@@ -2558,12 +3092,12 @@ def test_notify_for_query_audit_wait_apply_tb_perm(self, _auth_group_users, _msg
# 通知人修改
_auth_group_users.return_value = [self.user]
# 开启消息通知
- self.sys_config.set('mail', 'true')
- self.sys_config.set('ding', 'true')
+ self.sys_config.set("mail", "true")
+ self.sys_config.set("ding", "true")
# 修改工单状态为待审核
- self.audit.workflow_type = WorkflowDict.workflow_type['query']
+ self.audit.workflow_type = WorkflowDict.workflow_type["query"]
self.audit.workflow_id = self.query_apply_1.apply_id
- self.audit.current_status = WorkflowDict.workflow_status['audit_wait']
+ self.audit.current_status = WorkflowDict.workflow_status["audit_wait"]
self.audit.save()
# 修改工单为表权限申请
self.query_apply_1.priv_type = 2
@@ -2572,21 +3106,21 @@ def test_notify_for_query_audit_wait_apply_tb_perm(self, _auth_group_users, _msg
self.assertIsNone(r)
_msg_sender.assert_called_once()
- @patch('sql.notify.MsgSender')
+ @patch("sql.notify.MsgSender")
def test_notify_for_execute_disable(self, _msg_sender):
"""
测试执行消息关闭
:return:
"""
# 开启消息通知
- self.sys_config.set('mail', 'false')
- self.sys_config.set('ding', 'false')
+ self.sys_config.set("mail", "false")
+ self.sys_config.set("ding", "false")
r = notify_for_execute(self.wf)
self.assertIsNone(r)
- @patch('sql.notify.auth_group_users')
- @patch('sql.notify.Audit')
- @patch('sql.notify.MsgSender')
+ @patch("sql.notify.auth_group_users")
+ @patch("sql.notify.Audit")
+ @patch("sql.notify.MsgSender")
def test_notify_for_execute(self, _msg_sender, _audit, _auth_group_users):
"""
测试执行消息
@@ -2594,42 +3128,43 @@ def test_notify_for_execute(self, _msg_sender, _audit, _auth_group_users):
"""
_auth_group_users.return_value = [self.user]
# 处理工单信息
- _audit.review_info.return_value = self.audit.audit_auth_groups, self.audit.current_audit
+ _audit.review_info.return_value = (
+ self.audit.audit_auth_groups,
+ self.audit.current_audit,
+ )
# 开启消息通知
- self.sys_config.set('mail', 'true')
- self.sys_config.set('ding', 'true')
- self.sys_config.set('ddl_notify_auth_group', self.aug.name)
+ self.sys_config.set("mail", "true")
+ self.sys_config.set("ding", "true")
+ self.sys_config.set("ddl_notify_auth_group", self.aug.name)
# 修改工单状态为执行结束,修改为DDL工单
- self.wf.status = 'workflow_finish'
+ self.wf.status = "workflow_finish"
self.wf.syntax_type = 1
self.wf.save()
r = notify_for_execute(self.wf)
self.assertIsNone(r)
_msg_sender.assert_called()
-
- @patch('sql.notify.MsgSender')
+ @patch("sql.notify.MsgSender")
def test_notify_for_my2sql_disable(self, _msg_sender):
"""
测试执行消息关闭
:return:
"""
# 开启消息通知
- self.sys_config.set('mail', 'false')
- self.sys_config.set('ding', 'false')
+ self.sys_config.set("mail", "false")
+ self.sys_config.set("ding", "false")
r = notify_for_execute(self.wf)
self.assertIsNone(r)
-
- @patch('django_q.tasks.async_task')
- @patch('sql.notify.MsgSender')
+ @patch("django_q.tasks.async_task")
+ @patch("sql.notify.MsgSender")
def test_notify_for_my2sql(self, _msg_sender, _async_task):
"""
测试执行消息
:return:
"""
# 开启消息通知
- self.sys_config.set('mail', 'true')
+ self.sys_config.set("mail", "true")
# 设置为task成功
_async_task.return_value.success.return_value = True
r = notify_for_my2sql(_async_task)
@@ -2644,17 +3179,23 @@ class TestDataDictionary(TestCase):
def setUp(self):
self.sys_config = SysConfig()
- self.su = User.objects.create(username='s_user', display='中文显示', is_active=True, is_superuser=True)
- self.u1 = User.objects.create(username='user1', display='中文显示', is_active=True)
+ self.su = User.objects.create(
+ username="s_user", display="中文显示", is_active=True, is_superuser=True
+ )
+ self.u1 = User.objects.create(username="user1", display="中文显示", is_active=True)
self.client = Client()
self.client.force_login(self.su)
# 使用 travis.ci 时实例和测试service保持一致
- self.ins = Instance.objects.create(instance_name='test_instance', type='slave', db_type='mysql',
- host=settings.DATABASES['default']['HOST'],
- port=settings.DATABASES['default']['PORT'],
- user=settings.DATABASES['default']['USER'],
- password=settings.DATABASES['default']['PASSWORD'])
- self.db_name = settings.DATABASES['default']['TEST']['NAME']
+ self.ins = Instance.objects.create(
+ instance_name="test_instance",
+ type="slave",
+ db_type="mysql",
+ host=settings.DATABASES["default"]["HOST"],
+ port=settings.DATABASES["default"]["PORT"],
+ user=settings.DATABASES["default"]["USER"],
+ password=settings.DATABASES["default"]["PASSWORD"],
+ )
+ self.db_name = settings.DATABASES["default"]["TEST"]["NAME"]
def tearDown(self):
self.sys_config.purge()
@@ -2666,38 +3207,39 @@ def test_data_dictionary_view(self):
测试访问数据字典页面
:return:
"""
- r = self.client.get(path='/data_dictionary/')
+ r = self.client.get(path="/data_dictionary/")
self.assertEqual(r.status_code, 200)
- @patch('sql.data_dictionary.get_engine')
+ @patch("sql.data_dictionary.get_engine")
def test_table_list(self, _get_engine):
"""
测试获取表清单
:return:
"""
- _get_engine.return_value.get_group_tables_by_db.return_value = {'t': [['test1', '测试表1'], ['test2', '测试表2']]}
+ _get_engine.return_value.get_group_tables_by_db.return_value = {
+ "t": [["test1", "测试表1"], ["test2", "测试表2"]]
+ }
data = {
- 'instance_name': self.ins.instance_name,
- 'db_name': self.db_name,
- 'db_type': 'mysql'
+ "instance_name": self.ins.instance_name,
+ "db_name": self.db_name,
+ "db_type": "mysql",
}
- r = self.client.get(path='/data_dictionary/table_list/', data=data)
+ r = self.client.get(path="/data_dictionary/table_list/", data=data)
self.assertEqual(r.status_code, 200)
- self.assertDictEqual(json.loads(r.content),
- {'data': {'t': [['test1', '测试表1'], ['test2', '测试表2']]}, 'status': 0})
+ self.assertDictEqual(
+ json.loads(r.content),
+ {"data": {"t": [["test1", "测试表1"], ["test2", "测试表2"]]}, "status": 0},
+ )
def test_table_list_not_param(self):
"""
测试获取表清单,参数不完整
:return:
"""
- data = {
- 'instance_name': 'not exist ins',
- 'db_type': 'mysql'
- }
- r = self.client.get(path='/data_dictionary/table_list/', data=data)
+ data = {"instance_name": "not exist ins", "db_type": "mysql"}
+ r = self.client.get(path="/data_dictionary/table_list/", data=data)
self.assertEqual(r.status_code, 200)
- self.assertDictEqual(json.loads(r.content), {'msg': '非法调用!', 'status': 1})
+ self.assertDictEqual(json.loads(r.content), {"msg": "非法调用!", "status": 1})
def test_table_list_instance_does_not_exist(self):
"""
@@ -2705,46 +3247,53 @@ def test_table_list_instance_does_not_exist(self):
:return:
"""
data = {
- 'instance_name': 'not exist ins',
- 'db_name': self.db_name,
- 'db_type': 'mysql'
+ "instance_name": "not exist ins",
+ "db_name": self.db_name,
+ "db_type": "mysql",
}
- r = self.client.get(path='/data_dictionary/table_list/', data=data)
+ r = self.client.get(path="/data_dictionary/table_list/", data=data)
self.assertEqual(r.status_code, 200)
- self.assertDictEqual(json.loads(r.content), {'msg': 'Instance.DoesNotExist', 'status': 1})
+ self.assertDictEqual(
+ json.loads(r.content), {"msg": "Instance.DoesNotExist", "status": 1}
+ )
- @patch('sql.data_dictionary.get_engine')
+ @patch("sql.data_dictionary.get_engine")
def test_table_list_exception(self, _get_engine):
"""
测试获取表清单,异常
:return:
"""
- _get_engine.side_effect = RuntimeError('test error')
+ _get_engine.side_effect = RuntimeError("test error")
data = {
- 'instance_name': self.ins.instance_name,
- 'db_name': self.db_name,
- 'db_type': 'mysql'
+ "instance_name": self.ins.instance_name,
+ "db_name": self.db_name,
+ "db_type": "mysql",
}
- r = self.client.get(path='/data_dictionary/table_list/', data=data)
+ r = self.client.get(path="/data_dictionary/table_list/", data=data)
self.assertEqual(r.status_code, 200)
- self.assertDictEqual(json.loads(r.content), {'msg': 'test error', 'status': 1})
+ self.assertDictEqual(json.loads(r.content), {"msg": "test error", "status": 1})
- @patch('sql.data_dictionary.get_engine')
+ @patch("sql.data_dictionary.get_engine")
def test_table_info(self, _get_engine):
"""
测试获取表信息
:return:
"""
- _get_engine.return_value.query.return_value = ResultSet(rows=(('test1', '测试表1'), ('test2', '测试表2')))
+ _get_engine.return_value.query.return_value = ResultSet(
+ rows=(("test1", "测试表1"), ("test2", "测试表2"))
+ )
data = {
- 'instance_name': self.ins.instance_name,
- 'db_name': self.db_name,
- 'tb_name': 'sql_instance',
- 'db_type': 'mysql'
+ "instance_name": self.ins.instance_name,
+ "db_name": self.db_name,
+ "tb_name": "sql_instance",
+ "db_type": "mysql",
}
- r = self.client.get(path='/data_dictionary/table_info/', data=data)
+ r = self.client.get(path="/data_dictionary/table_info/", data=data)
self.assertEqual(r.status_code, 200)
- self.assertListEqual(list(json.loads(r.content)['data'].keys()), ['meta_data', 'desc', 'index', 'create_sql'])
+ self.assertListEqual(
+ list(json.loads(r.content)["data"].keys()),
+ ["meta_data", "desc", "index", "create_sql"],
+ )
def test_table_info_not_param(self):
"""
@@ -2752,11 +3301,11 @@ def test_table_info_not_param(self):
:return:
"""
data = {
- 'instance_name': 'not exist ins',
+ "instance_name": "not exist ins",
}
- r = self.client.get(path='/data_dictionary/table_info/', data=data)
+ r = self.client.get(path="/data_dictionary/table_info/", data=data)
self.assertEqual(r.status_code, 200)
- self.assertDictEqual(json.loads(r.content), {'msg': '非法调用!', 'status': 1})
+ self.assertDictEqual(json.loads(r.content), {"msg": "非法调用!", "status": 1})
def test_table_info_instance_does_not_exist(self):
"""
@@ -2764,31 +3313,33 @@ def test_table_info_instance_does_not_exist(self):
:return:
"""
data = {
- 'instance_name': 'not exist ins',
- 'db_name': self.db_name,
- 'tb_name': 'sql_instance',
- 'db_type': 'mysql'
+ "instance_name": "not exist ins",
+ "db_name": self.db_name,
+ "tb_name": "sql_instance",
+ "db_type": "mysql",
}
- r = self.client.get(path='/data_dictionary/table_info/', data=data)
+ r = self.client.get(path="/data_dictionary/table_info/", data=data)
self.assertEqual(r.status_code, 200)
- self.assertDictEqual(json.loads(r.content), {'msg': 'Instance.DoesNotExist', 'status': 1})
+ self.assertDictEqual(
+ json.loads(r.content), {"msg": "Instance.DoesNotExist", "status": 1}
+ )
- @patch('sql.data_dictionary.get_engine')
+ @patch("sql.data_dictionary.get_engine")
def test_table_info_exception(self, _get_engine):
"""
测试获取表清单,异常
:return:
"""
- _get_engine.side_effect = RuntimeError('test error')
+ _get_engine.side_effect = RuntimeError("test error")
data = {
- 'instance_name': self.ins.instance_name,
- 'db_name': self.db_name,
- 'tb_name': 'sql_instance',
- 'db_type': 'mysql'
+ "instance_name": self.ins.instance_name,
+ "db_name": self.db_name,
+ "tb_name": "sql_instance",
+ "db_type": "mysql",
}
- r = self.client.get(path='/data_dictionary/table_info/', data=data)
+ r = self.client.get(path="/data_dictionary/table_info/", data=data)
self.assertEqual(r.status_code, 200)
- self.assertDictEqual(json.loads(r.content), {'msg': 'test error', 'status': 1})
+ self.assertDictEqual(json.loads(r.content), {"msg": "test error", "status": 1})
def test_export_instance_does_not_exist(self):
"""
@@ -2796,137 +3347,261 @@ def test_export_instance_does_not_exist(self):
:return:
"""
data = {
- 'instance_name': 'not_exist',
- 'db_name': self.db_name,
- 'db_type': 'mysql'
+ "instance_name": "not_exist",
+ "db_name": self.db_name,
+ "db_type": "mysql",
}
- r = self.client.get(path='/data_dictionary/export/', data=data)
- self.assertDictEqual(json.loads(r.content), {'data': [], 'msg': '你所在组未关联该实例!', 'status': 1})
+ r = self.client.get(path="/data_dictionary/export/", data=data)
+ self.assertDictEqual(
+ json.loads(r.content), {"data": [], "msg": "你所在组未关联该实例!", "status": 1}
+ )
- @patch('sql.data_dictionary.user_instances')
- @patch('sql.data_dictionary.get_engine')
+ @patch("sql.data_dictionary.user_instances")
+ @patch("sql.data_dictionary.get_engine")
def test_export_ins_no_perm(self, _get_engine, _user_instances):
"""
测试导出实例无权限
:return:
"""
self.client.force_login(self.u1)
- data_dictionary_export = Permission.objects.get(codename='data_dictionary_export')
+ data_dictionary_export = Permission.objects.get(
+ codename="data_dictionary_export"
+ )
self.u1.user_permissions.add(data_dictionary_export)
_user_instances.return_value.get.return_value = self.ins
- data = {
- 'instance_name': self.ins.instance_name,
- 'db_type': 'mysql'
-
- }
- r = self.client.get(path='/data_dictionary/export/', data=data)
- self.assertDictEqual(json.loads(r.content),
- {'status': 1, 'msg': f'仅管理员可以导出整个实例的字典信息!', 'data': []})
+ data = {"instance_name": self.ins.instance_name, "db_type": "mysql"}
+ r = self.client.get(path="/data_dictionary/export/", data=data)
+ self.assertDictEqual(
+ json.loads(r.content),
+ {"status": 1, "msg": f"仅管理员可以导出整个实例的字典信息!", "data": []},
+ )
- @patch('sql.data_dictionary.get_engine')
+ @patch("sql.data_dictionary.get_engine")
def test_export_db(self, _get_engine):
"""
测试导出
:return:
"""
- _get_engine.return_value.get_all_databases.return_value.rows.return_value = ResultSet(
- rows=(('test1',), ('test2',)))
- _get_engine.return_value.query.return_value = ResultSet(rows=(
- {'TABLE_CATALOG': 'def', 'TABLE_SCHEMA': 'archer', 'TABLE_NAME': 'aliyun_rds_config',
- 'TABLE_TYPE': 'BASE TABLE', 'ENGINE': 'InnoDB', 'VERSION': 10, 'ROW_FORMAT': 'Dynamic', 'TABLE_ROWS': 0,
- 'AVG_ROW_LENGTH': 0, 'DATA_LENGTH': 16384, 'MAX_DATA_LENGTH': 0, 'INDEX_LENGTH': 32768, 'DATA_FREE': 0,
- 'AUTO_INCREMENT': 1, 'CREATE_TIME': datetime(2019, 5, 28, 9, 25, 41), 'UPDATE_TIME': None,
- 'CHECK_TIME': None, 'TABLE_COLLATION': 'utf8_general_ci', 'CHECKSUM': None, 'CREATE_OPTIONS': '',
- 'TABLE_COMMENT': ''},
- {'TABLE_CATALOG': 'def', 'TABLE_SCHEMA': 'archer', 'TABLE_NAME': 'auth_group', 'TABLE_TYPE': 'BASE TABLE',
- 'ENGINE': 'InnoDB', 'VERSION': 10, 'ROW_FORMAT': 'Dynamic', 'TABLE_ROWS': 8, 'AVG_ROW_LENGTH': 2048,
- 'DATA_LENGTH': 16384, 'MAX_DATA_LENGTH': 0, 'INDEX_LENGTH': 16384, 'DATA_FREE': 0, 'AUTO_INCREMENT': 9,
- 'CREATE_TIME': datetime(2019, 5, 28, 9, 4, 11), 'UPDATE_TIME': None, 'CHECK_TIME': None,
- 'TABLE_COLLATION': 'utf8_general_ci', 'CHECKSUM': None, 'CREATE_OPTIONS': '', 'TABLE_COMMENT': ''}))
+ _get_engine.return_value.get_all_databases.return_value.rows.return_value = (
+ ResultSet(rows=(("test1",), ("test2",)))
+ )
+ _get_engine.return_value.query.return_value = ResultSet(
+ rows=(
+ {
+ "TABLE_CATALOG": "def",
+ "TABLE_SCHEMA": "archer",
+ "TABLE_NAME": "aliyun_rds_config",
+ "TABLE_TYPE": "BASE TABLE",
+ "ENGINE": "InnoDB",
+ "VERSION": 10,
+ "ROW_FORMAT": "Dynamic",
+ "TABLE_ROWS": 0,
+ "AVG_ROW_LENGTH": 0,
+ "DATA_LENGTH": 16384,
+ "MAX_DATA_LENGTH": 0,
+ "INDEX_LENGTH": 32768,
+ "DATA_FREE": 0,
+ "AUTO_INCREMENT": 1,
+ "CREATE_TIME": datetime(2019, 5, 28, 9, 25, 41),
+ "UPDATE_TIME": None,
+ "CHECK_TIME": None,
+ "TABLE_COLLATION": "utf8_general_ci",
+ "CHECKSUM": None,
+ "CREATE_OPTIONS": "",
+ "TABLE_COMMENT": "",
+ },
+ {
+ "TABLE_CATALOG": "def",
+ "TABLE_SCHEMA": "archer",
+ "TABLE_NAME": "auth_group",
+ "TABLE_TYPE": "BASE TABLE",
+ "ENGINE": "InnoDB",
+ "VERSION": 10,
+ "ROW_FORMAT": "Dynamic",
+ "TABLE_ROWS": 8,
+ "AVG_ROW_LENGTH": 2048,
+ "DATA_LENGTH": 16384,
+ "MAX_DATA_LENGTH": 0,
+ "INDEX_LENGTH": 16384,
+ "DATA_FREE": 0,
+ "AUTO_INCREMENT": 9,
+ "CREATE_TIME": datetime(2019, 5, 28, 9, 4, 11),
+ "UPDATE_TIME": None,
+ "CHECK_TIME": None,
+ "TABLE_COLLATION": "utf8_general_ci",
+ "CHECKSUM": None,
+ "CREATE_OPTIONS": "",
+ "TABLE_COMMENT": "",
+ },
+ )
+ )
data = {
- 'instance_name': self.ins.instance_name,
- 'db_name': self.db_name,
- 'db_type': 'mysql'
+ "instance_name": self.ins.instance_name,
+ "db_name": self.db_name,
+ "db_type": "mysql",
}
- r = self.client.get(path='/data_dictionary/export/', data=data)
+ r = self.client.get(path="/data_dictionary/export/", data=data)
self.assertEqual(r.status_code, 200)
self.assertTrue(r.streaming)
-
- @patch('sql.data_dictionary.get_engine')
+ @patch("sql.data_dictionary.get_engine")
def oracle_test_export_db(self, _get_engine):
"""
oracle测试导出
:return:
"""
- _get_engine.return_value.get_all_databases.return_value.rows.return_value = ResultSet(
- rows=(('test1',), ('test2',)))
- _get_engine.return_value.query.return_value = ResultSet(rows=(
- { 'TABLE_NAME': 'aliyun_rds_config', 'TABLE_COMMENTS': 'TABLE', 'COLUMN_NAME':'t1', 'data_type': 'varcher2(20)', 'DATA_DEFAULT': 'Dynamic', 'NULLABLE': 'Y', 'INDEX_NAME': 'SYS_01', 'COMMENTS': 'SYS_01'
- },
- { 'TABLE_NAME': 'auth_group', 'TABLE_COMMENTS': 'TABLE', 'COLUMN_NAME': 't1', 'data_type': 'varcher2(20)', 'DATA_DEFAULT': 'Dynamic', 'NULLABLE': 'N', 'INDEX_NAME': 'SYS_01','COMMENTS': 'SYS_01'
- }))
+ _get_engine.return_value.get_all_databases.return_value.rows.return_value = (
+ ResultSet(rows=(("test1",), ("test2",)))
+ )
+ _get_engine.return_value.query.return_value = ResultSet(
+ rows=(
+ {
+ "TABLE_NAME": "aliyun_rds_config",
+ "TABLE_COMMENTS": "TABLE",
+ "COLUMN_NAME": "t1",
+ "data_type": "varcher2(20)",
+ "DATA_DEFAULT": "Dynamic",
+ "NULLABLE": "Y",
+ "INDEX_NAME": "SYS_01",
+ "COMMENTS": "SYS_01",
+ },
+ {
+ "TABLE_NAME": "auth_group",
+ "TABLE_COMMENTS": "TABLE",
+ "COLUMN_NAME": "t1",
+ "data_type": "varcher2(20)",
+ "DATA_DEFAULT": "Dynamic",
+ "NULLABLE": "N",
+ "INDEX_NAME": "SYS_01",
+ "COMMENTS": "SYS_01",
+ },
+ )
+ )
data = {
- 'instance_name': self.ins.instance_name,
- 'db_name': self.db_name,
- 'db_type': 'oracle'
+ "instance_name": self.ins.instance_name,
+ "db_name": self.db_name,
+ "db_type": "oracle",
}
- r = self.client.get(path='/data_dictionary/export/', data=data)
- print("oracle_test_export_db" )
+ r = self.client.get(path="/data_dictionary/export/", data=data)
+ print("oracle_test_export_db")
self.assertEqual(r.status_code, 200)
self.assertTrue(r.streaming)
- @patch('sql.data_dictionary.get_engine')
+ @patch("sql.data_dictionary.get_engine")
def test_export_instance(self, _get_engine):
"""
测试导出
:return:
"""
- _get_engine.return_value.get_all_databases.return_value.rows.return_value = ResultSet(
- rows=(('test1',), ('test2',)))
- _get_engine.return_value.query.return_value = ResultSet(rows=(
- {'TABLE_CATALOG': 'def', 'TABLE_SCHEMA': 'archer', 'TABLE_NAME': 'aliyun_rds_config',
- 'TABLE_TYPE': 'BASE TABLE', 'ENGINE': 'InnoDB', 'VERSION': 10, 'ROW_FORMAT': 'Dynamic', 'TABLE_ROWS': 0,
- 'AVG_ROW_LENGTH': 0, 'DATA_LENGTH': 16384, 'MAX_DATA_LENGTH': 0, 'INDEX_LENGTH': 32768, 'DATA_FREE': 0,
- 'AUTO_INCREMENT': 1, 'CREATE_TIME': datetime(2019, 5, 28, 9, 25, 41), 'UPDATE_TIME': None,
- 'CHECK_TIME': None, 'TABLE_COLLATION': 'utf8_general_ci', 'CHECKSUM': None, 'CREATE_OPTIONS': '',
- 'TABLE_COMMENT': ''},
- {'TABLE_CATALOG': 'def', 'TABLE_SCHEMA': 'archer', 'TABLE_NAME': 'auth_group', 'TABLE_TYPE': 'BASE TABLE',
- 'ENGINE': 'InnoDB', 'VERSION': 10, 'ROW_FORMAT': 'Dynamic', 'TABLE_ROWS': 8, 'AVG_ROW_LENGTH': 2048,
- 'DATA_LENGTH': 16384, 'MAX_DATA_LENGTH': 0, 'INDEX_LENGTH': 16384, 'DATA_FREE': 0, 'AUTO_INCREMENT': 9,
- 'CREATE_TIME': datetime(2019, 5, 28, 9, 4, 11), 'UPDATE_TIME': None, 'CHECK_TIME': None,
- 'TABLE_COLLATION': 'utf8_general_ci', 'CHECKSUM': None, 'CREATE_OPTIONS': '', 'TABLE_COMMENT': ''}))
- data = {
- 'instance_name': self.ins.instance_name,
- 'db_type':'mysql'
- }
- r = self.client.get(path='/data_dictionary/export/', data=data)
+ _get_engine.return_value.get_all_databases.return_value.rows.return_value = (
+ ResultSet(rows=(("test1",), ("test2",)))
+ )
+ _get_engine.return_value.query.return_value = ResultSet(
+ rows=(
+ {
+ "TABLE_CATALOG": "def",
+ "TABLE_SCHEMA": "archer",
+ "TABLE_NAME": "aliyun_rds_config",
+ "TABLE_TYPE": "BASE TABLE",
+ "ENGINE": "InnoDB",
+ "VERSION": 10,
+ "ROW_FORMAT": "Dynamic",
+ "TABLE_ROWS": 0,
+ "AVG_ROW_LENGTH": 0,
+ "DATA_LENGTH": 16384,
+ "MAX_DATA_LENGTH": 0,
+ "INDEX_LENGTH": 32768,
+ "DATA_FREE": 0,
+ "AUTO_INCREMENT": 1,
+ "CREATE_TIME": datetime(2019, 5, 28, 9, 25, 41),
+ "UPDATE_TIME": None,
+ "CHECK_TIME": None,
+ "TABLE_COLLATION": "utf8_general_ci",
+ "CHECKSUM": None,
+ "CREATE_OPTIONS": "",
+ "TABLE_COMMENT": "",
+ },
+ {
+ "TABLE_CATALOG": "def",
+ "TABLE_SCHEMA": "archer",
+ "TABLE_NAME": "auth_group",
+ "TABLE_TYPE": "BASE TABLE",
+ "ENGINE": "InnoDB",
+ "VERSION": 10,
+ "ROW_FORMAT": "Dynamic",
+ "TABLE_ROWS": 8,
+ "AVG_ROW_LENGTH": 2048,
+ "DATA_LENGTH": 16384,
+ "MAX_DATA_LENGTH": 0,
+ "INDEX_LENGTH": 16384,
+ "DATA_FREE": 0,
+ "AUTO_INCREMENT": 9,
+ "CREATE_TIME": datetime(2019, 5, 28, 9, 4, 11),
+ "UPDATE_TIME": None,
+ "CHECK_TIME": None,
+ "TABLE_COLLATION": "utf8_general_ci",
+ "CHECKSUM": None,
+ "CREATE_OPTIONS": "",
+ "TABLE_COMMENT": "",
+ },
+ )
+ )
+ data = {"instance_name": self.ins.instance_name, "db_type": "mysql"}
+ r = self.client.get(path="/data_dictionary/export/", data=data)
self.assertEqual(r.status_code, 200)
- self.assertDictEqual(json.loads(r.content),
- {'data': [], 'msg': '实例test_instance数据字典导出成功,请到downloads目录下载!', 'status': 0})
+ self.assertDictEqual(
+ json.loads(r.content),
+ {
+ "data": [],
+ "msg": "实例test_instance数据字典导出成功,请到downloads目录下载!",
+ "status": 0,
+ },
+ )
- @patch('sql.data_dictionary.get_engine')
+ @patch("sql.data_dictionary.get_engine")
def oracle_test_export_instance(self, _get_engine):
"""
oracle元数据测试导出
:return:
"""
- _get_engine.return_value.get_all_databases.return_value.rows.return_value = ResultSet(
- rows=(('test1',), ('test2',)))
- _get_engine.return_value.query.return_value = ResultSet(rows=(
- { 'TABLE_NAME': 'aliyun_rds_config', 'TABLE_COMMENTS': 'TABLE', 'COLUMN_NAME':'t1', 'data_type': 'varcher2(20)', 'DATA_DEFAULT': 'Dynamic', 'NULLABLE': 'Y', 'INDEX_NAME': 'SYS_01', 'COMMENTS': 'SYS_01'
- },
- { 'TABLE_NAME': 'auth_group', 'TABLE_COMMENTS': 'TABLE', 'COLUMN_NAME': 't1', 'data_type': 'varcher2(20)', 'DATA_DEFAULT': 'Dynamic', 'NULLABLE': 'N', 'INDEX_NAME': 'SYS_01','COMMENTS': 'SYS_01'
- }))
- data = {
- 'instance_name': self.ins.instance_name,
- 'db_type':'oracle'
- }
- r = self.client.get(path='/data_dictionary/export/', data=data)
+ _get_engine.return_value.get_all_databases.return_value.rows.return_value = (
+ ResultSet(rows=(("test1",), ("test2",)))
+ )
+ _get_engine.return_value.query.return_value = ResultSet(
+ rows=(
+ {
+ "TABLE_NAME": "aliyun_rds_config",
+ "TABLE_COMMENTS": "TABLE",
+ "COLUMN_NAME": "t1",
+ "data_type": "varcher2(20)",
+ "DATA_DEFAULT": "Dynamic",
+ "NULLABLE": "Y",
+ "INDEX_NAME": "SYS_01",
+ "COMMENTS": "SYS_01",
+ },
+ {
+ "TABLE_NAME": "auth_group",
+ "TABLE_COMMENTS": "TABLE",
+ "COLUMN_NAME": "t1",
+ "data_type": "varcher2(20)",
+ "DATA_DEFAULT": "Dynamic",
+ "NULLABLE": "N",
+ "INDEX_NAME": "SYS_01",
+ "COMMENTS": "SYS_01",
+ },
+ )
+ )
+ data = {"instance_name": self.ins.instance_name, "db_type": "oracle"}
+ r = self.client.get(path="/data_dictionary/export/", data=data)
print(r.status_code)
- print("oracle_test_export_instance" )
+ print("oracle_test_export_instance")
self.assertEqual(r.status_code, 200)
- self.assertDictEqual(json.loads(r.content),
- {'data': [], 'msg': '实例test_instance数据字典导出成功,请到downloads目录下载!', 'status': 0})
-
+ self.assertDictEqual(
+ json.loads(r.content),
+ {
+ "data": [],
+ "msg": "实例test_instance数据字典导出成功,请到downloads目录下载!",
+ "status": 0,
+ },
+ )
diff --git a/sql/urls.py b/sql/urls.py
index 71202950ef..084313569f 100644
--- a/sql/urls.py
+++ b/sql/urls.py
@@ -8,157 +8,160 @@
import sql.sql_optimize
from common import auth, config, workflow, dashboard, check
from common.twofa import totp
-from sql import views, sql_workflow, sql_analyze, query, slowlog, instance, instance_account, db_diagnostic, \
- resource_group, binlog, data_dictionary, archiver, audit_log, user
+from sql import (
+ views,
+ sql_workflow,
+ sql_analyze,
+ query,
+ slowlog,
+ instance,
+ instance_account,
+ db_diagnostic,
+ resource_group,
+ binlog,
+ data_dictionary,
+ archiver,
+ audit_log,
+ user,
+)
from sql.utils import tasks
from common.utils import ding_api
urlpatterns = [
- path('', views.index),
- path('jsi18n/', JavaScriptCatalog.as_view(), name='javascript-catalog'),
- path('index/', views.index),
- path('login/', views.login, name='login'),
- path('login/2fa/', views.twofa, name='twofa'),
- path('logout/', auth.sign_out),
- path('signup/', auth.sign_up),
- path('sqlworkflow/', views.sqlworkflow),
- path('submitsql/', views.submit_sql),
- path('editsql/', views.submit_sql),
- path('submitotherinstance/', views.submit_sql),
- path('detail//', views.detail, name='detail'),
- path('autoreview/', sql_workflow.submit),
- path('passed/', sql_workflow.passed),
- path('execute/', sql_workflow.execute),
- path('timingtask/', sql_workflow.timing_task),
- path('alter_run_date/', sql_workflow.alter_run_date),
- path('cancel/', sql_workflow.cancel),
- path('rollback/', views.rollback),
- path('sqlanalyze/', views.sqlanalyze),
- path('sqlquery/', views.sqlquery),
- path('slowquery/', views.slowquery),
- path('sqladvisor/', views.sqladvisor),
- path('slowquery_advisor/', views.sqladvisor),
- path('queryapplylist/', views.queryapplylist),
- path('queryapplydetail//', views.queryapplydetail, name='queryapplydetail'),
- path('queryuserprivileges/', views.queryuserprivileges),
- path('dbdiagnostic/', views.dbdiagnostic),
- path('workflow/', views.workflows),
- path('workflow//', views.workflowsdetail),
- path('dbaprinciples/', views.dbaprinciples),
- path('dashboard/', dashboard.pyecharts),
- path('group/', views.group),
- path('grouprelations//', views.groupmgmt),
- path('instance/', views.instance),
- path('instanceaccount/', views.instanceaccount),
- path('database/', views.database),
- path('instanceparam/', views.instance_param),
- path('my2sql/', views.my2sql),
- path('schemasync/', views.schemasync),
- path('archive/', views.archive),
- path('archive//', views.archive_detail, name='archive_detail'),
- path('config/', views.config),
- path('audit/', views.audit),
- path('audit_sqlquery/', views.audit_sqlquery),
- path('audit_sqlworkflow/', views.audit_sqlworkflow),
-
- path('authenticate/', auth.authenticate_entry),
- path('sqlworkflow_list/', sql_workflow.sql_workflow_list),
- path('sqlworkflow_list_audit/', sql_workflow.sql_workflow_list_audit),
- path('sqlworkflow/detail_content/', sql_workflow.detail_content),
- path('sqlworkflow/backup_sql/', sql_workflow.backup_sql),
- path('simplecheck/', sql_workflow.check),
- path('getWorkflowStatus/', sql_workflow.get_workflow_status),
- path('del_sqlcronjob/', tasks.del_schedule),
- path('inception/osc_control/', sql_workflow.osc_control),
-
- path('sql_analyze/generate/', sql_analyze.generate),
- path('sql_analyze/analyze/', sql_analyze.analyze),
-
- path('workflow/list/', workflow.lists),
- path('workflow/log/', workflow.log),
- path('config/change/', config.change_config),
-
- path('check/go_inception/', check.go_inception),
- path('check/email/', check.email),
- path('check/instance/', check.instance),
-
- path('group/group/', resource_group.group),
- path('group/addrelation/', resource_group.addrelation),
- path('group/relations/', resource_group.associated_objects),
- path('group/instances/', resource_group.instances),
- path('group/unassociated/', resource_group.unassociated_objects),
- path('group/auditors/', resource_group.auditors),
- path('group/changeauditors/', resource_group.changeauditors),
- path('group/user_all_instances/', resource_group.user_all_instances),
-
- path('instance/list/', instance.lists),
-
- path('instance/user/list', instance_account.users),
- path('instance/user/create/', instance_account.create),
- path('instance/user/edit/', instance_account.edit),
- path('instance/user/grant/', instance_account.grant),
- path('instance/user/reset_pwd/', instance_account.reset_pwd),
- path('instance/user/lock/', instance_account.lock),
- path('instance/user/delete/', instance_account.delete),
-
- path('instance/database/list/', sql.instance_database.databases),
- path('instance/database/create/', sql.instance_database.create),
- path('instance/database/edit/', sql.instance_database.edit),
-
- path('instance/schemasync/', instance.schemasync),
- path('instance/instance_resource/', instance.instance_resource),
- path('instance/describetable/', instance.describe),
-
- path('data_dictionary/', views.data_dictionary),
- path('data_dictionary/table_list/', data_dictionary.table_list),
- path('data_dictionary/table_info/', data_dictionary.table_info),
- path('data_dictionary/export/', data_dictionary.export),
-
- path('param/list/', instance.param_list),
- path('param/history/', instance.param_history),
- path('param/edit/', instance.param_edit),
-
- path('query/', query.query),
- path('query/querylog/', query.querylog),
- path('query/querylog_audit/', query.querylog_audit),
- path('query/favorite/', query.favorite),
- path('query/explain/', sql.sql_optimize.explain),
- path('query/applylist/', sql.query_privileges.query_priv_apply_list),
- path('query/userprivileges/', sql.query_privileges.user_query_priv),
- path('query/applyforprivileges/', sql.query_privileges.query_priv_apply),
- path('query/modifyprivileges/', sql.query_privileges.query_priv_modify),
- path('query/privaudit/', sql.query_privileges.query_priv_audit),
-
- path('binlog/list/', binlog.binlog_list),
- path('binlog/my2sql/', binlog.my2sql),
- path('binlog/del_log/', binlog.del_binlog),
-
- path('slowquery/review/', slowlog.slowquery_review),
- path('slowquery/review_history/', slowlog.slowquery_review_history),
- path('slowquery/optimize_sqladvisor/', sql.sql_optimize.optimize_sqladvisor),
- path('slowquery/optimize_sqltuning/', sql.sql_optimize.optimize_sqltuning),
- path('slowquery/optimize_soar/', sql.sql_optimize.optimize_soar),
- path('slowquery/optimize_sqltuningadvisor/', sql.sql_optimize.optimize_sqltuningadvisor),
- path('slowquery/report/', slowlog.report),
-
- path('db_diagnostic/process/', db_diagnostic.process),
- path('db_diagnostic/create_kill_session/', db_diagnostic.create_kill_session),
- path('db_diagnostic/kill_session/', db_diagnostic.kill_session),
- path('db_diagnostic/tablesapce/', db_diagnostic.tablesapce),
- path('db_diagnostic/trxandlocks/', db_diagnostic.trxandlocks),
- path('db_diagnostic/innodb_trx/', db_diagnostic.innodb_trx),
-
- path('archive/list/', archiver.archive_list),
- path('archive/apply/', archiver.archive_apply),
- path('archive/audit/', archiver.archive_audit),
- path('archive/switch/', archiver.archive_switch),
- path('archive/once/', archiver.archive_once),
- path('archive/log/', archiver.archive_log),
-
- path('4admin/sync_ding_user/', ding_api.sync_ding_user),
-
- path('audit/log/', audit_log.audit_log),
- path('audit/input/', audit_log.audit_input),
- path('user/list/', user.lists),
- path('user/qrcode//', totp.generate_qrcode),
+ path("", views.index),
+ path("jsi18n/", JavaScriptCatalog.as_view(), name="javascript-catalog"),
+ path("index/", views.index),
+ path("login/", views.login, name="login"),
+ path("login/2fa/", views.twofa, name="twofa"),
+ path("logout/", auth.sign_out),
+ path("signup/", auth.sign_up),
+ path("sqlworkflow/", views.sqlworkflow),
+ path("submitsql/", views.submit_sql),
+ path("editsql/", views.submit_sql),
+ path("submitotherinstance/", views.submit_sql),
+ path("detail//", views.detail, name="detail"),
+ path("autoreview/", sql_workflow.submit),
+ path("passed/", sql_workflow.passed),
+ path("execute/", sql_workflow.execute),
+ path("timingtask/", sql_workflow.timing_task),
+ path("alter_run_date/", sql_workflow.alter_run_date),
+ path("cancel/", sql_workflow.cancel),
+ path("rollback/", views.rollback),
+ path("sqlanalyze/", views.sqlanalyze),
+ path("sqlquery/", views.sqlquery),
+ path("slowquery/", views.slowquery),
+ path("sqladvisor/", views.sqladvisor),
+ path("slowquery_advisor/", views.sqladvisor),
+ path("queryapplylist/", views.queryapplylist),
+ path(
+ "queryapplydetail//",
+ views.queryapplydetail,
+ name="queryapplydetail",
+ ),
+ path("queryuserprivileges/", views.queryuserprivileges),
+ path("dbdiagnostic/", views.dbdiagnostic),
+ path("workflow/", views.workflows),
+ path("workflow//", views.workflowsdetail),
+ path("dbaprinciples/", views.dbaprinciples),
+ path("dashboard/", dashboard.pyecharts),
+ path("group/", views.group),
+ path("grouprelations//", views.groupmgmt),
+ path("instance/", views.instance),
+ path("instanceaccount/", views.instanceaccount),
+ path("database/", views.database),
+ path("instanceparam/", views.instance_param),
+ path("my2sql/", views.my2sql),
+ path("schemasync/", views.schemasync),
+ path("archive/", views.archive),
+ path("archive//", views.archive_detail, name="archive_detail"),
+ path("config/", views.config),
+ path("audit/", views.audit),
+ path("audit_sqlquery/", views.audit_sqlquery),
+ path("audit_sqlworkflow/", views.audit_sqlworkflow),
+ path("authenticate/", auth.authenticate_entry),
+ path("sqlworkflow_list/", sql_workflow.sql_workflow_list),
+ path("sqlworkflow_list_audit/", sql_workflow.sql_workflow_list_audit),
+ path("sqlworkflow/detail_content/", sql_workflow.detail_content),
+ path("sqlworkflow/backup_sql/", sql_workflow.backup_sql),
+ path("simplecheck/", sql_workflow.check),
+ path("getWorkflowStatus/", sql_workflow.get_workflow_status),
+ path("del_sqlcronjob/", tasks.del_schedule),
+ path("inception/osc_control/", sql_workflow.osc_control),
+ path("sql_analyze/generate/", sql_analyze.generate),
+ path("sql_analyze/analyze/", sql_analyze.analyze),
+ path("workflow/list/", workflow.lists),
+ path("workflow/log/", workflow.log),
+ path("config/change/", config.change_config),
+ path("check/go_inception/", check.go_inception),
+ path("check/email/", check.email),
+ path("check/instance/", check.instance),
+ path("group/group/", resource_group.group),
+ path("group/addrelation/", resource_group.addrelation),
+ path("group/relations/", resource_group.associated_objects),
+ path("group/instances/", resource_group.instances),
+ path("group/unassociated/", resource_group.unassociated_objects),
+ path("group/auditors/", resource_group.auditors),
+ path("group/changeauditors/", resource_group.changeauditors),
+ path("group/user_all_instances/", resource_group.user_all_instances),
+ path("instance/list/", instance.lists),
+ path("instance/user/list", instance_account.users),
+ path("instance/user/create/", instance_account.create),
+ path("instance/user/edit/", instance_account.edit),
+ path("instance/user/grant/", instance_account.grant),
+ path("instance/user/reset_pwd/", instance_account.reset_pwd),
+ path("instance/user/lock/", instance_account.lock),
+ path("instance/user/delete/", instance_account.delete),
+ path("instance/database/list/", sql.instance_database.databases),
+ path("instance/database/create/", sql.instance_database.create),
+ path("instance/database/edit/", sql.instance_database.edit),
+ path("instance/schemasync/", instance.schemasync),
+ path("instance/instance_resource/", instance.instance_resource),
+ path("instance/describetable/", instance.describe),
+ path("data_dictionary/", views.data_dictionary),
+ path("data_dictionary/table_list/", data_dictionary.table_list),
+ path("data_dictionary/table_info/", data_dictionary.table_info),
+ path("data_dictionary/export/", data_dictionary.export),
+ path("param/list/", instance.param_list),
+ path("param/history/", instance.param_history),
+ path("param/edit/", instance.param_edit),
+ path("query/", query.query),
+ path("query/querylog/", query.querylog),
+ path("query/querylog_audit/", query.querylog_audit),
+ path("query/favorite/", query.favorite),
+ path("query/explain/", sql.sql_optimize.explain),
+ path("query/applylist/", sql.query_privileges.query_priv_apply_list),
+ path("query/userprivileges/", sql.query_privileges.user_query_priv),
+ path("query/applyforprivileges/", sql.query_privileges.query_priv_apply),
+ path("query/modifyprivileges/", sql.query_privileges.query_priv_modify),
+ path("query/privaudit/", sql.query_privileges.query_priv_audit),
+ path("binlog/list/", binlog.binlog_list),
+ path("binlog/my2sql/", binlog.my2sql),
+ path("binlog/del_log/", binlog.del_binlog),
+ path("slowquery/review/", slowlog.slowquery_review),
+ path("slowquery/review_history/", slowlog.slowquery_review_history),
+ path("slowquery/optimize_sqladvisor/", sql.sql_optimize.optimize_sqladvisor),
+ path("slowquery/optimize_sqltuning/", sql.sql_optimize.optimize_sqltuning),
+ path("slowquery/optimize_soar/", sql.sql_optimize.optimize_soar),
+ path(
+ "slowquery/optimize_sqltuningadvisor/",
+ sql.sql_optimize.optimize_sqltuningadvisor,
+ ),
+ path("slowquery/report/", slowlog.report),
+ path("db_diagnostic/process/", db_diagnostic.process),
+ path("db_diagnostic/create_kill_session/", db_diagnostic.create_kill_session),
+ path("db_diagnostic/kill_session/", db_diagnostic.kill_session),
+ path("db_diagnostic/tablesapce/", db_diagnostic.tablesapce),
+ path("db_diagnostic/trxandlocks/", db_diagnostic.trxandlocks),
+ path("db_diagnostic/innodb_trx/", db_diagnostic.innodb_trx),
+ path("archive/list/", archiver.archive_list),
+ path("archive/apply/", archiver.archive_apply),
+ path("archive/audit/", archiver.archive_audit),
+ path("archive/switch/", archiver.archive_switch),
+ path("archive/once/", archiver.archive_once),
+ path("archive/log/", archiver.archive_log),
+ path("4admin/sync_ding_user/", ding_api.sync_ding_user),
+ path("audit/log/", audit_log.audit_log),
+ path("audit/input/", audit_log.audit_input),
+ path("user/list/", user.lists),
+ path("user/qrcode//", totp.generate_qrcode),
]
diff --git a/sql/user.py b/sql/user.py
index 8dcfb52d06..b7b8ff6053 100644
--- a/sql/user.py
+++ b/sql/user.py
@@ -9,11 +9,15 @@
@superuser_required
def lists(request):
"""获取用户列表"""
- users = Users.objects.order_by('username')
- users = users.values("id", "username", "display", "is_superuser", "is_staff", "is_active", "email")
+ users = Users.objects.order_by("username")
+ users = users.values(
+ "id", "username", "display", "is_superuser", "is_staff", "is_active", "email"
+ )
rows = [row for row in users]
result = {"status": 0, "msg": "ok", "data": rows}
- return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
- content_type='application/json')
+ return HttpResponse(
+ json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
+ content_type="application/json",
+ )
diff --git a/sql/utils/data_masking.py b/sql/utils/data_masking.py
index d774105fdb..f4c4e4a91b 100644
--- a/sql/utils/data_masking.py
+++ b/sql/utils/data_masking.py
@@ -11,7 +11,7 @@
import re
import traceback
-logger = logging.getLogger('default')
+logger = logging.getLogger("default")
def data_masking(instance, db_name, sql, sql_result):
@@ -21,22 +21,28 @@ def data_masking(instance, db_name, sql, sql_result):
# 解析查询语句,判断UNION需要单独处理
p = sqlparse.parse(sql)[0]
for token in p.tokens:
- if token.ttype is Keyword and token.value.upper() in ['UNION', 'UNION ALL']:
- keywords_count['UNION'] = keywords_count.get('UNION', 0) + 1
+ if token.ttype is Keyword and token.value.upper() in ["UNION", "UNION ALL"]:
+ keywords_count["UNION"] = keywords_count.get("UNION", 0) + 1
# 通过goInception获取select list
inception_engine = GoInceptionEngine()
- select_list = inception_engine.query_data_masking(instance=instance, db_name=db_name, sql=sql)
+ select_list = inception_engine.query_data_masking(
+ instance=instance, db_name=db_name, sql=sql
+ )
# 如果UNION存在,那么调用去重函数
- select_list = del_repeat(select_list, keywords_count) if keywords_count else select_list
+ select_list = (
+ del_repeat(select_list, keywords_count) if keywords_count else select_list
+ )
# 分析语法树获取命中脱敏规则的列数据
hit_columns = analyze_query_tree(select_list, instance)
sql_result.mask_rule_hit = True if hit_columns else False
# 对命中规则列hit_columns的数据进行脱敏
- masking_rules = {i.rule_type: model_to_dict(i) for i in DataMaskingRules.objects.all()}
+ masking_rules = {
+ i.rule_type: model_to_dict(i) for i in DataMaskingRules.objects.all()
+ }
if hit_columns and sql_result.rows:
rows = list(sql_result.rows)
for column in hit_columns:
- index, rule_type = column['index'], column['rule_type']
+ index, rule_type = column["index"], column["rule_type"]
masking_rule = masking_rules.get(rule_type)
if not masking_rule:
continue
@@ -47,7 +53,7 @@ def data_masking(instance, db_name, sql, sql_result):
# 脱敏结果
sql_result.is_masked = True
except Exception as msg:
- logger.warning(f'数据脱敏异常,错误信息:{traceback.format_exc()}')
+ logger.warning(f"数据脱敏异常,错误信息:{traceback.format_exc()}")
sql_result.error = str(msg)
sql_result.status = 1
return sql_result
@@ -65,15 +71,15 @@ def del_repeat(select_list, keywords_count):
# 先将query_tree转换成表,方便统计
df = pd.DataFrame(select_list)
- #从原来的库、表、字段去重改为字段
- #result_index = df.groupby(['field', 'table', 'schema']).filter(lambda g: len(g) > 1).to_dict('records')
- result_index = df.groupby(['field']).filter(lambda g: len(g) > 1).to_dict('records')
+ # 从原来的库、表、字段去重改为字段
+ # result_index = df.groupby(['field', 'table', 'schema']).filter(lambda g: len(g) > 1).to_dict('records')
+ result_index = df.groupby(["field"]).filter(lambda g: len(g) > 1).to_dict("records")
# 再统计重复数量
result_len = len(result_index)
-
+
# 再计算取列表前多少的值=重复数量/(union次数+1)
- group_count = int(result_len / (keywords_count['UNION'] + 1))
+ group_count = int(result_len / (keywords_count["UNION"] + 1))
result = result_index[:group_count]
return result
@@ -83,39 +89,49 @@ def analyze_query_tree(select_list, instance):
"""解析select list, 返回命中脱敏规则的列信息"""
# 获取实例全部激活的脱敏字段信息,减少循环查询,提升效率
masking_columns = {
- f"{i.instance}-{i.table_schema}-{i.table_name}-{i.column_name}": model_to_dict(i) for i in
- DataMaskingColumns.objects.filter(instance=instance, active=True)
+ f"{i.instance}-{i.table_schema}-{i.table_name}-{i.column_name}": model_to_dict(
+ i
+ )
+ for i in DataMaskingColumns.objects.filter(instance=instance, active=True)
}
# 遍历select_list 格式化命中的列信息
hit_columns = []
for column in select_list:
- table_schema, table, field = column.get('schema'), column.get('table'), column.get('field')
- masking_column = masking_columns.get(f"{instance}-{table_schema}-{table}-{field}")
+ table_schema, table, field = (
+ column.get("schema"),
+ column.get("table"),
+ column.get("field"),
+ )
+ masking_column = masking_columns.get(
+ f"{instance}-{table_schema}-{table}-{field}"
+ )
if masking_column:
- hit_columns.append({
- "instance_name": instance.instance_name,
- "table_schema": table_schema,
- "table_name": table,
- "column_name": field,
- "rule_type": masking_column['rule_type'],
- "is_hit": True,
- "index": column['index']
- })
+ hit_columns.append(
+ {
+ "instance_name": instance.instance_name,
+ "table_schema": table_schema,
+ "table_name": table,
+ "column_name": field,
+ "rule_type": masking_column["rule_type"],
+ "is_hit": True,
+ "index": column["index"],
+ }
+ )
return hit_columns
def regex(masking_rule, value):
"""利用正则表达式脱敏数据"""
- rule_regex = masking_rule['rule_regex']
- hide_group = masking_rule['hide_group']
+ rule_regex = masking_rule["rule_regex"]
+ hide_group = masking_rule["hide_group"]
# 正则匹配必须分组,隐藏的组会使用****代替
try:
p = re.compile(rule_regex, re.I)
m = p.search(str(value))
- masking_str = ''
+ masking_str = ""
for i in range(m.lastindex):
if i == hide_group - 1:
- group = '****'
+ group = "****"
else:
group = m.group(i + 1)
masking_str = masking_str + group
@@ -132,7 +148,11 @@ def brute_mask(instance, sql_result):
返回同样结构的sql_result , error 中写入脱敏时产生的错误.
"""
# 读取所有关联实例的脱敏规则,去重后应用到结果集,不会按照具体配置的字段匹配
- rule_types = DataMaskingColumns.objects.filter(instance=instance).values_list('rule_type', flat=True).distinct()
+ rule_types = (
+ DataMaskingColumns.objects.filter(instance=instance)
+ .values_list("rule_type", flat=True)
+ .distinct()
+ )
masking_rules = DataMaskingRules.objects.filter(rule_type__in=rule_types)
for reg in masking_rules:
compiled_r = re.compile(reg.rule_regex, re.I)
@@ -147,7 +167,9 @@ def brute_mask(instance, sql_result):
temp_value_list = []
for j in range(len(sql_result.rows[i])):
# 进行正则替换
- temp_value_list += [compiled_r.sub(replace_pattern, str(sql_result.rows[i][j]))]
+ temp_value_list += [
+ compiled_r.sub(replace_pattern, str(sql_result.rows[i][j]))
+ ]
rows[i] = tuple(temp_value_list)
sql_result.rows = rows
return sql_result
@@ -172,22 +194,30 @@ def simple_column_mask(instance, sql_result):
# 脱敏规则字段索引信息
_masking_column_index = []
if column_name in sql_result_column_list:
- _masking_column_index.append(sql_result_column_list.index(column_name))
+ _masking_column_index.append(
+ sql_result_column_list.index(column_name)
+ )
# 别名字段脱敏处理
try:
for _c in sql_result_column_list:
- alias_column_regex = r'"?([^\s"]+)"?\s+(as\s+)?"?({})[",\s+]?'.format(re.escape(_c))
+ alias_column_regex = (
+ r'"?([^\s"]+)"?\s+(as\s+)?"?({})[",\s+]?'.format(
+ re.escape(_c)
+ )
+ )
alias_column_r = re.compile(alias_column_regex, re.I)
# 解析原SQL查询别名字段
search_data = re.search(alias_column_r, sql_result.full_sql)
# 字段名
_column_name = search_data.group(1).lower()
- s_column_name = re.sub(r'^"?\w+"?\."?|\.|"$', '', _column_name)
+ s_column_name = re.sub(r'^"?\w+"?\."?|\.|"$', "", _column_name)
# 别名
alias_name = search_data.group(3).lower()
# 如果字段名匹配脱敏配置字段,对此字段进行脱敏处理
if s_column_name == column_name:
- _masking_column_index.append(sql_result_column_list.index(alias_name))
+ _masking_column_index.append(
+ sql_result_column_list.index(alias_name)
+ )
except:
pass
@@ -209,7 +239,9 @@ def simple_column_mask(instance, sql_result):
for j in range(len(sql_result.rows[i])):
column_data = sql_result.rows[i][j]
if j == masking_column_index:
- column_data = compiled_r.sub(replace_pattern, str(sql_result.rows[i][j]))
+ column_data = compiled_r.sub(
+ replace_pattern, str(sql_result.rows[i][j])
+ )
temp_value_list += [column_data]
rows[i] = tuple(temp_value_list)
sql_result.rows = rows
diff --git a/sql/utils/execute_sql.py b/sql/utils/execute_sql.py
index 0ba93436c2..4e2c084a53 100644
--- a/sql/utils/execute_sql.py
+++ b/sql/utils/execute_sql.py
@@ -12,7 +12,7 @@
from sql.utils.workflow_audit import Audit
from sql.engines import get_engine
-logger = logging.getLogger('default')
+logger = logging.getLogger("default")
def execute(workflow_id, user=None):
@@ -21,21 +21,25 @@ def execute(workflow_id, user=None):
with transaction.atomic():
workflow_detail = SqlWorkflow.objects.select_for_update().get(id=workflow_id)
# 只有排队中和定时执行的数据才可以继续执行,否则直接抛错
- if workflow_detail.status not in ['workflow_queuing', 'workflow_timingtask']:
- raise Exception('工单状态不正确,禁止执行!')
+ if workflow_detail.status not in ["workflow_queuing", "workflow_timingtask"]:
+ raise Exception("工单状态不正确,禁止执行!")
# 将工单状态修改为执行中
else:
- SqlWorkflow(id=workflow_id, status='workflow_executing').save(update_fields=['status'])
+ SqlWorkflow(id=workflow_id, status="workflow_executing").save(
+ update_fields=["status"]
+ )
# 增加执行日志
- audit_id = Audit.detail_by_workflow_id(workflow_id=workflow_id,
- workflow_type=WorkflowDict.workflow_type['sqlreview']).audit_id
- Audit.add_log(audit_id=audit_id,
- operation_type=5,
- operation_type_desc='执行工单',
- operation_info='工单开始执行' if user else '系统定时执行工单',
- operator=user.username if user else '',
- operator_display=user.display if user else '系统'
- )
+ audit_id = Audit.detail_by_workflow_id(
+ workflow_id=workflow_id, workflow_type=WorkflowDict.workflow_type["sqlreview"]
+ ).audit_id
+ Audit.add_log(
+ audit_id=audit_id,
+ operation_type=5,
+ operation_type_desc="执行工单",
+ operation_info="工单开始执行" if user else "系统定时执行工单",
+ operator=user.username if user else "",
+ operator_display=user.display if user else "系统",
+ )
execute_engine = get_engine(instance=workflow_detail.instance)
return execute_engine.execute_workflow(workflow=workflow_detail)
@@ -52,60 +56,68 @@ def execute_callback(task):
# 判断工单状态,如果不是执行中的,不允许更新信息,直接抛错记录日志
with transaction.atomic():
workflow = SqlWorkflow.objects.get(id=workflow_id)
- if workflow.status != 'workflow_executing':
- raise Exception(f'工单{workflow.id}状态不正确,禁止重复更新执行结果!')
+ if workflow.status != "workflow_executing":
+ raise Exception(f"工单{workflow.id}状态不正确,禁止重复更新执行结果!")
workflow.finish_time = task.stopped
if not task.success:
# 不成功会返回错误堆栈信息,构造一个错误信息
- workflow.status = 'workflow_exception'
+ workflow.status = "workflow_exception"
execute_result = ReviewSet(full_sql=workflow.sqlworkflowcontent.sql_content)
- execute_result.rows = [ReviewResult(
- stage='Execute failed',
- errlevel=2,
- stagestatus='异常终止',
- errormessage=task.result,
- sql=workflow.sqlworkflowcontent.sql_content)]
+ execute_result.rows = [
+ ReviewResult(
+ stage="Execute failed",
+ errlevel=2,
+ stagestatus="异常终止",
+ errormessage=task.result,
+ sql=workflow.sqlworkflowcontent.sql_content,
+ )
+ ]
elif task.result.warning or task.result.error:
execute_result = task.result
- workflow.status = 'workflow_exception'
+ workflow.status = "workflow_exception"
else:
execute_result = task.result
- workflow.status = 'workflow_finish'
+ workflow.status = "workflow_finish"
try:
# 保存执行结果
workflow.sqlworkflowcontent.execute_result = execute_result.json()
workflow.sqlworkflowcontent.save()
workflow.save()
except Exception as e:
- logger.error(f'SQL工单回调异常: {workflow_id} {traceback.format_exc()}')
+ logger.error(f"SQL工单回调异常: {workflow_id} {traceback.format_exc()}")
SqlWorkflow.objects.filter(id=workflow_id).update(
finish_time=task.stopped,
- status='workflow_exception',
+ status="workflow_exception",
)
- workflow.sqlworkflowcontent.execute_result = {f'{e}'}
+ workflow.sqlworkflowcontent.execute_result = {f"{e}"}
workflow.sqlworkflowcontent.save()
# 增加工单日志
- audit_id = Audit.detail_by_workflow_id(workflow_id=workflow_id,
- workflow_type=WorkflowDict.workflow_type['sqlreview']).audit_id
- Audit.add_log(audit_id=audit_id,
- operation_type=6,
- operation_type_desc='执行结束',
- operation_info='执行结果:{}'.format(workflow.get_status_display()),
- operator='',
- operator_display='系统'
- )
+ audit_id = Audit.detail_by_workflow_id(
+ workflow_id=workflow_id, workflow_type=WorkflowDict.workflow_type["sqlreview"]
+ ).audit_id
+ Audit.add_log(
+ audit_id=audit_id,
+ operation_type=6,
+ operation_type_desc="执行结束",
+ operation_info="执行结果:{}".format(workflow.get_status_display()),
+ operator="",
+ operator_display="系统",
+ )
# DDL工单结束后清空实例资源缓存
if workflow.syntax_type == 1:
r = get_redis_connection("default")
- for key in r.scan_iter(match='*insRes*', count=2000):
+ for key in r.scan_iter(match="*insRes*", count=2000):
r.delete(key)
# 开启了Execute阶段通知参数才发送消息通知
sys_config = SysConfig()
- is_notified = 'Execute' in sys_config.get('notify_phase_control').split(',') \
- if sys_config.get('notify_phase_control') else True
+ is_notified = (
+ "Execute" in sys_config.get("notify_phase_control").split(",")
+ if sys_config.get("notify_phase_control")
+ else True
+ )
if is_notified:
notify_for_execute(workflow)
diff --git a/sql/utils/extract_tables.py b/sql/utils/extract_tables.py
index 73aed5d522..9ba7972d3f 100644
--- a/sql/utils/extract_tables.py
+++ b/sql/utils/extract_tables.py
@@ -53,9 +53,12 @@ def is_subselect(parsed):
if not parsed.is_group:
return False
for item in parsed.tokens:
- if (
- item.ttype is DML
- and item.value.upper() in ("SELECT", "INSERT", "UPDATE", "CREATE", "DELETE")
+ if item.ttype is DML and item.value.upper() in (
+ "SELECT",
+ "INSERT",
+ "UPDATE",
+ "CREATE",
+ "DELETE",
):
return True
return False
@@ -82,18 +85,23 @@ def extract_from_part(parsed, stop_at_punctuation=True):
# Also 'SELECT * FROM abc JOIN def' will trigger this elif
# condition. So we need to ignore the keyword JOIN and its variants
# INNER JOIN, FULL OUTER JOIN, etc.
- elif item.ttype is Keyword and (not item.value.upper() == "FROM") and (
- not item.value.upper().endswith("JOIN")
+ elif (
+ item.ttype is Keyword
+ and (not item.value.upper() == "FROM")
+ and (not item.value.upper().endswith("JOIN"))
):
tbl_prefix_seen = False
else:
yield item
elif item.ttype is Keyword or item.ttype is Keyword.DML:
item_val = item.value.upper()
- if (
- item_val in ("COPY", "FROM", "INTO", "UPDATE", "TABLE")
- or item_val.endswith("JOIN")
- ):
+ if item_val in (
+ "COPY",
+ "FROM",
+ "INTO",
+ "UPDATE",
+ "TABLE",
+ ) or item_val.endswith("JOIN"):
tbl_prefix_seen = True
# 'SELECT a, FROM abc' will detect FROM as part of the column list.
# So this check here is necessary.
@@ -139,8 +147,8 @@ def parse_identifier(item):
try:
schema_name = identifier.get_parent_name()
real_name = identifier.get_real_name()
- is_function = (
- allow_functions and _identifier_is_function(identifier)
+ is_function = allow_functions and _identifier_is_function(
+ identifier
)
except AttributeError:
continue
diff --git a/sql/utils/resource_group.py b/sql/utils/resource_group.py
index f3349753c5..2ccf46a04c 100644
--- a/sql/utils/resource_group.py
+++ b/sql/utils/resource_group.py
@@ -12,7 +12,12 @@ def user_groups(user):
if user.is_superuser:
group_list = [group for group in ResourceGroup.objects.filter(is_deleted=0)]
else:
- group_list = [group for group in Users.objects.get(id=user.id).resource_group.filter(is_deleted=0)]
+ group_list = [
+ group
+ for group in Users.objects.get(id=user.id).resource_group.filter(
+ is_deleted=0
+ )
+ ]
return group_list
@@ -26,7 +31,7 @@ def user_instances(user, type=None, db_type=None, tag_codes=None):
:return:
"""
# 拥有所有实例权限的用户
- if user.has_perm('sql.query_all_instances'):
+ if user.has_perm("sql.query_all_instances"):
instances = Instance.objects.all()
else:
# 先获取用户关联的资源组
@@ -44,7 +49,9 @@ def user_instances(user, type=None, db_type=None, tag_codes=None):
# 过滤tag
if tag_codes:
for tag_code in tag_codes:
- instances = instances.filter(instance_tag__tag_code=tag_code, instance_tag__active=True)
+ instances = instances.filter(
+ instance_tag__tag_code=tag_code, instance_tag__active=True
+ )
return instances.distinct()
diff --git a/sql/utils/sql_review.py b/sql/utils/sql_review.py
index 7cd38eb59a..771747f30d 100644
--- a/sql/utils/sql_review.py
+++ b/sql/utils/sql_review.py
@@ -18,14 +18,19 @@ def is_auto_review(workflow_id):
"""
workflow = SqlWorkflow.objects.get(id=workflow_id)
- auto_review_tags = SysConfig().get('auto_review_tag', '').split(',')
- auto_review_db_type = SysConfig().get('auto_review_db_type', '').split(',')
+ auto_review_tags = SysConfig().get("auto_review_tag", "").split(",")
+ auto_review_db_type = SysConfig().get("auto_review_db_type", "").split(",")
# TODO 这里也可以放到engine中实现,但是配置项可能会相对复杂
- if workflow.instance.db_type in auto_review_db_type and workflow.instance.instance_tag.filter(
- tag_code__in=auto_review_tags).exists():
+ if (
+ workflow.instance.db_type in auto_review_db_type
+ and workflow.instance.instance_tag.filter(
+ tag_code__in=auto_review_tags
+ ).exists()
+ ):
# 获取正则表达式
auto_review_regex = SysConfig().get(
- 'auto_review_regex', '^alter|^create|^drop|^truncate|^rename|^delete')
+ "auto_review_regex", "^alter|^create|^drop|^truncate|^rename|^delete"
+ )
p = re.compile(auto_review_regex, re.I)
# 判断是否匹配到需要手动审核的语句
@@ -35,14 +40,14 @@ def is_auto_review(workflow_id):
for review_row in json.loads(review_content):
review_result = ReviewResult(**review_row)
# 去除SQL注释 https://github.com/hhyo/Archery/issues/949
- sql = remove_comments(review_result.sql).replace("\n","").replace("\r", "")
+ sql = remove_comments(review_result.sql).replace("\n", "").replace("\r", "")
# 正则匹配
if p.match(sql):
auto_review = False
break
# 影响行数加测, 总语句影响行数超过指定数量则需要人工审核
all_affected_rows += int(review_result.affected_rows)
- if all_affected_rows > int(SysConfig().get('auto_review_max_update_rows', 50)):
+ if all_affected_rows > int(SysConfig().get("auto_review_max_update_rows", 50)):
auto_review = False
else:
auto_review = False
@@ -63,14 +68,19 @@ def can_execute(user, workflow_id):
with transaction.atomic():
workflow_detail = SqlWorkflow.objects.select_for_update().get(id=workflow_id)
# 只有审核通过和定时执行的数据才可以立即执行
- if workflow_detail.status not in ['workflow_review_pass', 'workflow_timingtask']:
+ if workflow_detail.status not in [
+ "workflow_review_pass",
+ "workflow_timingtask",
+ ]:
return False
# 当前登录用户有资源组粒度执行权限,并且为组内用户
group_ids = [group.group_id for group in user_groups(user)]
- if workflow_detail.group_id in group_ids and user.has_perm('sql.sql_execute_for_resource_group'):
+ if workflow_detail.group_id in group_ids and user.has_perm(
+ "sql.sql_execute_for_resource_group"
+ ):
result = True
# 当前登录用户为提交人,并且有执行权限
- if workflow_detail.engineer == user.username and user.has_perm('sql.sql_execute'):
+ if workflow_detail.engineer == user.username and user.has_perm("sql.sql_execute"):
result = True
return result
@@ -104,13 +114,17 @@ def can_timingtask(user, workflow_id):
workflow_detail = SqlWorkflow.objects.get(id=workflow_id)
result = False
# 只有审核通过和定时执行的数据才可以执行
- if workflow_detail.status in ['workflow_review_pass', 'workflow_timingtask']:
+ if workflow_detail.status in ["workflow_review_pass", "workflow_timingtask"]:
# 当前登录用户有资源组粒度执行权限,并且为组内用户
group_ids = [group.group_id for group in user_groups(user)]
- if workflow_detail.group_id in group_ids and user.has_perm('sql.sql_execute_for_resource_group'):
+ if workflow_detail.group_id in group_ids and user.has_perm(
+ "sql.sql_execute_for_resource_group"
+ ):
result = True
# 当前登录用户为提交人,并且有执行权限
- if workflow_detail.engineer == user.username and user.has_perm('sql.sql_execute'):
+ if workflow_detail.engineer == user.username and user.has_perm(
+ "sql.sql_execute"
+ ):
result = True
return result
@@ -126,11 +140,19 @@ def can_cancel(user, workflow_id):
workflow_detail = SqlWorkflow.objects.get(id=workflow_id)
result = False
# 审核中的工单,审核人和提交人可终止
- if workflow_detail.status == 'workflow_manreviewing':
+ if workflow_detail.status == "workflow_manreviewing":
from sql.utils.workflow_audit import Audit
- return any([Audit.can_review(user, workflow_id, 2), user.username == workflow_detail.engineer])
- elif workflow_detail.status in ['workflow_review_pass', 'workflow_timingtask']:
- return any([can_execute(user, workflow_id), user.username == workflow_detail.engineer])
+
+ return any(
+ [
+ Audit.can_review(user, workflow_id, 2),
+ user.username == workflow_detail.engineer,
+ ]
+ )
+ elif workflow_detail.status in ["workflow_review_pass", "workflow_timingtask"]:
+ return any(
+ [can_execute(user, workflow_id), user.username == workflow_detail.engineer]
+ )
return result
@@ -147,7 +169,9 @@ def can_view(user, workflow_id):
if user.is_superuser:
result = True
# 非管理员,拥有审核权限、资源组粒度执行权限的,可以查看组内所有工单
- elif user.has_perm('sql.sql_review') or user.has_perm('sql.sql_execute_for_resource_group'):
+ elif user.has_perm("sql.sql_review") or user.has_perm(
+ "sql.sql_execute_for_resource_group"
+ ):
# 先获取用户所在资源组列表
group_list = user_groups(user)
group_ids = [group.group_id for group in group_list]
@@ -171,6 +195,9 @@ def can_rollback(user, workflow_id):
workflow_detail = SqlWorkflow.objects.get(id=workflow_id)
result = False
# 执行结束并且开启备份的工单可以查看回滚信息
- if workflow_detail.is_backup and workflow_detail.status in ('workflow_finish', 'workflow_exception'):
+ if workflow_detail.is_backup and workflow_detail.status in (
+ "workflow_finish",
+ "workflow_exception",
+ ):
return can_view(user, workflow_id)
return result
diff --git a/sql/utils/sql_utils.py b/sql/utils/sql_utils.py
index 55e32cea3b..201f0c20f0 100644
--- a/sql/utils/sql_utils.py
+++ b/sql/utils/sql_utils.py
@@ -13,10 +13,10 @@
from sql.engines.models import SqlItem
from sql.utils.extract_tables import extract_tables as extract_tables_by_sql_parse
-__author__ = 'hhyo'
+__author__ = "hhyo"
-def get_syntax_type(sql, parser=True, db_type='mysql'):
+def get_syntax_type(sql, parser=True, db_type="mysql"):
"""
返回SQL语句类型,仅判断DDL和DML
:param sql:
@@ -29,32 +29,32 @@ def get_syntax_type(sql, parser=True, db_type='mysql'):
try:
statement = sqlparse.parse(sql)[0]
syntax_type = statement.token_first(skip_cm=True).ttype.__str__()
- if syntax_type == 'Token.Keyword.DDL':
- syntax_type = 'DDL'
- elif syntax_type == 'Token.Keyword.DML':
- syntax_type = 'DML'
+ if syntax_type == "Token.Keyword.DDL":
+ syntax_type = "DDL"
+ elif syntax_type == "Token.Keyword.DML":
+ syntax_type = "DML"
except Exception:
syntax_type = None
else:
- if db_type == 'mysql':
+ if db_type == "mysql":
ddl_re = r"^alter|^create|^drop|^rename|^truncate"
dml_re = r"^call|^delete|^do|^handler|^insert|^load\s+data|^load\s+xml|^replace|^select|^update"
- elif db_type == 'oracle':
+ elif db_type == "oracle":
ddl_re = r"^alter|^create|^drop|^rename|^truncate"
dml_re = r"^delete|^exec|^insert|^select|^update|^with|^merge"
else:
# TODO 其他数据库的解析正则
return None
if re.match(ddl_re, sql, re.I):
- syntax_type = 'DDL'
+ syntax_type = "DDL"
elif re.match(dml_re, sql, re.I):
- syntax_type = 'DML'
+ syntax_type = "DML"
else:
syntax_type = None
return syntax_type
-def remove_comments(sql, db_type='mysql'):
+def remove_comments(sql, db_type="mysql"):
"""
去除SQL语句中的注释信息
来源:https://stackoverflow.com/questions/35647841/parse-sql-file-with-comments-into-sqlite-with-python
@@ -63,10 +63,8 @@ def remove_comments(sql, db_type='mysql'):
:return:
"""
sql_comments_re = {
- 'oracle':
- [r'(?:--)[^\n]*\n', r'(?:\W|^)(?:remark|rem)\s+[^\n]*\n'],
- 'mysql':
- [r'(?:#|--\s)[^\n]*\n']
+ "oracle": [r"(?:--)[^\n]*\n", r"(?:\W|^)(?:remark|rem)\s+[^\n]*\n"],
+ "mysql": [r"(?:#|--\s)[^\n]*\n"],
}
specific_comment_re = sql_comments_re[db_type]
additional_patterns = "|"
@@ -94,10 +92,12 @@ def extract_tables(sql):
"""
tables = list()
for i in extract_tables_by_sql_parse(sql):
- tables.append({
- "schema": i.schema,
- "name": i.name,
- })
+ tables.append(
+ {
+ "schema": i.schema,
+ "name": i.name,
+ }
+ )
return tables
@@ -110,7 +110,7 @@ def generate_sql(text):
# 尝试XML解析
try:
mapper, xml_raw_text = mybatis_mapper2sql.create_mapper(xml_raw_text=text)
- statements = mybatis_mapper2sql.get_statement(mapper, result_type='list')
+ statements = mybatis_mapper2sql.get_statement(mapper, result_type="list")
rows = []
# 压缩SQL语句,方便展示
for statement in statements:
@@ -131,13 +131,15 @@ def generate_sql(text):
def get_base_sqlitem_list(full_sql):
- ''' 把参数 full_sql 转变为 SqlItem列表
+ """把参数 full_sql 转变为 SqlItem列表
:param full_sql: 完整sql字符串, 每个SQL以分号;间隔, 不包含plsql执行块和plsql对象定义块
:return: SqlItem对象列表
- '''
+ """
list = []
for statement in sqlparse.split(full_sql):
- statement = sqlparse.format(statement, strip_comments=True, reindent=True, keyword_case='lower')
+ statement = sqlparse.format(
+ statement, strip_comments=True, reindent=True, keyword_case="lower"
+ )
if len(statement) <= 0:
continue
item = SqlItem(statement=statement)
@@ -146,15 +148,15 @@ def get_base_sqlitem_list(full_sql):
def get_full_sqlitem_list(full_sql, db_name):
- ''' 获取Sql对应的SqlItem列表, 包括PLSQL部分
+ """获取Sql对应的SqlItem列表, 包括PLSQL部分
PLSQL语句块由delimiter $$作为开始间隔符,以$$作为结束间隔符
:param full_sql: 全部sql内容
:return: SqlItem 列表
- '''
+ """
list = []
# 定义开始分隔符,两端用括号,是为了re.split()返回列表包含分隔符
- regex_delimiter = r'(delimiter\s*\$\$)'
+ regex_delimiter = r"(delimiter\s*\$\$)"
# 注意:必须把package body置于package之前,否则将永远匹配不上package body
regex_objdefine = r'create\s+or\s+replace\s+(function|procedure|trigger|package\s+body|package|view)\s+("?\w+"?\.)?"?\w+"?[\s+|\(]'
# 对象命名,两端有双引号
@@ -192,7 +194,7 @@ def get_full_sqlitem_list(full_sql, db_name):
plsql_block = sql[0:pos].strip()
# 如果plsql_area字符串最后一个字符为/,则把/给去掉
while True:
- if plsql_block[-1:] == '/':
+ if plsql_block[-1:] == "/":
plsql_block = plsql_block[:-1].strip()
else:
break
@@ -211,11 +213,11 @@ def get_full_sqlitem_list(full_sql, db_name):
str_plsql_type = search_result.groups()[0]
idx = str_plsql_match.index(str_plsql_type)
- nm_str = str_plsql_match[idx + len(str_plsql_type):].strip()
+ nm_str = str_plsql_match[idx + len(str_plsql_type) :].strip()
- if nm_str[-1:] == '(':
+ if nm_str[-1:] == "(":
nm_str = nm_str[:-1]
- nm_list = nm_str.split('.')
+ nm_list = nm_str.split(".")
if len(nm_list) > 1:
# 带有属主的对象名, 形如object_owner.object_name
@@ -246,28 +248,32 @@ def get_full_sqlitem_list(full_sql, db_name):
object_name = nm_list[0].upper().strip()
tmp_object_type = str_plsql_type.upper()
- tmp_stmt_type = 'PLSQL'
- if tmp_object_type == 'VIEW':
- tmp_stmt_type = 'SQL'
-
- item = SqlItem(statement=plsql_block,
- stmt_type=tmp_stmt_type,
- object_owner=object_owner,
- object_type=tmp_object_type,
- object_name=object_name)
+ tmp_stmt_type = "PLSQL"
+ if tmp_object_type == "VIEW":
+ tmp_stmt_type = "SQL"
+
+ item = SqlItem(
+ statement=plsql_block,
+ stmt_type=tmp_stmt_type,
+ object_owner=object_owner,
+ object_type=tmp_object_type,
+ object_name=object_name,
+ )
list.append(item)
else:
# 未检索到关键字, 属于情况2, 匿名可执行块 it's ANONYMOUS
- item = SqlItem(statement=plsql_block.strip(),
- stmt_type='PLSQL',
- object_owner=db_name,
- object_type='ANONYMOUS',
- object_name='ANONYMOUS')
+ item = SqlItem(
+ statement=plsql_block.strip(),
+ stmt_type="PLSQL",
+ object_owner=db_name,
+ object_type="ANONYMOUS",
+ object_name="ANONYMOUS",
+ )
list.append(item)
if length > pos + 2:
# 处理$$之后的那些语句, 默认为单条可执行SQL的集合
- sql_area = sql[pos + 2:].strip()
+ sql_area = sql[pos + 2 :].strip()
if len(sql_area) > 0:
tmp_list = get_base_sqlitem_list(sql_area)
list.extend(tmp_list)
@@ -287,18 +293,22 @@ def get_full_sqlitem_list(full_sql, db_name):
def get_exec_sqlitem_list(reviewResult, db_name):
- """ 根据审核结果生成新的SQL列表
+ """根据审核结果生成新的SQL列表
:param reviewResult: SQL审核结果列表
:param db_name:
:return:
"""
list = []
- list.append(SqlItem(statement=f" ALTER SESSION SET CURRENT_SCHEMA = \"{db_name}\" "))
+ list.append(SqlItem(statement=f' ALTER SESSION SET CURRENT_SCHEMA = "{db_name}" '))
for item in reviewResult:
- list.append(SqlItem(statement=item['sql'],
- stmt_type=item['stmt_type'],
- object_owner=item['object_owner'],
- object_type=item['object_type'],
- object_name=item['object_name']))
+ list.append(
+ SqlItem(
+ statement=item["sql"],
+ stmt_type=item["stmt_type"],
+ object_owner=item["object_owner"],
+ object_type=item["object_type"],
+ object_name=item["object_name"],
+ )
+ )
return list
diff --git a/sql/utils/ssh_tunnel.py b/sql/utils/ssh_tunnel.py
index e40c9861e7..2946eb5f1f 100644
--- a/sql/utils/ssh_tunnel.py
+++ b/sql/utils/ssh_tunnel.py
@@ -14,7 +14,18 @@ class SSHConnection(object):
"""
ssh隧道连接类,用于映射ssh隧道端口到本地,连接结束时需要清理
"""
- def __init__(self, host, port, tun_host, tun_port, tun_user, tun_password, pkey, pkey_password):
+
+ def __init__(
+ self,
+ host,
+ port,
+ tun_host,
+ tun_port,
+ tun_user,
+ tun_password,
+ pkey,
+ pkey_password,
+ ):
self.host = host
self.port = int(port)
self.tun_host = tun_host
@@ -26,7 +37,9 @@ def __init__(self, host, port, tun_host, tun_port, tun_user, tun_password, pkey,
private_key_file_obj = io.StringIO()
private_key_file_obj.write(pkey)
private_key_file_obj.seek(0)
- self.private_key = RSAKey.from_private_key(private_key_file_obj, password=pkey_password)
+ self.private_key = RSAKey.from_private_key(
+ private_key_file_obj, password=pkey_password
+ )
self.server = SSHTunnelForwarder(
ssh_address_or_host=(self.tun_host, self.tun_port),
ssh_username=self.tun_user,
diff --git a/sql/utils/tasks.py b/sql/utils/tasks.py
index 97de47f108..793feced50 100644
--- a/sql/utils/tasks.py
+++ b/sql/utils/tasks.py
@@ -4,30 +4,50 @@
import logging
-logger = logging.getLogger('default')
+logger = logging.getLogger("default")
def add_sql_schedule(name, run_date, workflow_id):
"""添加/修改sql定时任务"""
del_schedule(name)
- schedule('sql.utils.execute_sql.execute', workflow_id,
- hook='sql.utils.execute_sql.execute_callback',
- name=name, schedule_type='O', next_run=run_date, repeats=1, timeout=-1)
+ schedule(
+ "sql.utils.execute_sql.execute",
+ workflow_id,
+ hook="sql.utils.execute_sql.execute_callback",
+ name=name,
+ schedule_type="O",
+ next_run=run_date,
+ repeats=1,
+ timeout=-1,
+ )
logger.debug(f"添加SQL定时执行任务:{name} 执行时间:{run_date}")
def add_kill_conn_schedule(name, run_date, instance_id, thread_id):
"""添加/修改终止数据库连接的定时任务"""
del_schedule(name)
- schedule('sql.query.kill_query_conn', instance_id, thread_id,
- name=name, schedule_type='O', next_run=run_date, repeats=1, timeout=-1)
+ schedule(
+ "sql.query.kill_query_conn",
+ instance_id,
+ thread_id,
+ name=name,
+ schedule_type="O",
+ next_run=run_date,
+ repeats=1,
+ timeout=-1,
+ )
def add_sync_ding_user_schedule():
"""添加钉钉同步用户定时任务"""
- del_schedule(name='同步钉钉用户ID')
- schedule('common.utils.ding_api.sync_ding_user_id',
- name='同步钉钉用户ID', schedule_type='D', repeats=-1, timeout=-1)
+ del_schedule(name="同步钉钉用户ID")
+ schedule(
+ "common.utils.ding_api.sync_ding_user_id",
+ name="同步钉钉用户ID",
+ schedule_type="D",
+ repeats=-1,
+ timeout=-1,
+ )
def del_schedule(name):
@@ -35,7 +55,7 @@ def del_schedule(name):
try:
sql_schedule = Schedule.objects.get(name=name)
Schedule.delete(sql_schedule)
- logger.debug(f'删除schedule:{name}')
+ logger.debug(f"删除schedule:{name}")
except Schedule.DoesNotExist:
pass
diff --git a/sql/utils/tests.py b/sql/utils/tests.py
index 185224a41b..404fb4c356 100644
--- a/sql/utils/tests.py
+++ b/sql/utils/tests.py
@@ -23,11 +23,30 @@
from common.config import SysConfig
from common.utils.const import WorkflowDict
from sql.engines.models import ReviewResult, ReviewSet
-from sql.models import Users, SqlWorkflow, SqlWorkflowContent, Instance, ResourceGroup, \
- WorkflowLog, WorkflowAudit, WorkflowAuditDetail, WorkflowAuditSetting, \
- QueryPrivilegesApply, DataMaskingRules, DataMaskingColumns, InstanceTag, ArchiveConfig
+from sql.models import (
+ Users,
+ SqlWorkflow,
+ SqlWorkflowContent,
+ Instance,
+ ResourceGroup,
+ WorkflowLog,
+ WorkflowAudit,
+ WorkflowAuditDetail,
+ WorkflowAuditSetting,
+ QueryPrivilegesApply,
+ DataMaskingRules,
+ DataMaskingColumns,
+ InstanceTag,
+ ArchiveConfig,
+)
from sql.utils.resource_group import user_groups, user_instances, auth_group_users
-from sql.utils.sql_review import is_auto_review, can_execute, can_timingtask, can_cancel, on_correct_time_period
+from sql.utils.sql_review import (
+ is_auto_review,
+ can_execute,
+ can_timingtask,
+ can_cancel,
+ on_correct_time_period,
+)
from sql.utils.sql_utils import *
from sql.utils.execute_sql import execute, execute_callback
from sql.utils.tasks import add_sql_schedule, del_schedule, task_info
@@ -35,7 +54,7 @@
from sql.utils.data_masking import data_masking, brute_mask, simple_column_mask
User = Users
-__author__ = 'hhyo'
+__author__ = "hhyo"
class TestSQLUtils(TestCase):
@@ -46,8 +65,8 @@ def test_get_syntax_type(self):
"""
dml_sql = "select * from users;"
ddl_sql = "alter table users add id not null default 0 comment 'id' "
- self.assertEqual(get_syntax_type(dml_sql), 'DML')
- self.assertEqual(get_syntax_type(ddl_sql), 'DDL')
+ self.assertEqual(get_syntax_type(dml_sql), "DML")
+ self.assertEqual(get_syntax_type(ddl_sql), "DDL")
def test_get_syntax_type_by_re(self):
"""
@@ -56,10 +75,10 @@ def test_get_syntax_type_by_re(self):
"""
dml_sql = "select * from users;"
ddl_sql = "alter table users add id int not null default 0 comment 'id' "
- other_sql = 'show engine innodb status'
- self.assertEqual(get_syntax_type(dml_sql, parser=False, db_type='mysql'), 'DML')
- self.assertEqual(get_syntax_type(ddl_sql, parser=False, db_type='mysql'), 'DDL')
- self.assertIsNone(get_syntax_type(other_sql, parser=False, db_type='mysql'))
+ other_sql = "show engine innodb status"
+ self.assertEqual(get_syntax_type(dml_sql, parser=False, db_type="mysql"), "DML")
+ self.assertEqual(get_syntax_type(ddl_sql, parser=False, db_type="mysql"), "DDL")
+ self.assertIsNone(get_syntax_type(other_sql, parser=False, db_type="mysql"))
def test_remove_comments(self):
"""
@@ -72,12 +91,15 @@ def test_remove_comments(self):
SELECT 1+1; -- This comment continues to the end of line"""
sql3 = """/* this is an in-line comment */
SELECT 1 /* this is an in-line comment */ + 1;/* this is an in-line comment */"""
- self.assertEqual(remove_comments(sql1, db_type='mysql'),
- 'SELECT 1+1; # This comment continues to the end of line')
- self.assertEqual(remove_comments(sql2, db_type='mysql'),
- 'SELECT 1+1; -- This comment continues to the end of line')
- self.assertEqual(remove_comments(sql3, db_type='mysql'),
- 'SELECT 1 + 1;')
+ self.assertEqual(
+ remove_comments(sql1, db_type="mysql"),
+ "SELECT 1+1; # This comment continues to the end of line",
+ )
+ self.assertEqual(
+ remove_comments(sql2, db_type="mysql"),
+ "SELECT 1+1; -- This comment continues to the end of line",
+ )
+ self.assertEqual(remove_comments(sql3, db_type="mysql"), "SELECT 1 + 1;")
def test_extract_tables_by_sql_parse(self):
"""
@@ -85,7 +107,10 @@ def test_extract_tables_by_sql_parse(self):
:return:
"""
sql = "select * from user.users a join logs.log b on a.id=b.id;"
- self.assertEqual(extract_tables(sql), [{'name': 'users', 'schema': 'user'}, {'name': 'log', 'schema': 'logs'}])
+ self.assertEqual(
+ extract_tables(sql),
+ [{"name": "users", "schema": "user"}, {"name": "log", "schema": "logs"}],
+ )
def test_generate_sql_from_sql(self):
"""
@@ -94,9 +119,13 @@ def test_generate_sql_from_sql(self):
"""
text = "select * from sql_user;select * from sql_workflow;"
rows = generate_sql(text)
- self.assertListEqual(rows, [{'sql_id': 1, 'sql': 'select * from sql_user;'},
- {'sql_id': 2, 'sql': 'select * from sql_workflow;'}]
- )
+ self.assertListEqual(
+ rows,
+ [
+ {"sql_id": 1, "sql": "select * from sql_user;"},
+ {"sql_id": 2, "sql": "select * from sql_workflow;"},
+ ],
+ )
def test_generate_sql_from_xml(self):
"""
@@ -120,9 +149,15 @@ def test_generate_sql_from_xml(self):
"""
rows = generate_sql(text)
- self.assertEqual(rows, [{'sql_id': 'testParameters',
- 'sql': '\nSELECT name,\n category,\n price\nFROM fruits\nWHERE category = ?\n AND price > ?'}]
- )
+ self.assertEqual(
+ rows,
+ [
+ {
+ "sql_id": "testParameters",
+ "sql": "\nSELECT name,\n category,\n price\nFROM fruits\nWHERE category = ?\n AND price > ?",
+ }
+ ],
+ )
class TestSQLReview(TestCase):
@@ -131,36 +166,38 @@ class TestSQLReview(TestCase):
"""
def setUp(self):
- self.superuser = User.objects.create(username='super', is_superuser=True)
- self.user = User.objects.create(username='user')
+ self.superuser = User.objects.create(username="super", is_superuser=True)
+ self.user = User.objects.create(username="user")
# 使用 travis.ci 时实例和测试service保持一致
- self.master = Instance(instance_name='test_instance', type='master', db_type='mysql',
- host=settings.DATABASES['default']['HOST'],
- port=settings.DATABASES['default']['PORT'],
- user=settings.DATABASES['default']['USER'],
- password=settings.DATABASES['default']['PASSWORD'])
+ self.master = Instance(
+ instance_name="test_instance",
+ type="master",
+ db_type="mysql",
+ host=settings.DATABASES["default"]["HOST"],
+ port=settings.DATABASES["default"]["PORT"],
+ user=settings.DATABASES["default"]["USER"],
+ password=settings.DATABASES["default"]["PASSWORD"],
+ )
self.master.save()
self.sys_config = SysConfig()
self.client = Client()
- self.group = ResourceGroup.objects.create(group_id=1, group_name='group_name')
+ self.group = ResourceGroup.objects.create(group_id=1, group_name="group_name")
self.wf1 = SqlWorkflow.objects.create(
- workflow_name='workflow_name',
+ workflow_name="workflow_name",
group_id=self.group.group_id,
group_name=self.group.group_name,
engineer=self.superuser.username,
engineer_display=self.superuser.display,
- audit_auth_groups='audit_auth_groups',
+ audit_auth_groups="audit_auth_groups",
create_time=datetime.datetime.now(),
- status='workflow_review_pass',
+ status="workflow_review_pass",
is_backup=True,
instance=self.master,
- db_name='db_name',
+ db_name="db_name",
syntax_type=1,
)
self.wfc1 = SqlWorkflowContent.objects.create(
- workflow=self.wf1,
- sql_content='some_sql',
- execute_result=''
+ workflow=self.wf1, sql_content="some_sql", execute_result=""
)
def tearDown(self):
@@ -171,275 +208,370 @@ def tearDown(self):
self.master.delete()
self.sys_config.replace(json.dumps({}))
- @patch('sql.engines.get_engine')
- def test_auto_review_hit_review_regex(self, _get_engine, ):
+ @patch("sql.engines.get_engine")
+ def test_auto_review_hit_review_regex(
+ self,
+ _get_engine,
+ ):
"""
测试自动审批通过的判定条件,命中判断正则
:return:
"""
# 开启自动审批设置
- self.sys_config.set('auto_review', 'true')
- self.sys_config.set('auto_review_db_type', 'mysql')
- self.sys_config.set('auto_review_regex', '^drop') # drop语句需要审批
- self.sys_config.set('auto_review_max_update_rows', '50') # update影响行数大于50需要审批
+ self.sys_config.set("auto_review", "true")
+ self.sys_config.set("auto_review_db_type", "mysql")
+ self.sys_config.set("auto_review_regex", "^drop") # drop语句需要审批
+ self.sys_config.set("auto_review_max_update_rows", "50") # update影响行数大于50需要审批
self.sys_config.get_all_config()
# 修改工单为drop
self.wfc1.sql_content = "drop table users;"
- self.wfc1.save(update_fields=('sql_content',))
+ self.wfc1.save(update_fields=("sql_content",))
r = is_auto_review(self.wfc1.workflow_id)
self.assertFalse(r)
- @patch('sql.engines.mysql.MysqlEngine.execute_check')
- @patch('sql.engines.get_engine')
+ @patch("sql.engines.mysql.MysqlEngine.execute_check")
+ @patch("sql.engines.get_engine")
def test_auto_review_gt_max_update_rows(self, _get_engine, _execute_check):
"""
测试自动审批通过的判定条件,影响行数大于auto_review_max_update_rows
:return:
"""
# 开启自动审批设置
- self.sys_config.set('auto_review', 'true')
- self.sys_config.set('auto_review_db_type', 'mysql')
- self.sys_config.set('auto_review_regex', '^drop') # drop语句需要审批
- self.sys_config.set('auto_review_max_update_rows', '2') # update影响行数大于2需要审批
+ self.sys_config.set("auto_review", "true")
+ self.sys_config.set("auto_review_db_type", "mysql")
+ self.sys_config.set("auto_review_regex", "^drop") # drop语句需要审批
+ self.sys_config.set("auto_review_max_update_rows", "2") # update影响行数大于2需要审批
self.sys_config.get_all_config()
# 修改工单为update
self.wfc1.sql_content = "update table users set email='';"
- self.wfc1.save(update_fields=('sql_content',))
+ self.wfc1.save(update_fields=("sql_content",))
# mock返回值,update影响行数=3
_execute_check.return_value.to_dict.return_value = [
- {"id": 1, "stage": "CHECKED", "errlevel": 0, "stagestatus": "Audit completed", "errormessage": "None",
- "sql": "use archer_test", "affected_rows": 0, "sequence": "'0_0_0'", "backup_dbname": "None",
- "execute_time": "0", "sqlsha1": "", "actual_affected_rows": 'null'},
- {"id": 2, "stage": "CHECKED", "errlevel": 0, "stagestatus": "Audit completed", "errormessage": "None",
- "sql": "update table users set email=''", "affected_rows": 3, "sequence": "'0_0_1'",
- "backup_dbname": "mysql_3306_archer_test", "execute_time": "0", "sqlsha1": "",
- "actual_affected_rows": 'null'}]
+ {
+ "id": 1,
+ "stage": "CHECKED",
+ "errlevel": 0,
+ "stagestatus": "Audit completed",
+ "errormessage": "None",
+ "sql": "use archer_test",
+ "affected_rows": 0,
+ "sequence": "'0_0_0'",
+ "backup_dbname": "None",
+ "execute_time": "0",
+ "sqlsha1": "",
+ "actual_affected_rows": "null",
+ },
+ {
+ "id": 2,
+ "stage": "CHECKED",
+ "errlevel": 0,
+ "stagestatus": "Audit completed",
+ "errormessage": "None",
+ "sql": "update table users set email=''",
+ "affected_rows": 3,
+ "sequence": "'0_0_1'",
+ "backup_dbname": "mysql_3306_archer_test",
+ "execute_time": "0",
+ "sqlsha1": "",
+ "actual_affected_rows": "null",
+ },
+ ]
r = is_auto_review(self.wfc1.workflow_id)
self.assertFalse(r)
- @patch('sql.engines.get_engine')
+ @patch("sql.engines.get_engine")
def test_auto_review_true(self, _get_engine):
"""
测试自动审批通过的判定条件,
:return:
"""
# 开启自动审批设置
- self.sys_config.set('auto_review', 'true')
- self.sys_config.set('auto_review_db_type', 'mysql')
- self.sys_config.set('auto_review_regex', '^drop') # drop语句需要审批
- self.sys_config.set('auto_review_max_update_rows', '2') # update影响行数大于2需要审批
- self.sys_config.set('auto_review_tag', 'GA') # 仅GA开启自动审批
+ self.sys_config.set("auto_review", "true")
+ self.sys_config.set("auto_review_db_type", "mysql")
+ self.sys_config.set("auto_review_regex", "^drop") # drop语句需要审批
+ self.sys_config.set("auto_review_max_update_rows", "2") # update影响行数大于2需要审批
+ self.sys_config.set("auto_review_tag", "GA") # 仅GA开启自动审批
self.sys_config.get_all_config()
# 修改工单为update,mock返回值,update影响行数=3
self.wfc1.sql_content = "update table users set email='';"
- self.wfc1.review_content = json.dumps([
- {"id": 1, "stage": "CHECKED", "errlevel": 0, "stagestatus": "Audit completed", "errormessage": "None",
- "sql": "use archer_test", "affected_rows": 0, "sequence": "'0_0_0'", "backup_dbname": "None",
- "execute_time": "0", "sqlsha1": "", "actual_affected_rows": 'null'},
- {"id": 2, "stage": "CHECKED", "errlevel": 0, "stagestatus": "Audit completed", "errormessage": "None",
- "sql": "update table users set email=''", "affected_rows": 1, "sequence": "'0_0_1'",
- "backup_dbname": "mysql_3306_archer_test", "execute_time": "0", "sqlsha1": "",
- "actual_affected_rows": 'null'}])
- self.wfc1.save(update_fields=('sql_content', 'review_content'))
+ self.wfc1.review_content = json.dumps(
+ [
+ {
+ "id": 1,
+ "stage": "CHECKED",
+ "errlevel": 0,
+ "stagestatus": "Audit completed",
+ "errormessage": "None",
+ "sql": "use archer_test",
+ "affected_rows": 0,
+ "sequence": "'0_0_0'",
+ "backup_dbname": "None",
+ "execute_time": "0",
+ "sqlsha1": "",
+ "actual_affected_rows": "null",
+ },
+ {
+ "id": 2,
+ "stage": "CHECKED",
+ "errlevel": 0,
+ "stagestatus": "Audit completed",
+ "errormessage": "None",
+ "sql": "update table users set email=''",
+ "affected_rows": 1,
+ "sequence": "'0_0_1'",
+ "backup_dbname": "mysql_3306_archer_test",
+ "execute_time": "0",
+ "sqlsha1": "",
+ "actual_affected_rows": "null",
+ },
+ ]
+ )
+ self.wfc1.save(update_fields=("sql_content", "review_content"))
# 修改工单实例标签
- tag, is_created = InstanceTag.objects.get_or_create(tag_code='GA',
- defaults={'tag_name': '生产环境', 'active': True})
+ tag, is_created = InstanceTag.objects.get_or_create(
+ tag_code="GA", defaults={"tag_name": "生产环境", "active": True}
+ )
self.wf1.instance.instance_tag.add(tag)
r = is_auto_review(self.wfc1.workflow_id)
self.assertTrue(r)
- @patch('sql.engines.get_engine')
+ @patch("sql.engines.get_engine")
def test_auto_review_false(self, _get_engine):
"""
测试自动审批通过的判定条件,
:return:
"""
# 开启自动审批设置
- self.sys_config.set('auto_review', 'true')
- self.sys_config.set('auto_review_db_type', '') # 未配置auto_review_db_type需要审批
- self.sys_config.set('auto_review_regex', '^drop') # drop语句需要审批
- self.sys_config.set('auto_review_max_update_rows', '2') # update影响行数大于2需要审批
- self.sys_config.set('auto_review_tag', 'GA') # 仅GA开启自动审批
+ self.sys_config.set("auto_review", "true")
+ self.sys_config.set("auto_review_db_type", "") # 未配置auto_review_db_type需要审批
+ self.sys_config.set("auto_review_regex", "^drop") # drop语句需要审批
+ self.sys_config.set("auto_review_max_update_rows", "2") # update影响行数大于2需要审批
+ self.sys_config.set("auto_review_tag", "GA") # 仅GA开启自动审批
self.sys_config.get_all_config()
# 修改工单为update,mock返回值,update影响行数=3
self.wfc1.sql_content = "update table users set email='';"
- self.wfc1.review_content = json.dumps([
- {"id": 1, "stage": "CHECKED", "errlevel": 0, "stagestatus": "Audit completed", "errormessage": "None",
- "sql": "use archer_test", "affected_rows": 0, "sequence": "'0_0_0'", "backup_dbname": "None",
- "execute_time": "0", "sqlsha1": "", "actual_affected_rows": 'null'},
- {"id": 2, "stage": "CHECKED", "errlevel": 0, "stagestatus": "Audit completed", "errormessage": "None",
- "sql": "update table users set email=''", "affected_rows": 1, "sequence": "'0_0_1'",
- "backup_dbname": "mysql_3306_archer_test", "execute_time": "0", "sqlsha1": "",
- "actual_affected_rows": 'null'}])
- self.wfc1.save(update_fields=('sql_content', 'review_content'))
+ self.wfc1.review_content = json.dumps(
+ [
+ {
+ "id": 1,
+ "stage": "CHECKED",
+ "errlevel": 0,
+ "stagestatus": "Audit completed",
+ "errormessage": "None",
+ "sql": "use archer_test",
+ "affected_rows": 0,
+ "sequence": "'0_0_0'",
+ "backup_dbname": "None",
+ "execute_time": "0",
+ "sqlsha1": "",
+ "actual_affected_rows": "null",
+ },
+ {
+ "id": 2,
+ "stage": "CHECKED",
+ "errlevel": 0,
+ "stagestatus": "Audit completed",
+ "errormessage": "None",
+ "sql": "update table users set email=''",
+ "affected_rows": 1,
+ "sequence": "'0_0_1'",
+ "backup_dbname": "mysql_3306_archer_test",
+ "execute_time": "0",
+ "sqlsha1": "",
+ "actual_affected_rows": "null",
+ },
+ ]
+ )
+ self.wfc1.save(update_fields=("sql_content", "review_content"))
# 修改工单实例标签
- tag, is_created = InstanceTag.objects.get_or_create(tag_code='GA',
- defaults={'tag_name': '生产环境', 'active': True})
+ tag, is_created = InstanceTag.objects.get_or_create(
+ tag_code="GA", defaults={"tag_name": "生产环境", "active": True}
+ )
self.wf1.instance.instance_tag.add(tag)
r = is_auto_review(self.wfc1.workflow_id)
self.assertFalse(r)
- def test_can_execute_for_resource_group(self, ):
+ def test_can_execute_for_resource_group(
+ self,
+ ):
"""
测试是否能执行的判定条件,登录用户有资源组粒度执行权限,并且为组内用户
:return:
"""
# 修改工单为workflow_review_pass,登录用户有资源组粒度执行权限,并且为组内用户
- self.wf1.status = 'workflow_review_pass'
- self.wf1.save(update_fields=('status',))
- sql_execute_for_resource_group = Permission.objects.get(codename='sql_execute_for_resource_group')
+ self.wf1.status = "workflow_review_pass"
+ self.wf1.save(update_fields=("status",))
+ sql_execute_for_resource_group = Permission.objects.get(
+ codename="sql_execute_for_resource_group"
+ )
self.user.user_permissions.add(sql_execute_for_resource_group)
self.user.resource_group.add(self.group)
r = can_execute(user=self.user, workflow_id=self.wfc1.workflow_id)
self.assertTrue(r)
- def test_can_execute_true(self, ):
+ def test_can_execute_true(
+ self,
+ ):
"""
测试是否能执行的判定条件,当前登录用户为提交人,并且有执行权限,工单状态为审核通过
:return:
"""
# 修改工单为workflow_review_pass,当前登录用户为提交人,并且有执行权限
- self.wf1.status = 'workflow_review_pass'
+ self.wf1.status = "workflow_review_pass"
self.wf1.engineer = self.user.username
- self.wf1.save(update_fields=('status', 'engineer'))
- sql_execute = Permission.objects.get(codename='sql_execute')
+ self.wf1.save(update_fields=("status", "engineer"))
+ sql_execute = Permission.objects.get(codename="sql_execute")
self.user.user_permissions.add(sql_execute)
r = can_execute(user=self.user, workflow_id=self.wfc1.workflow_id)
self.assertTrue(r)
- def test_can_execute_workflow_timing_task(self, ):
+ def test_can_execute_workflow_timing_task(
+ self,
+ ):
"""
测试是否能执行的判定条件,当前登录用户为提交人,并且有执行权限,工单状态为定时执行
:return:
"""
# 修改工单为workflow_review_pass,当前登录用户为提交人,并且有执行权限
- self.wf1.status = 'workflow_timingtask'
+ self.wf1.status = "workflow_timingtask"
self.wf1.engineer = self.user.username
- self.wf1.save(update_fields=('status', 'engineer'))
- sql_execute = Permission.objects.get(codename='sql_execute')
+ self.wf1.save(update_fields=("status", "engineer"))
+ sql_execute = Permission.objects.get(codename="sql_execute")
self.user.user_permissions.add(sql_execute)
r = can_execute(user=self.user, workflow_id=self.wfc1.workflow_id)
self.assertTrue(r)
- def test_can_execute_false_no_permission(self, ):
+ def test_can_execute_false_no_permission(
+ self,
+ ):
"""
当前登录用户为提交人,但是没有执行权限
:return:
"""
# 修改工单为workflow_review_pass,当前登录用户为提交人,并且有执行权限
- self.wf1.status = 'workflow_timingtask'
+ self.wf1.status = "workflow_timingtask"
self.wf1.engineer = self.user.username
- self.wf1.save(update_fields=('status', 'engineer'))
+ self.wf1.save(update_fields=("status", "engineer"))
r = can_execute(user=self.user, workflow_id=self.wfc1.workflow_id)
self.assertFalse(r)
- def test_can_execute_false_not_in_group(self, ):
+ def test_can_execute_false_not_in_group(
+ self,
+ ):
"""
当前登录用户为提交人,有资源组粒度执行权限,但是不是组内用户
:return:
"""
# 修改工单为workflow_review_pass,有资源组粒度执行权限,但是不是组内用户
- self.wf1.status = 'workflow_review_pass'
- self.wf1.save(update_fields=('status',))
- sql_execute_for_resource_group = Permission.objects.get(codename='sql_execute_for_resource_group')
+ self.wf1.status = "workflow_review_pass"
+ self.wf1.save(update_fields=("status",))
+ sql_execute_for_resource_group = Permission.objects.get(
+ codename="sql_execute_for_resource_group"
+ )
self.user.user_permissions.add(sql_execute_for_resource_group)
r = can_execute(user=self.user, workflow_id=self.wfc1.workflow_id)
self.assertFalse(r)
- def test_can_execute_false_wrong_status(self, ):
+ def test_can_execute_false_wrong_status(
+ self,
+ ):
"""
当前登录用户为提交人,前登录用户为提交人,并且有执行权限,但是工单状态为待审核
:return:
"""
# 修改工单为workflow_manreviewing,当前登录用户为提交人,并且有执行权限, 但是工单状态为待审核
- self.wf1.status = 'workflow_manreviewing'
+ self.wf1.status = "workflow_manreviewing"
self.wf1.engineer = self.user.username
- self.wf1.save(update_fields=('status', 'engineer'))
- sql_execute = Permission.objects.get(codename='sql_execute')
+ self.wf1.save(update_fields=("status", "engineer"))
+ sql_execute = Permission.objects.get(codename="sql_execute")
self.user.user_permissions.add(sql_execute)
r = can_execute(user=self.user, workflow_id=self.wfc1.workflow_id)
self.assertFalse(r)
- def test_can_timingtask_true(self, ):
+ def test_can_timingtask_true(
+ self,
+ ):
"""
测试是否能定时执行的判定条件,当前登录用户为提交人,并且有执行权限,工单状态为审核通过
:return:
"""
# 修改工单为workflow_review_pass,当前登录用户为提交人,并且有执行权限
- self.wf1.status = 'workflow_review_pass'
+ self.wf1.status = "workflow_review_pass"
self.wf1.engineer = self.user.username
- self.wf1.save(update_fields=('status', 'engineer'))
- sql_execute = Permission.objects.get(codename='sql_execute')
+ self.wf1.save(update_fields=("status", "engineer"))
+ sql_execute = Permission.objects.get(codename="sql_execute")
self.user.user_permissions.add(sql_execute)
r = can_timingtask(user=self.user, workflow_id=self.wfc1.workflow_id)
self.assertTrue(r)
- def test_can_timingtask_false(self, ):
+ def test_can_timingtask_false(
+ self,
+ ):
"""
测试是否能定时执行的判定条件,当前登录有执行权限,工单状态为审核通过,但用户不是提交人
:return:
"""
# 修改工单为workflow_review_pass,当前登录用户为提交人,并且有执行权限
- self.wf1.status = 'workflow_review_pass'
+ self.wf1.status = "workflow_review_pass"
self.wf1.engineer = self.superuser.username
- self.wf1.save(update_fields=('status', 'engineer'))
- sql_execute = Permission.objects.get(codename='sql_execute')
+ self.wf1.save(update_fields=("status", "engineer"))
+ sql_execute = Permission.objects.get(codename="sql_execute")
self.user.user_permissions.add(sql_execute)
r = can_timingtask(user=self.user, workflow_id=self.wfc1.workflow_id)
self.assertFalse(r)
- @patch('sql.utils.workflow_audit.Audit.can_review')
+ @patch("sql.utils.workflow_audit.Audit.can_review")
def test_can_cancel_true_for_apply_user(self, _can_review):
"""
测试是否能取消,审核中的工单,提交人可终止
:return:
"""
# 修改工单为workflow_review_pass,当前登录用户为提交人
- self.wf1.status = 'workflow_manreviewing'
+ self.wf1.status = "workflow_manreviewing"
self.wf1.engineer = self.user.username
- self.wf1.save(update_fields=('status', 'engineer'))
+ self.wf1.save(update_fields=("status", "engineer"))
_can_review.return_value = False
r = can_cancel(user=self.user, workflow_id=self.wfc1.workflow_id)
self.assertTrue(r)
- @patch('sql.utils.workflow_audit.Audit.can_review')
+ @patch("sql.utils.workflow_audit.Audit.can_review")
def test_can_cancel_true_for_audit_user(self, _can_review):
"""
测试是否能取消,审核中的工单,审核人可终止
:return:
"""
# 修改工单为workflow_review_pass,当前登录用户为提交人
- self.wf1.status = 'workflow_manreviewing'
+ self.wf1.status = "workflow_manreviewing"
self.wf1.engineer = self.superuser.username
- self.wf1.save(update_fields=('status', 'engineer'))
+ self.wf1.save(update_fields=("status", "engineer"))
_can_review.return_value = True
r = can_cancel(user=self.user, workflow_id=self.wfc1.workflow_id)
self.assertTrue(r)
- @patch('sql.utils.sql_review.can_execute')
+ @patch("sql.utils.sql_review.can_execute")
def test_can_cancel_true_for_execute_user(self, _can_execute):
"""
测试是否能取消,审核通过但未执行的工单,有执行权限的用户终止
:return:
"""
# 修改工单为workflow_review_pass,当前登录用户为提交人
- self.wf1.status = 'workflow_review_pass'
+ self.wf1.status = "workflow_review_pass"
self.wf1.engineer = self.user.username
- self.wf1.save(update_fields=('status', 'engineer'))
+ self.wf1.save(update_fields=("status", "engineer"))
_can_execute.return_value = True
r = can_cancel(user=self.user, workflow_id=self.wfc1.workflow_id)
self.assertTrue(r)
- @patch('sql.utils.sql_review.can_execute')
+ @patch("sql.utils.sql_review.can_execute")
def test_can_cancel_true_for_submit_user(self, _can_execute):
"""
测试是否能取消,审核通过但未执行的工单,提交人可终止
:return:
"""
# 修改工单为workflow_review_pass,当前登录用户为提交人
- self.wf1.status = 'workflow_review_pass'
+ self.wf1.status = "workflow_review_pass"
self.wf1.engineer = self.user.username
- self.wf1.save(update_fields=('status', 'engineer'))
+ self.wf1.save(update_fields=("status", "engineer"))
_can_execute.return_value = True
r = can_cancel(user=self.user, workflow_id=self.wfc1.workflow_id)
self.assertTrue(r)
@@ -450,10 +582,12 @@ def test_on_correct_time_period(self):
:return:
"""
# 修改工单为workflow_review_pass,当前登录用户为提交人
- self.wf1.run_date_start = '2019-06-15 11:10:00'
- self.wf1.run_date_end = '2019-06-15 11:30:00'
- self.wf1.save(update_fields=('run_date_start', 'run_date_end'))
- run_date = datetime.datetime.strptime('2019-06-15 11:15:00', "%Y-%m-%d %H:%M:%S")
+ self.wf1.run_date_start = "2019-06-15 11:10:00"
+ self.wf1.run_date_end = "2019-06-15 11:30:00"
+ self.wf1.save(update_fields=("run_date_start", "run_date_end"))
+ run_date = datetime.datetime.strptime(
+ "2019-06-15 11:15:00", "%Y-%m-%d %H:%M:%S"
+ )
r = on_correct_time_period(self.wf1.id, run_date=run_date)
self.assertTrue(r)
@@ -463,77 +597,94 @@ def test_not_in_correct_time_period(self):
:return:
"""
# 修改工单为workflow_review_pass,当前登录用户为提交人
- self.wf1.run_date_start = '2019-06-15 11:10:00'
- self.wf1.run_date_end = '2019-06-15 11:30:00'
- self.wf1.save(update_fields=('run_date_start', 'run_date_end'))
- run_date = datetime.datetime.strptime('2019-06-15 11:45:00', "%Y-%m-%d %H:%M:%S")
+ self.wf1.run_date_start = "2019-06-15 11:10:00"
+ self.wf1.run_date_end = "2019-06-15 11:30:00"
+ self.wf1.save(update_fields=("run_date_start", "run_date_end"))
+ run_date = datetime.datetime.strptime(
+ "2019-06-15 11:45:00", "%Y-%m-%d %H:%M:%S"
+ )
r = on_correct_time_period(self.wf1.id, run_date=run_date)
self.assertFalse(r)
- @patch('sql.utils.sql_review.datetime')
+ @patch("sql.utils.sql_review.datetime")
def test_now_on_correct_time_period(self, _datetime):
"""
测试当前时间在可执行时间内
:return:
"""
# 修改工单为workflow_review_pass,当前登录用户为提交人
- self.wf1.run_date_start = '2019-06-15 11:10:00'
- self.wf1.run_date_end = '2019-06-15 11:30:00'
- self.wf1.save(update_fields=('run_date_start', 'run_date_end'))
+ self.wf1.run_date_start = "2019-06-15 11:10:00"
+ self.wf1.run_date_end = "2019-06-15 11:30:00"
+ self.wf1.save(update_fields=("run_date_start", "run_date_end"))
_datetime.datetime.now.return_value = datetime.datetime.strptime(
- '2019-06-15 11:15:00', "%Y-%m-%d %H:%M:%S")
+ "2019-06-15 11:15:00", "%Y-%m-%d %H:%M:%S"
+ )
r = on_correct_time_period(self.wf1.id)
self.assertTrue(r)
- @patch('sql.utils.sql_review.datetime')
+ @patch("sql.utils.sql_review.datetime")
def test_now_not_in_correct_time_period(self, _datetime):
"""
测试当前时间不在可执行时间内
:return:
"""
# 修改工单为workflow_review_pass,当前登录用户为提交人
- self.wf1.run_date_start = '2019-06-15 11:10:00'
- self.wf1.run_date_end = '2019-06-15 11:30:00'
- self.wf1.save(update_fields=('run_date_start', 'run_date_end'))
+ self.wf1.run_date_start = "2019-06-15 11:10:00"
+ self.wf1.run_date_end = "2019-06-15 11:30:00"
+ self.wf1.save(update_fields=("run_date_start", "run_date_end"))
_datetime.datetime.now.return_value = datetime.datetime.strptime(
- '2019-06-15 11:55:00', "%Y-%m-%d %H:%M:%S")
+ "2019-06-15 11:55:00", "%Y-%m-%d %H:%M:%S"
+ )
r = on_correct_time_period(self.wf1.id)
self.assertFalse(r)
class TestExecuteSql(TestCase):
def setUp(self):
- self.ins = Instance.objects.create(instance_name='some_ins', type='slave', db_type='mysql',
- host='some_host',
- port=3306, user='ins_user', password='some_str')
+ self.ins = Instance.objects.create(
+ instance_name="some_ins",
+ type="slave",
+ db_type="mysql",
+ host="some_host",
+ port=3306,
+ user="ins_user",
+ password="some_str",
+ )
self.wf = SqlWorkflow.objects.create(
- workflow_name='some_name',
+ workflow_name="some_name",
group_id=1,
- group_name='g1',
- engineer_display='',
- audit_auth_groups='some_group',
+ group_name="g1",
+ engineer_display="",
+ audit_auth_groups="some_group",
create_time=datetime.datetime.now(),
- status='workflow_timingtask',
+ status="workflow_timingtask",
is_backup=True,
instance=self.ins,
- db_name='some_db',
- syntax_type=1
- )
- SqlWorkflowContent.objects.create(workflow=self.wf,
- sql_content='some_sql',
- execute_result=ReviewSet(rows=[ReviewResult(
- id=0,
- stage='Execute failed',
- errlevel=2,
- stagestatus='异常终止',
- errormessage='',
- sql='执行异常信息',
- affected_rows=0,
- actual_affected_rows=0,
- sequence='0_0_0',
- backup_dbname=None,
- execute_time=0,
- sqlsha1='')]).json())
+ db_name="some_db",
+ syntax_type=1,
+ )
+ SqlWorkflowContent.objects.create(
+ workflow=self.wf,
+ sql_content="some_sql",
+ execute_result=ReviewSet(
+ rows=[
+ ReviewResult(
+ id=0,
+ stage="Execute failed",
+ errlevel=2,
+ stagestatus="异常终止",
+ errormessage="",
+ sql="执行异常信息",
+ affected_rows=0,
+ actual_affected_rows=0,
+ sequence="0_0_0",
+ backup_dbname=None,
+ execute_time=0,
+ sqlsha1="",
+ )
+ ]
+ ).json(),
+ )
def tearDown(self):
self.ins.delete()
@@ -541,9 +692,9 @@ def tearDown(self):
SqlWorkflowContent.objects.all().delete()
WorkflowLog.objects.all().delete()
- @patch('sql.utils.execute_sql.Audit')
- @patch('sql.engines.mysql.MysqlEngine.execute_workflow')
- @patch('sql.engines.get_engine')
+ @patch("sql.utils.execute_sql.Audit")
+ @patch("sql.engines.mysql.MysqlEngine.execute_workflow")
+ @patch("sql.engines.get_engine")
def test_execute(self, _get_engine, _execute_workflow, _audit):
_audit.detail_by_workflow_id.return_value.audit_id = 1
execute(self.wf.id)
@@ -551,194 +702,216 @@ def test_execute(self, _get_engine, _execute_workflow, _audit):
_audit.add_log.assert_called_with(
audit_id=1,
operation_type=5,
- operation_type_desc='执行工单',
- operation_info='系统定时执行工单',
- operator='',
- operator_display='系统',
+ operation_type_desc="执行工单",
+ operation_info="系统定时执行工单",
+ operator="",
+ operator_display="系统",
)
- @patch('sql.utils.execute_sql.notify_for_execute')
- @patch('sql.utils.execute_sql.Audit')
+ @patch("sql.utils.execute_sql.notify_for_execute")
+ @patch("sql.utils.execute_sql.Audit")
def test_execute_callback_success(self, _audit, _notify):
# 初始化工单执行返回对象
self.task_result = MagicMock()
self.task_result.args = [self.wf.id]
self.task_result.success = True
self.task_result.stopped = datetime.datetime.now()
- self.task_result.result.json.return_value = json.dumps([{'id': 1, 'sql': 'some_content'}])
- self.task_result.result.warning = ''
- self.task_result.result.error = ''
+ self.task_result.result.json.return_value = json.dumps(
+ [{"id": 1, "sql": "some_content"}]
+ )
+ self.task_result.result.warning = ""
+ self.task_result.result.error = ""
_audit.detail_by_workflow_id.return_value.audit_id = 123
- _audit.add_log.return_value = 'any thing'
+ _audit.add_log.return_value = "any thing"
# 先处理为执行中
- self.wf.status = 'workflow_executing'
- self.wf.save(update_fields=['status'])
+ self.wf.status = "workflow_executing"
+ self.wf.save(update_fields=["status"])
execute_callback(self.task_result)
- _audit.detail_by_workflow_id.assert_called_with(workflow_id=self.wf.id, workflow_type=2)
+ _audit.detail_by_workflow_id.assert_called_with(
+ workflow_id=self.wf.id, workflow_type=2
+ )
_audit.add_log.assert_called_with(
audit_id=123,
operation_type=6,
- operation_type_desc='执行结束',
+ operation_type_desc="执行结束",
operation_info="执行结果:已正常结束",
- operator='',
- operator_display='系统',
+ operator="",
+ operator_display="系统",
)
_notify.assert_called_once()
- @patch('sql.utils.execute_sql.notify_for_execute')
- @patch('sql.utils.execute_sql.Audit')
+ @patch("sql.utils.execute_sql.notify_for_execute")
+ @patch("sql.utils.execute_sql.Audit")
def test_execute_callback_failure(self, _audit, _notify):
# 初始化工单执行返回对象
self.task_result = MagicMock()
self.task_result.args = [self.wf.id]
self.task_result.success = False
self.task_result.stopped = datetime.datetime.now()
- self.task_result.result = '执行异常'
+ self.task_result.result = "执行异常"
_audit.detail_by_workflow_id.return_value.audit_id = 123
- _audit.add_log.return_value = 'any thing'
+ _audit.add_log.return_value = "any thing"
# 处理状态为执行中
- self.wf.status = 'workflow_executing'
- self.wf.save(update_fields=['status'])
+ self.wf.status = "workflow_executing"
+ self.wf.save(update_fields=["status"])
execute_callback(self.task_result)
- _audit.detail_by_workflow_id.assert_called_with(workflow_id=self.wf.id, workflow_type=2)
+ _audit.detail_by_workflow_id.assert_called_with(
+ workflow_id=self.wf.id, workflow_type=2
+ )
_audit.add_log.assert_called_with(
audit_id=123,
operation_type=6,
- operation_type_desc='执行结束',
+ operation_type_desc="执行结束",
operation_info="执行结果:执行有异常",
- operator='',
- operator_display='系统',
+ operator="",
+ operator_display="系统",
)
_notify.assert_called_once()
- @patch('sql.utils.execute_sql.notify_for_execute')
- @patch('sql.utils.execute_sql.Audit')
+ @patch("sql.utils.execute_sql.notify_for_execute")
+ @patch("sql.utils.execute_sql.Audit")
def test_execute_callback_failure_no_execute_result(self, _audit, _notify):
# 初始化工单执行返回对象
self.task_result = MagicMock()
self.task_result.args = [self.wf.id]
self.task_result.success = False
self.task_result.stopped = datetime.datetime.now()
- self.task_result.result = '执行异常'
+ self.task_result.result = "执行异常"
_audit.detail_by_workflow_id.return_value.audit_id = 123
- _audit.add_log.return_value = 'any thing'
+ _audit.add_log.return_value = "any thing"
# 删除execute_result、处理为执行中
- self.wf.sqlworkflowcontent.execute_result = ''
+ self.wf.sqlworkflowcontent.execute_result = ""
self.wf.sqlworkflowcontent.save()
- self.wf.status = 'workflow_executing'
- self.wf.save(update_fields=['status'])
+ self.wf.status = "workflow_executing"
+ self.wf.save(update_fields=["status"])
execute_callback(self.task_result)
- _audit.detail_by_workflow_id.assert_called_with(workflow_id=self.wf.id, workflow_type=2)
+ _audit.detail_by_workflow_id.assert_called_with(
+ workflow_id=self.wf.id, workflow_type=2
+ )
_audit.add_log.assert_called_with(
audit_id=123,
operation_type=6,
- operation_type_desc='执行结束',
+ operation_type_desc="执行结束",
operation_info="执行结果:执行有异常",
- operator='',
- operator_display='系统',
+ operator="",
+ operator_display="系统",
)
_notify.assert_called_once()
class TestTasks(TestCase):
def setUp(self):
- self.Schedule = Schedule.objects.create(name='some_name')
+ self.Schedule = Schedule.objects.create(name="some_name")
def tearDown(self):
Schedule.objects.all().delete()
- @patch('sql.utils.tasks.schedule')
+ @patch("sql.utils.tasks.schedule")
def test_add_sql_schedule(self, _schedule):
- add_sql_schedule('test', datetime.datetime.now(), 1)
+ add_sql_schedule("test", datetime.datetime.now(), 1)
_schedule.assert_called_once()
def test_del_schedule(self):
- del_schedule('some_name')
+ del_schedule("some_name")
with self.assertRaises(Schedule.DoesNotExist):
- Schedule.objects.get(name='some_name')
+ Schedule.objects.get(name="some_name")
def test_del_schedule_not_exists(self):
- del_schedule('some_name1')
+ del_schedule("some_name1")
def test_task_info(self):
- task_info('some_name')
+ task_info("some_name")
def test_task_info_not_exists(self):
with self.assertRaises(Schedule.DoesNotExist):
- Schedule.objects.get(name='some_name1')
+ Schedule.objects.get(name="some_name1")
class TestAudit(TestCase):
def setUp(self):
self.sys_config = SysConfig()
- self.user = User.objects.create(username='test_user', display='中文显示', is_active=True)
- self.su = User.objects.create(username='s_user', display='中文显示', is_active=True, is_superuser=True)
+ self.user = User.objects.create(
+ username="test_user", display="中文显示", is_active=True
+ )
+ self.su = User.objects.create(
+ username="s_user", display="中文显示", is_active=True, is_superuser=True
+ )
tomorrow = datetime.datetime.today() + datetime.timedelta(days=1)
- self.ins = Instance.objects.create(instance_name='some_ins', type='slave', db_type='mysql',
- host='some_host',
- port=3306, user='ins_user', password='some_str')
- self.res_group = ResourceGroup.objects.create(group_id=1, group_name='group_name')
+ self.ins = Instance.objects.create(
+ instance_name="some_ins",
+ type="slave",
+ db_type="mysql",
+ host="some_host",
+ port=3306,
+ user="ins_user",
+ password="some_str",
+ )
+ self.res_group = ResourceGroup.objects.create(
+ group_id=1, group_name="group_name"
+ )
self.wf = SqlWorkflow.objects.create(
- workflow_name='some_name',
+ workflow_name="some_name",
group_id=1,
- group_name='g1',
- engineer_display='',
- audit_auth_groups='some_audit_group',
+ group_name="g1",
+ engineer_display="",
+ audit_auth_groups="some_audit_group",
create_time=datetime.datetime.now(),
- status='workflow_timingtask',
+ status="workflow_timingtask",
is_backup=True,
instance=self.ins,
- db_name='some_db',
- syntax_type=1
+ db_name="some_db",
+ syntax_type=1,
+ )
+ SqlWorkflowContent.objects.create(
+ workflow=self.wf, sql_content="some_sql", execute_result=""
)
- SqlWorkflowContent.objects.create(workflow=self.wf,
- sql_content='some_sql',
- execute_result='')
self.query_apply_1 = QueryPrivilegesApply.objects.create(
group_id=1,
- group_name='some_name',
- title='some_title1',
- user_name='some_user',
+ group_name="some_name",
+ title="some_title1",
+ user_name="some_user",
instance=self.ins,
- db_list='some_db,some_db2',
+ db_list="some_db,some_db2",
limit_num=100,
valid_date=tomorrow,
priv_type=1,
status=0,
- audit_auth_groups='some_audit_group'
+ audit_auth_groups="some_audit_group",
)
self.archive_apply_1 = ArchiveConfig.objects.create(
- title='title',
+ title="title",
resource_group=self.res_group,
- audit_auth_groups='some_audit_group',
+ audit_auth_groups="some_audit_group",
src_instance=self.ins,
- src_db_name='src_db_name',
- src_table_name='src_table_name',
+ src_db_name="src_db_name",
+ src_table_name="src_table_name",
dest_instance=self.ins,
- dest_db_name='src_db_name',
- dest_table_name='src_table_name',
- condition='1=1',
- mode='file',
+ dest_db_name="src_db_name",
+ dest_table_name="src_table_name",
+ condition="1=1",
+ mode="file",
no_delete=True,
sleep=1,
- status=WorkflowDict.workflow_status['audit_wait'],
+ status=WorkflowDict.workflow_status["audit_wait"],
state=False,
- user_name='some_user',
- user_display='display',
+ user_name="some_user",
+ user_display="display",
)
self.audit = WorkflowAudit.objects.create(
group_id=1,
- group_name='some_group',
+ group_name="some_group",
workflow_id=1,
workflow_type=1,
- workflow_title='申请标题',
- workflow_remark='申请备注',
- audit_auth_groups='1,2,3',
- current_audit='1',
- next_audit='2',
- current_status=0)
- self.wl = WorkflowLog.objects.create(audit_id=self.audit.audit_id,
- operation_type=1)
+ workflow_title="申请标题",
+ workflow_remark="申请备注",
+ audit_auth_groups="1,2,3",
+ current_audit="1",
+ next_audit="2",
+ current_status=0,
+ )
+ self.wl = WorkflowLog.objects.create(
+ audit_id=self.audit.audit_id, operation_type=1
+ )
def tearDown(self):
self.sys_config.purge()
@@ -754,224 +927,247 @@ def tearDown(self):
ArchiveConfig.objects.all().delete()
def test_audit_add_query(self):
- """ 测试添加查询审核工单"""
+ """测试添加查询审核工单"""
result = Audit.add(1, self.query_apply_1.apply_id)
- audit_id = result['data']['audit_id']
- workflow_status = result['data']['workflow_status']
- self.assertEqual(workflow_status, WorkflowDict.workflow_status['audit_wait'])
+ audit_id = result["data"]["audit_id"]
+ workflow_status = result["data"]["workflow_status"]
+ self.assertEqual(workflow_status, WorkflowDict.workflow_status["audit_wait"])
audit_detail = WorkflowAudit.objects.get(audit_id=audit_id)
# 当前审批
- self.assertEqual(audit_detail.current_audit, 'some_audit_group')
+ self.assertEqual(audit_detail.current_audit, "some_audit_group")
# 无下级审批
- self.assertEqual(audit_detail.next_audit, '-1')
+ self.assertEqual(audit_detail.next_audit, "-1")
# 验证日志
log_info = WorkflowLog.objects.filter(audit_id=audit_id).first()
self.assertEqual(log_info.operation_type, 0)
- self.assertEqual(log_info.operation_type_desc, '提交')
- self.assertIn('等待审批,审批流程:', log_info.operation_info)
+ self.assertEqual(log_info.operation_type_desc, "提交")
+ self.assertIn("等待审批,审批流程:", log_info.operation_info)
def test_audit_add_sqlreview(self):
- """ 测试添加上线审核工单"""
+ """测试添加上线审核工单"""
result = Audit.add(2, self.wf.id)
- audit_id = result['data']['audit_id']
- workflow_status = result['data']['workflow_status']
- self.assertEqual(workflow_status, WorkflowDict.workflow_status['audit_wait'])
+ audit_id = result["data"]["audit_id"]
+ workflow_status = result["data"]["workflow_status"]
+ self.assertEqual(workflow_status, WorkflowDict.workflow_status["audit_wait"])
audit_detail = WorkflowAudit.objects.get(audit_id=audit_id)
# 当前审批
- self.assertEqual(audit_detail.current_audit, 'some_audit_group')
+ self.assertEqual(audit_detail.current_audit, "some_audit_group")
# 无下级审批
- self.assertEqual(audit_detail.next_audit, '-1')
+ self.assertEqual(audit_detail.next_audit, "-1")
# 验证日志
log_info = WorkflowLog.objects.filter(audit_id=audit_id).first()
self.assertEqual(log_info.operation_type, 0)
- self.assertEqual(log_info.operation_type_desc, '提交')
- self.assertIn('等待审批,审批流程:', log_info.operation_info)
+ self.assertEqual(log_info.operation_type_desc, "提交")
+ self.assertIn("等待审批,审批流程:", log_info.operation_info)
def test_audit_add_archive_review(self):
- """ 测试添加数据归档工单"""
+ """测试添加数据归档工单"""
result = Audit.add(3, self.archive_apply_1.id)
- audit_id = result['data']['audit_id']
- workflow_status = result['data']['workflow_status']
- self.assertEqual(workflow_status, WorkflowDict.workflow_status['audit_wait'])
+ audit_id = result["data"]["audit_id"]
+ workflow_status = result["data"]["workflow_status"]
+ self.assertEqual(workflow_status, WorkflowDict.workflow_status["audit_wait"])
audit_detail = WorkflowAudit.objects.get(audit_id=audit_id)
# 当前审批
- self.assertEqual(audit_detail.current_audit, 'some_audit_group')
+ self.assertEqual(audit_detail.current_audit, "some_audit_group")
# 无下级审批
- self.assertEqual(audit_detail.next_audit, '-1')
+ self.assertEqual(audit_detail.next_audit, "-1")
# 验证日志
log_info = WorkflowLog.objects.filter(audit_id=audit_id).first()
self.assertEqual(log_info.operation_type, 0)
- self.assertEqual(log_info.operation_type_desc, '提交')
- self.assertIn('等待审批,审批流程:', log_info.operation_info)
+ self.assertEqual(log_info.operation_type_desc, "提交")
+ self.assertIn("等待审批,审批流程:", log_info.operation_info)
def test_audit_add_wrong_type(self):
- """ 测试添加不存在的类型"""
- with self.assertRaisesMessage(Exception, '工单类型不存在'):
+ """测试添加不存在的类型"""
+ with self.assertRaisesMessage(Exception, "工单类型不存在"):
Audit.add(4, 1)
def test_audit_add_settings_not_exists(self):
- """ 测试审批流程未配置"""
- self.wf.audit_auth_groups = ''
+ """测试审批流程未配置"""
+ self.wf.audit_auth_groups = ""
self.wf.save()
- with self.assertRaisesMessage(Exception, '审批流程不能为空,请先配置审批流程'):
+ with self.assertRaisesMessage(Exception, "审批流程不能为空,请先配置审批流程"):
Audit.add(2, self.wf.id)
def test_audit_add_duplicate(self):
"""测试重复提交"""
Audit.add(2, self.wf.id)
- with self.assertRaisesMessage(Exception, '该工单当前状态为待审核,请勿重复提交'):
+ with self.assertRaisesMessage(Exception, "该工单当前状态为待审核,请勿重复提交"):
Audit.add(2, self.wf.id)
- @patch('sql.utils.workflow_audit.is_auto_review', return_value=True)
+ @patch("sql.utils.workflow_audit.is_auto_review", return_value=True)
def test_audit_add_auto_review(self, _is_auto_review):
"""测试提交自动审核通过"""
- self.sys_config.set('auto_review', 'true')
+ self.sys_config.set("auto_review", "true")
result = Audit.add(2, self.wf.id)
- audit_id = result['data']['audit_id']
- workflow_status = result['data']['workflow_status']
- self.assertEqual(workflow_status, WorkflowDict.workflow_status['audit_success'])
+ audit_id = result["data"]["audit_id"]
+ workflow_status = result["data"]["workflow_status"]
+ self.assertEqual(workflow_status, WorkflowDict.workflow_status["audit_success"])
audit_detail = WorkflowAudit.objects.get(audit_id=audit_id)
# 无下级审批
- self.assertEqual(audit_detail.next_audit, '-1')
+ self.assertEqual(audit_detail.next_audit, "-1")
# 验证日志
log_info = WorkflowLog.objects.filter(audit_id=audit_id).first()
self.assertEqual(log_info.operation_type, 0)
- self.assertEqual(log_info.operation_type_desc, '提交')
- self.assertEqual(log_info.operation_info, '无需审批,系统直接审核通过')
+ self.assertEqual(log_info.operation_type_desc, "提交")
+ self.assertEqual(log_info.operation_info, "无需审批,系统直接审核通过")
def test_audit_add_multiple_audit(self):
"""测试提交多级审核"""
- self.wf.audit_auth_groups = '1,2,3'
+ self.wf.audit_auth_groups = "1,2,3"
self.wf.save()
result = Audit.add(2, self.wf.id)
- audit_id = result['data']['audit_id']
- workflow_status = result['data']['workflow_status']
+ audit_id = result["data"]["audit_id"]
+ workflow_status = result["data"]["workflow_status"]
audit_detail = WorkflowAudit.objects.get(audit_id=audit_id)
- self.assertEqual(workflow_status, WorkflowDict.workflow_status['audit_wait'])
+ self.assertEqual(workflow_status, WorkflowDict.workflow_status["audit_wait"])
# 存在下级审批
- self.assertEqual(audit_detail.current_audit, '1')
- self.assertEqual(audit_detail.next_audit, '2')
+ self.assertEqual(audit_detail.current_audit, "1")
+ self.assertEqual(audit_detail.next_audit, "2")
# 验证日志
log_info = WorkflowLog.objects.filter(audit_id=audit_id).first()
self.assertEqual(log_info.operation_type, 0)
- self.assertEqual(log_info.operation_type_desc, '提交')
- self.assertIn('等待审批,审批流程:', log_info.operation_info)
+ self.assertEqual(log_info.operation_type_desc, "提交")
+ self.assertIn("等待审批,审批流程:", log_info.operation_info)
def test_audit_success_not_exists_next(self):
"""测试审核通过、无下一级"""
- self.audit.current_audit = '3'
- self.audit.next_audit = '-1'
+ self.audit.current_audit = "3"
+ self.audit.next_audit = "-1"
self.audit.save()
- result = Audit.audit(self.audit.audit_id,
- WorkflowDict.workflow_status['audit_success'],
- self.user.username,
- '通过')
+ result = Audit.audit(
+ self.audit.audit_id,
+ WorkflowDict.workflow_status["audit_success"],
+ self.user.username,
+ "通过",
+ )
audit_id = self.audit.audit_id
- workflow_status = result['data']['workflow_status']
+ workflow_status = result["data"]["workflow_status"]
audit_detail = WorkflowAudit.objects.get(audit_id=audit_id)
- self.assertEqual(workflow_status, WorkflowDict.workflow_status['audit_success'])
+ self.assertEqual(workflow_status, WorkflowDict.workflow_status["audit_success"])
# 不存在下级审批
- self.assertEqual(audit_detail.next_audit, '-1')
+ self.assertEqual(audit_detail.next_audit, "-1")
# 验证日志
- log_info = WorkflowLog.objects.filter(audit_id=audit_id).order_by('-id').first()
+ log_info = WorkflowLog.objects.filter(audit_id=audit_id).order_by("-id").first()
self.assertEqual(log_info.operator, self.user.username)
self.assertEqual(log_info.operator_display, self.user.display)
self.assertEqual(log_info.operation_type, 1)
- self.assertEqual(log_info.operation_type_desc, '审批通过')
- self.assertEqual(log_info.operation_info, f'审批备注:通过,下级审批:None')
+ self.assertEqual(log_info.operation_type_desc, "审批通过")
+ self.assertEqual(log_info.operation_info, f"审批备注:通过,下级审批:None")
def test_audit_success_exists_next(self):
"""测试审核通过、存在下一级"""
- self.audit.current_audit = '1'
- self.audit.next_audit = '2'
+ self.audit.current_audit = "1"
+ self.audit.next_audit = "2"
self.audit.save()
- result = Audit.audit(self.audit.audit_id,
- WorkflowDict.workflow_status['audit_success'],
- self.user.username,
- '通过')
+ result = Audit.audit(
+ self.audit.audit_id,
+ WorkflowDict.workflow_status["audit_success"],
+ self.user.username,
+ "通过",
+ )
audit_id = self.audit.audit_id
- workflow_status = result['data']['workflow_status']
+ workflow_status = result["data"]["workflow_status"]
audit_detail = WorkflowAudit.objects.get(audit_id=audit_id)
- self.assertEqual(workflow_status, WorkflowDict.workflow_status['audit_wait'])
+ self.assertEqual(workflow_status, WorkflowDict.workflow_status["audit_wait"])
# 存在下级审批
- self.assertEqual(audit_detail.next_audit, '3')
+ self.assertEqual(audit_detail.next_audit, "3")
# 验证日志
- log_info = WorkflowLog.objects.filter(audit_id=audit_id).order_by('-id').first()
+ log_info = WorkflowLog.objects.filter(audit_id=audit_id).order_by("-id").first()
self.assertEqual(log_info.operator, self.user.username)
self.assertEqual(log_info.operator_display, self.user.display)
self.assertEqual(log_info.operation_type, 1)
- self.assertEqual(log_info.operation_type_desc, '审批通过')
- self.assertEqual(log_info.operation_info, f'审批备注:通过,下级审批:2')
+ self.assertEqual(log_info.operation_type_desc, "审批通过")
+ self.assertEqual(log_info.operation_info, f"审批备注:通过,下级审批:2")
def test_audit_reject(self):
"""测试审核不通过"""
- result = Audit.audit(self.audit.audit_id,
- WorkflowDict.workflow_status['audit_reject'],
- self.user.username,
- '不通过')
+ result = Audit.audit(
+ self.audit.audit_id,
+ WorkflowDict.workflow_status["audit_reject"],
+ self.user.username,
+ "不通过",
+ )
audit_id = self.audit.audit_id
- workflow_status = result['data']['workflow_status']
+ workflow_status = result["data"]["workflow_status"]
audit_detail = WorkflowAudit.objects.get(audit_id=audit_id)
- self.assertEqual(workflow_status, WorkflowDict.workflow_status['audit_reject'])
+ self.assertEqual(workflow_status, WorkflowDict.workflow_status["audit_reject"])
# 不存在下级审批
- self.assertEqual(audit_detail.next_audit, '-1')
+ self.assertEqual(audit_detail.next_audit, "-1")
# 验证日志
- log_info = WorkflowLog.objects.filter(audit_id=audit_id).order_by('-id').first()
+ log_info = WorkflowLog.objects.filter(audit_id=audit_id).order_by("-id").first()
self.assertEqual(log_info.operator, self.user.username)
self.assertEqual(log_info.operator_display, self.user.display)
self.assertEqual(log_info.operation_type, 2)
- self.assertEqual(log_info.operation_type_desc, '审批不通过')
- self.assertEqual(log_info.operation_info, f'审批备注:不通过')
+ self.assertEqual(log_info.operation_type_desc, "审批不通过")
+ self.assertEqual(log_info.operation_info, f"审批备注:不通过")
def test_audit_abort(self):
"""测试取消审批"""
self.audit.create_user = self.user.username
self.audit.save()
- result = Audit.audit(self.audit.audit_id,
- WorkflowDict.workflow_status['audit_abort'],
- self.user.username,
- '取消')
+ result = Audit.audit(
+ self.audit.audit_id,
+ WorkflowDict.workflow_status["audit_abort"],
+ self.user.username,
+ "取消",
+ )
audit_id = self.audit.audit_id
- workflow_status = result['data']['workflow_status']
+ workflow_status = result["data"]["workflow_status"]
audit_detail = WorkflowAudit.objects.get(audit_id=audit_id)
- self.assertEqual(workflow_status, WorkflowDict.workflow_status['audit_abort'])
+ self.assertEqual(workflow_status, WorkflowDict.workflow_status["audit_abort"])
# 不存在下级审批
- self.assertEqual(audit_detail.next_audit, '-1')
+ self.assertEqual(audit_detail.next_audit, "-1")
# 验证日志
- log_info = WorkflowLog.objects.filter(audit_id=audit_id).order_by('-id').first()
+ log_info = WorkflowLog.objects.filter(audit_id=audit_id).order_by("-id").first()
self.assertEqual(log_info.operator, self.user.username)
self.assertEqual(log_info.operator_display, self.user.display)
self.assertEqual(log_info.operation_type, 3)
- self.assertEqual(log_info.operation_type_desc, '审批取消')
- self.assertEqual(log_info.operation_info, f'取消原因:取消')
+ self.assertEqual(log_info.operation_type_desc, "审批取消")
+ self.assertEqual(log_info.operation_info, f"取消原因:取消")
def test_audit_wrong_exception(self):
"""测试审核异常的状态"""
- with self.assertRaisesMessage(Exception, '审核异常'):
- Audit.audit(self.audit.audit_id, 10, self.user.username, '')
+ with self.assertRaisesMessage(Exception, "审核异常"):
+ Audit.audit(self.audit.audit_id, 10, self.user.username, "")
def test_audit_success_wrong_status(self):
"""测试审核通过,当前状态不是待审核"""
self.audit.current_status = 1
self.audit.save()
- with self.assertRaisesMessage(Exception, '工单不是待审核状态,请返回刷新'):
- Audit.audit(self.audit.audit_id, WorkflowDict.workflow_status['audit_success'], self.user.username, '')
+ with self.assertRaisesMessage(Exception, "工单不是待审核状态,请返回刷新"):
+ Audit.audit(
+ self.audit.audit_id,
+ WorkflowDict.workflow_status["audit_success"],
+ self.user.username,
+ "",
+ )
def test_audit_reject_wrong_status(self):
"""测试审核不通过,当前状态不是待审核"""
self.audit.current_status = 1
self.audit.save()
- with self.assertRaisesMessage(Exception, '工单不是待审核状态,请返回刷新'):
- Audit.audit(self.audit.audit_id, WorkflowDict.workflow_status['audit_reject'], self.user.username, '')
+ with self.assertRaisesMessage(Exception, "工单不是待审核状态,请返回刷新"):
+ Audit.audit(
+ self.audit.audit_id,
+ WorkflowDict.workflow_status["audit_reject"],
+ self.user.username,
+ "",
+ )
def test_audit_abort_wrong_status(self):
"""测试审核不通过,当前状态不是待审核"""
self.audit.current_status = 2
self.audit.save()
- with self.assertRaisesMessage(Exception, '工单不是待审核态/审核通过状态,请返回刷新'):
- Audit.audit(self.audit.audit_id, WorkflowDict.workflow_status['audit_abort'], self.user.username, '')
-
- @patch('sql.utils.workflow_audit.user_groups', return_value=[])
+ with self.assertRaisesMessage(Exception, "工单不是待审核态/审核通过状态,请返回刷新"):
+ Audit.audit(
+ self.audit.audit_id,
+ WorkflowDict.workflow_status["audit_abort"],
+ self.user.username,
+ "",
+ )
+
+ @patch("sql.utils.workflow_audit.user_groups", return_value=[])
def test_todo(self, _user_groups):
"""TODO 测试todo数量,未断言"""
Audit.todo(self.user)
@@ -986,140 +1182,164 @@ def test_detail(self):
def test_detail_by_workflow_id(self):
"""测试通过业务id获取审核信息"""
- self.audit.workflow_type = WorkflowDict.workflow_type['sqlreview']
+ self.audit.workflow_type = WorkflowDict.workflow_type["sqlreview"]
self.audit.workflow_id = self.wf.id
self.audit.save()
- result = Audit.detail_by_workflow_id(self.wf.id, WorkflowDict.workflow_type['sqlreview'])
+ result = Audit.detail_by_workflow_id(
+ self.wf.id, WorkflowDict.workflow_type["sqlreview"]
+ )
self.assertEqual(result, self.audit)
result = Audit.detail_by_workflow_id(0, 0)
self.assertEqual(result, None)
def test_settings(self):
"""测试通过组和审核类型,获取审核配置信息"""
- WorkflowAuditSetting.objects.create(workflow_type=1, group_id=1, audit_auth_groups='1,2,3')
+ WorkflowAuditSetting.objects.create(
+ workflow_type=1, group_id=1, audit_auth_groups="1,2,3"
+ )
result = Audit.settings(workflow_type=1, group_id=1)
- self.assertEqual(result, '1,2,3')
+ self.assertEqual(result, "1,2,3")
result = Audit.settings(0, 0)
self.assertEqual(result, None)
def test_change_settings_edit(self):
"""修改配置信息"""
- ws = WorkflowAuditSetting.objects.create(workflow_type=1, group_id=1, audit_auth_groups='1,2,3')
- Audit.change_settings(workflow_type=1, group_id=1, audit_auth_groups='1,2')
+ ws = WorkflowAuditSetting.objects.create(
+ workflow_type=1, group_id=1, audit_auth_groups="1,2,3"
+ )
+ Audit.change_settings(workflow_type=1, group_id=1, audit_auth_groups="1,2")
ws = WorkflowAuditSetting.objects.get(audit_setting_id=ws.audit_setting_id)
- self.assertEqual(ws.audit_auth_groups, '1,2')
+ self.assertEqual(ws.audit_auth_groups, "1,2")
def test_change_settings_add(self):
"""添加配置信息"""
- Audit.change_settings(workflow_type=1, group_id=1, audit_auth_groups='1,2')
+ Audit.change_settings(workflow_type=1, group_id=1, audit_auth_groups="1,2")
ws = WorkflowAuditSetting.objects.get(workflow_type=1, group_id=1)
- self.assertEqual(ws.audit_auth_groups, '1,2')
+ self.assertEqual(ws.audit_auth_groups, "1,2")
- @patch('sql.utils.workflow_audit.auth_group_users')
- @patch('sql.utils.workflow_audit.Audit.detail_by_workflow_id')
+ @patch("sql.utils.workflow_audit.auth_group_users")
+ @patch("sql.utils.workflow_audit.Audit.detail_by_workflow_id")
def test_can_review_sql_review(self, _detail_by_workflow_id, _auth_group_users):
"""测试判断用户当前是否是可审核上线工单,非管理员拥有权限"""
- sql_review = Permission.objects.get(codename='sql_review')
+ sql_review = Permission.objects.get(codename="sql_review")
self.user.user_permissions.add(sql_review)
- aug = Group.objects.create(name='auth_group')
+ aug = Group.objects.create(name="auth_group")
_detail_by_workflow_id.return_value.current_audit = aug.id
_auth_group_users.return_value.filter.exists = True
- self.audit.workflow_type = WorkflowDict.workflow_type['sqlreview']
+ self.audit.workflow_type = WorkflowDict.workflow_type["sqlreview"]
self.audit.workflow_id = self.wf.id
self.audit.save()
- r = Audit.can_review(self.user, self.audit.workflow_id, self.audit.workflow_type)
+ r = Audit.can_review(
+ self.user, self.audit.workflow_id, self.audit.workflow_type
+ )
self.assertEqual(r, True)
- @patch('sql.utils.workflow_audit.auth_group_users')
- @patch('sql.utils.workflow_audit.Audit.detail_by_workflow_id')
+ @patch("sql.utils.workflow_audit.auth_group_users")
+ @patch("sql.utils.workflow_audit.Audit.detail_by_workflow_id")
def test_can_review_query_review(self, _detail_by_workflow_id, _auth_group_users):
"""测试判断用户当前是否是可审核查询工单,非管理员拥有权限"""
- query_review = Permission.objects.get(codename='query_review')
+ query_review = Permission.objects.get(codename="query_review")
self.user.user_permissions.add(query_review)
- aug = Group.objects.create(name='auth_group')
+ aug = Group.objects.create(name="auth_group")
_detail_by_workflow_id.return_value.current_audit = aug.id
_auth_group_users.return_value.filter.exists = True
- self.audit.workflow_type = WorkflowDict.workflow_type['query']
+ self.audit.workflow_type = WorkflowDict.workflow_type["query"]
self.audit.workflow_id = self.query_apply_1.apply_id
self.audit.save()
- r = Audit.can_review(self.user, self.audit.workflow_id, self.audit.workflow_type)
+ r = Audit.can_review(
+ self.user, self.audit.workflow_id, self.audit.workflow_type
+ )
self.assertEqual(r, True)
- @patch('sql.utils.workflow_audit.auth_group_users')
- @patch('sql.utils.workflow_audit.Audit.detail_by_workflow_id')
- def test_can_review_sql_review_super(self, _detail_by_workflow_id, _auth_group_users):
+ @patch("sql.utils.workflow_audit.auth_group_users")
+ @patch("sql.utils.workflow_audit.Audit.detail_by_workflow_id")
+ def test_can_review_sql_review_super(
+ self, _detail_by_workflow_id, _auth_group_users
+ ):
"""测试判断用户当前是否是可审核查询工单,用户是管理员"""
- aug = Group.objects.create(name='auth_group')
+ aug = Group.objects.create(name="auth_group")
_detail_by_workflow_id.return_value.current_audit = aug.id
_auth_group_users.return_value.filter.exists = True
- self.audit.workflow_type = WorkflowDict.workflow_type['sqlreview']
+ self.audit.workflow_type = WorkflowDict.workflow_type["sqlreview"]
self.audit.workflow_id = self.wf.id
self.audit.save()
r = Audit.can_review(self.su, self.audit.workflow_id, self.audit.workflow_type)
self.assertEqual(r, True)
- @patch('sql.utils.workflow_audit.auth_group_users')
- @patch('sql.utils.workflow_audit.Audit.detail_by_workflow_id')
+ @patch("sql.utils.workflow_audit.auth_group_users")
+ @patch("sql.utils.workflow_audit.Audit.detail_by_workflow_id")
def test_can_review_wrong_status(self, _detail_by_workflow_id, _auth_group_users):
"""测试判断用户当前是否是可审核,非待审核工单"""
- aug = Group.objects.create(name='auth_group')
+ aug = Group.objects.create(name="auth_group")
_detail_by_workflow_id.return_value.current_audit = aug.id
_auth_group_users.return_value.filter.exists = True
- self.audit.workflow_type = WorkflowDict.workflow_type['sqlreview']
+ self.audit.workflow_type = WorkflowDict.workflow_type["sqlreview"]
self.audit.workflow_id = self.wf.id
- self.audit.current_status = WorkflowDict.workflow_status['audit_success']
+ self.audit.current_status = WorkflowDict.workflow_status["audit_success"]
self.audit.save()
- r = Audit.can_review(self.user, self.audit.workflow_id, self.audit.workflow_type)
+ r = Audit.can_review(
+ self.user, self.audit.workflow_id, self.audit.workflow_type
+ )
self.assertEqual(r, False)
- @patch('sql.utils.workflow_audit.auth_group_users')
- @patch('sql.utils.workflow_audit.Audit.detail_by_workflow_id')
+ @patch("sql.utils.workflow_audit.auth_group_users")
+ @patch("sql.utils.workflow_audit.Audit.detail_by_workflow_id")
def test_can_review_no_prem(self, _detail_by_workflow_id, _auth_group_users):
"""测试判断用户当前是否是可审核,普通用户无权限"""
- aug = Group.objects.create(name='auth_group')
+ aug = Group.objects.create(name="auth_group")
_detail_by_workflow_id.return_value.current_audit = aug.id
_auth_group_users.return_value.filter.exists = True
- self.audit.workflow_type = WorkflowDict.workflow_type['sqlreview']
+ self.audit.workflow_type = WorkflowDict.workflow_type["sqlreview"]
self.audit.workflow_id = self.wf.id
self.audit.save()
- r = Audit.can_review(self.user, self.audit.workflow_id, self.audit.workflow_type)
+ r = Audit.can_review(
+ self.user, self.audit.workflow_id, self.audit.workflow_type
+ )
self.assertEqual(r, False)
- @patch('sql.utils.workflow_audit.auth_group_users')
- @patch('sql.utils.workflow_audit.Audit.detail_by_workflow_id')
- def test_can_review_no_prem_exception(self, _detail_by_workflow_id, _auth_group_users):
+ @patch("sql.utils.workflow_audit.auth_group_users")
+ @patch("sql.utils.workflow_audit.Audit.detail_by_workflow_id")
+ def test_can_review_no_prem_exception(
+ self, _detail_by_workflow_id, _auth_group_users
+ ):
"""测试判断用户当前是否是可审核,权限组不存在"""
- Group.objects.create(name='auth_group')
+ Group.objects.create(name="auth_group")
_detail_by_workflow_id.side_effect = RuntimeError()
_auth_group_users.return_value.filter.exists = True
- self.audit.workflow_type = WorkflowDict.workflow_type['sqlreview']
+ self.audit.workflow_type = WorkflowDict.workflow_type["sqlreview"]
self.audit.workflow_id = self.wf.id
self.audit.save()
- with self.assertRaisesMessage(Exception, '当前审批auth_group_id不存在,请检查并清洗历史数据'):
- Audit.can_review(self.user, self.audit.workflow_id, self.audit.workflow_type)
+ with self.assertRaisesMessage(Exception, "当前审批auth_group_id不存在,请检查并清洗历史数据"):
+ Audit.can_review(
+ self.user, self.audit.workflow_id, self.audit.workflow_type
+ )
def test_review_info_no_review(self):
"""测试获取当前工单审批流程和当前审核组,无需审批"""
- self.audit.workflow_type = WorkflowDict.workflow_type['sqlreview']
+ self.audit.workflow_type = WorkflowDict.workflow_type["sqlreview"]
self.audit.workflow_id = self.wf.id
- self.audit.audit_auth_groups = ''
- self.audit.current_audit = '-1'
+ self.audit.audit_auth_groups = ""
+ self.audit.current_audit = "-1"
self.audit.save()
- audit_auth_group, current_audit_auth_group = Audit.review_info(self.audit.workflow_id, self.audit.workflow_type)
- self.assertEqual(audit_auth_group, '无需审批')
+ audit_auth_group, current_audit_auth_group = Audit.review_info(
+ self.audit.workflow_id, self.audit.workflow_type
+ )
+ self.assertEqual(audit_auth_group, "无需审批")
self.assertEqual(current_audit_auth_group, None)
def test_review_info(self):
"""测试获取当前工单审批流程和当前审核组,无需审批"""
- aug = Group.objects.create(name='DBA')
- self.audit.workflow_type = WorkflowDict.workflow_type['sqlreview']
+ aug = Group.objects.create(name="DBA")
+ self.audit.workflow_type = WorkflowDict.workflow_type["sqlreview"]
self.audit.workflow_id = self.wf.id
self.audit.audit_auth_groups = str(aug.id)
self.audit.current_audit = str(aug.id)
self.audit.save()
- audit_auth_group, current_audit_auth_group = Audit.review_info(self.audit.workflow_id, self.audit.workflow_type)
- self.assertEqual(audit_auth_group, 'DBA')
- self.assertEqual(current_audit_auth_group, 'DBA')
+ audit_auth_group, current_audit_auth_group = Audit.review_info(
+ self.audit.workflow_id, self.audit.workflow_type
+ )
+ self.assertEqual(audit_auth_group, "DBA")
+ self.assertEqual(current_audit_auth_group, "DBA")
def test_logs(self):
"""测试获取工单日志"""
@@ -1129,37 +1349,43 @@ def test_logs(self):
class TestDataMasking(TestCase):
def setUp(self):
- self.superuser = User.objects.create(username='super', is_superuser=True)
- self.user = User.objects.create(username='user')
- self.ins = Instance.objects.create(instance_name='some_ins', type='slave', db_type='mysql',
- host='some_host',
- port=3306, user='ins_user', password='some_str')
+ self.superuser = User.objects.create(username="super", is_superuser=True)
+ self.user = User.objects.create(username="user")
+ self.ins = Instance.objects.create(
+ instance_name="some_ins",
+ type="slave",
+ db_type="mysql",
+ host="some_host",
+ port=3306,
+ user="ins_user",
+ password="some_str",
+ )
self.sys_config = SysConfig()
self.wf1 = SqlWorkflow.objects.create(
- workflow_name='workflow_name',
+ workflow_name="workflow_name",
group_id=1,
- group_name='group_name',
+ group_name="group_name",
engineer=self.superuser.username,
engineer_display=self.superuser.display,
- audit_auth_groups='audit_auth_groups',
+ audit_auth_groups="audit_auth_groups",
create_time=datetime.datetime.now(),
- status='workflow_review_pass',
+ status="workflow_review_pass",
is_backup=True,
instance=self.ins,
- db_name='db_name',
+ db_name="db_name",
syntax_type=1,
)
DataMaskingRules.objects.create(
- rule_type=1,
- rule_regex='(.{3})(.*)(.{4})',
- hide_group=2)
+ rule_type=1, rule_regex="(.{3})(.*)(.{4})", hide_group=2
+ )
DataMaskingColumns.objects.create(
rule_type=1,
active=True,
instance=self.ins,
- table_schema='archer_test',
- table_name='users',
- column_name='phone')
+ table_schema="archer_test",
+ table_name="users",
+ column_name="phone",
+ )
def tearDown(self):
User.objects.all().delete()
@@ -1168,213 +1394,464 @@ def tearDown(self):
DataMaskingColumns.objects.all().delete()
DataMaskingRules.objects.all().delete()
- @patch('sql.utils.data_masking.GoInceptionEngine')
+ @patch("sql.utils.data_masking.GoInceptionEngine")
def test_data_masking_not_hit_rules(self, _inception):
DataMaskingColumns.objects.all().delete()
DataMaskingRules.objects.all().delete()
_inception.return_value.query_data_masking.return_value = [
- {"index": 0, "field": "phone", "type": "varchar(80)", "table": "users", "schema": "archer_test",
- "alias": "phone"}
+ {
+ "index": 0,
+ "field": "phone",
+ "type": "varchar(80)",
+ "table": "users",
+ "schema": "archer_test",
+ "alias": "phone",
+ }
]
sql = """select phone from users;"""
- rows = (('18888888888',), ('18888888889',), ('18888888810',))
- query_result = ReviewSet(column_list=['phone'], rows=rows, full_sql=sql)
- r = data_masking(self.ins, 'archery', sql, query_result)
+ rows = (("18888888888",), ("18888888889",), ("18888888810",))
+ query_result = ReviewSet(column_list=["phone"], rows=rows, full_sql=sql)
+ r = data_masking(self.ins, "archery", sql, query_result)
print("test_data_masking_not_hit_rules:", r.rows)
self.assertEqual(r, query_result)
- @patch('sql.utils.data_masking.GoInceptionEngine')
+ @patch("sql.utils.data_masking.GoInceptionEngine")
def test_data_masking_hit_rules_not_exists_star(self, _inception):
_inception.return_value.query_data_masking.return_value = [
- {"index": 0, "field": "phone", "type": "varchar(80)", "table": "users", "schema": "archer_test",
- "alias": "phone"}
+ {
+ "index": 0,
+ "field": "phone",
+ "type": "varchar(80)",
+ "table": "users",
+ "schema": "archer_test",
+ "alias": "phone",
+ }
]
sql = """select phone from users;"""
- rows = (('18888888888',), ('18888888889',), ('18888888810',))
- query_result = ReviewSet(column_list=['phone'], rows=rows, full_sql=sql)
- r = data_masking(self.ins, 'archery', sql, query_result)
+ rows = (("18888888888",), ("18888888889",), ("18888888810",))
+ query_result = ReviewSet(column_list=["phone"], rows=rows, full_sql=sql)
+ r = data_masking(self.ins, "archery", sql, query_result)
print("test_data_masking_hit_rules_not_exists_star:", r.rows)
- mask_result_rows = [['188****8888', ], ['188****8889', ], ['188****8810', ]]
+ mask_result_rows = [
+ [
+ "188****8888",
+ ],
+ [
+ "188****8889",
+ ],
+ [
+ "188****8810",
+ ],
+ ]
self.assertEqual(r.rows, mask_result_rows)
- @patch('sql.utils.data_masking.GoInceptionEngine')
+ @patch("sql.utils.data_masking.GoInceptionEngine")
def test_data_masking_hit_rules_exists_star(self, _inception):
"""[*]"""
_inception.return_value.query_data_masking.return_value = [
- {"index": 0, "field": "phone", "type": "varchar(80)", "table": "users", "schema": "archer_test",
- "alias": "phone"}
+ {
+ "index": 0,
+ "field": "phone",
+ "type": "varchar(80)",
+ "table": "users",
+ "schema": "archer_test",
+ "alias": "phone",
+ }
]
sql = """select * from users;"""
- rows = (('18888888888',), ('18888888889',), ('18888888810',))
- query_result = ReviewSet(column_list=['phone'], rows=rows, full_sql=sql)
- r = data_masking(self.ins, 'archery', sql, query_result)
+ rows = (("18888888888",), ("18888888889",), ("18888888810",))
+ query_result = ReviewSet(column_list=["phone"], rows=rows, full_sql=sql)
+ r = data_masking(self.ins, "archery", sql, query_result)
print("test_data_masking_hit_rules_exists_star:", r.rows)
- mask_result_rows = [['188****8888', ], ['188****8889', ], ['188****8810', ]]
+ mask_result_rows = [
+ [
+ "188****8888",
+ ],
+ [
+ "188****8889",
+ ],
+ [
+ "188****8810",
+ ],
+ ]
self.assertEqual(r.rows, mask_result_rows)
- @patch('sql.utils.data_masking.GoInceptionEngine')
+ @patch("sql.utils.data_masking.GoInceptionEngine")
def test_data_masking_hit_rules_star_and_column(self, _inception):
"""[*,column_a]"""
_inception.return_value.query_data_masking.return_value = [
- {"index": 0, "field": "phone", "type": "varchar(80)", "table": "users", "schema": "archer_test",
- "alias": "phone"},
- {"index": 1, "field": "phone", "type": "varchar(80)", "table": "users", "schema": "archer_test",
- "alias": "phone"},
+ {
+ "index": 0,
+ "field": "phone",
+ "type": "varchar(80)",
+ "table": "users",
+ "schema": "archer_test",
+ "alias": "phone",
+ },
+ {
+ "index": 1,
+ "field": "phone",
+ "type": "varchar(80)",
+ "table": "users",
+ "schema": "archer_test",
+ "alias": "phone",
+ },
]
sql = """select *,phone from users;"""
- rows = (('18888888888', '18888888888',),
- ('18888888889', '18888888889',),)
- query_result = ReviewSet(column_list=['phone', 'phone'], rows=rows, full_sql=sql)
- r = data_masking(self.ins, 'archery', sql, query_result)
+ rows = (
+ (
+ "18888888888",
+ "18888888888",
+ ),
+ (
+ "18888888889",
+ "18888888889",
+ ),
+ )
+ query_result = ReviewSet(
+ column_list=["phone", "phone"], rows=rows, full_sql=sql
+ )
+ r = data_masking(self.ins, "archery", sql, query_result)
print("test_data_masking_hit_rules_star_and_column", r.rows)
- mask_result_rows = [['188****8888', '188****8888', ],
- ['188****8889', '188****8889', ]]
+ mask_result_rows = [
+ [
+ "188****8888",
+ "188****8888",
+ ],
+ [
+ "188****8889",
+ "188****8889",
+ ],
+ ]
self.assertEqual(r.rows, mask_result_rows)
- @patch('sql.utils.data_masking.GoInceptionEngine')
+ @patch("sql.utils.data_masking.GoInceptionEngine")
def test_data_masking_hit_rules_column_and_star(self, _inception):
"""[column_a, *]"""
_inception.return_value.query_data_masking.return_value = [
- {"index": 0, "field": "phone", "type": "varchar(80)", "table": "users", "schema": "archer_test",
- "alias": "phone"},
- {"index": 1, "field": "phone", "type": "varchar(80)", "table": "users", "schema": "archer_test",
- "alias": "phone"}
+ {
+ "index": 0,
+ "field": "phone",
+ "type": "varchar(80)",
+ "table": "users",
+ "schema": "archer_test",
+ "alias": "phone",
+ },
+ {
+ "index": 1,
+ "field": "phone",
+ "type": "varchar(80)",
+ "table": "users",
+ "schema": "archer_test",
+ "alias": "phone",
+ },
]
sql = """select phone,* from users;"""
- rows = (('18888888888', '18888888888',),
- ('18888888889', '18888888889',))
- query_result = ReviewSet(column_list=['phone', 'phone'], rows=rows, full_sql=sql)
- r = data_masking(self.ins, 'archery', sql, query_result)
+ rows = (
+ (
+ "18888888888",
+ "18888888888",
+ ),
+ (
+ "18888888889",
+ "18888888889",
+ ),
+ )
+ query_result = ReviewSet(
+ column_list=["phone", "phone"], rows=rows, full_sql=sql
+ )
+ r = data_masking(self.ins, "archery", sql, query_result)
print("test_data_masking_hit_rules_column_and_star", r.rows)
- mask_result_rows = [['188****8888', '188****8888', ],
- ['188****8889', '188****8889', ]]
+ mask_result_rows = [
+ [
+ "188****8888",
+ "188****8888",
+ ],
+ [
+ "188****8889",
+ "188****8889",
+ ],
+ ]
self.assertEqual(r.rows, mask_result_rows)
- @patch('sql.utils.data_masking.GoInceptionEngine')
+ @patch("sql.utils.data_masking.GoInceptionEngine")
def test_data_masking_hit_rules_column_and_star_and_column(self, _inception):
"""[column_a,a.*,column_b]"""
_inception.return_value.query_data_masking.return_value = [
- {"index": 0, "field": "phone", "type": "varchar(80)", "table": "users", "schema": "archer_test",
- "alias": "phone"},
- {"index": 1, "field": "phone", "type": "varchar(80)", "table": "users", "schema": "archer_test",
- "alias": "phone"},
- {"index": 2, "field": "phone", "type": "varchar(80)", "table": "users", "schema": "archer_test",
- "alias": "phone"}
+ {
+ "index": 0,
+ "field": "phone",
+ "type": "varchar(80)",
+ "table": "users",
+ "schema": "archer_test",
+ "alias": "phone",
+ },
+ {
+ "index": 1,
+ "field": "phone",
+ "type": "varchar(80)",
+ "table": "users",
+ "schema": "archer_test",
+ "alias": "phone",
+ },
+ {
+ "index": 2,
+ "field": "phone",
+ "type": "varchar(80)",
+ "table": "users",
+ "schema": "archer_test",
+ "alias": "phone",
+ },
]
sql = """select phone,*,phone from users;"""
- rows = (('18888888888', '18888888888', '18888888888',),
- ('18888888889', '18888888889', '18888888889',))
- query_result = ReviewSet(column_list=['phone', 'phone', 'phone'], rows=rows, full_sql=sql)
- r = data_masking(self.ins, 'archery', sql, query_result)
+ rows = (
+ (
+ "18888888888",
+ "18888888888",
+ "18888888888",
+ ),
+ (
+ "18888888889",
+ "18888888889",
+ "18888888889",
+ ),
+ )
+ query_result = ReviewSet(
+ column_list=["phone", "phone", "phone"], rows=rows, full_sql=sql
+ )
+ r = data_masking(self.ins, "archery", sql, query_result)
print("test_data_masking_hit_rules_column_and_star_and_column", r.rows)
- mask_result_rows = [['188****8888', '188****8888', '188****8888', ],
- ['188****8889', '188****8889', '188****8889', ]]
+ mask_result_rows = [
+ [
+ "188****8888",
+ "188****8888",
+ "188****8888",
+ ],
+ [
+ "188****8889",
+ "188****8889",
+ "188****8889",
+ ],
+ ]
self.assertEqual(r.rows, mask_result_rows)
- @patch('sql.utils.data_masking.GoInceptionEngine')
+ @patch("sql.utils.data_masking.GoInceptionEngine")
def test_data_masking_hit_rules_star_and_column_and_star(self, _inception):
"""[a.*, column_a, b.*]"""
_inception.return_value.query_data_masking.return_value = [
- {"index": 0, "field": "phone", "type": "varchar(80)", "table": "users", "schema": "archer_test",
- "alias": "phone"},
- {"index": 1, "field": "phone", "type": "varchar(80)", "table": "users", "schema": "archer_test",
- "alias": "phone"},
- {"index": 2, "field": "phone", "type": "varchar(80)", "table": "users", "schema": "archer_test",
- "alias": "phone"}
+ {
+ "index": 0,
+ "field": "phone",
+ "type": "varchar(80)",
+ "table": "users",
+ "schema": "archer_test",
+ "alias": "phone",
+ },
+ {
+ "index": 1,
+ "field": "phone",
+ "type": "varchar(80)",
+ "table": "users",
+ "schema": "archer_test",
+ "alias": "phone",
+ },
+ {
+ "index": 2,
+ "field": "phone",
+ "type": "varchar(80)",
+ "table": "users",
+ "schema": "archer_test",
+ "alias": "phone",
+ },
]
sql = """select a.*,phone,a.* from users a;"""
- rows = (('18888888888', '18888888888', '18888888888',),
- ('18888888889', '18888888889', '18888888889',))
- query_result = ReviewSet(column_list=['phone', 'phone', 'phone'], rows=rows, full_sql=sql)
- r = data_masking(self.ins, 'archery', sql, query_result)
+ rows = (
+ (
+ "18888888888",
+ "18888888888",
+ "18888888888",
+ ),
+ (
+ "18888888889",
+ "18888888889",
+ "18888888889",
+ ),
+ )
+ query_result = ReviewSet(
+ column_list=["phone", "phone", "phone"], rows=rows, full_sql=sql
+ )
+ r = data_masking(self.ins, "archery", sql, query_result)
print("test_data_masking_hit_rules_star_and_column_and_star", r.rows)
- mask_result_rows = [['188****8888', '188****8888', '188****8888', ],
- ['188****8889', '188****8889', '188****8889', ]]
+ mask_result_rows = [
+ [
+ "188****8888",
+ "188****8888",
+ "188****8888",
+ ],
+ [
+ "188****8889",
+ "188****8889",
+ "188****8889",
+ ],
+ ]
self.assertEqual(r.rows, mask_result_rows)
- @patch('sql.utils.data_masking.GoInceptionEngine')
+ @patch("sql.utils.data_masking.GoInceptionEngine")
def test_data_masking_concat_function_support(self, _inception):
"""concat_函数支持"""
_inception.return_value.query_data_masking.return_value = [
- {"index": 0, "field": "phone", "type": "varchar(80)", "table": "users", "schema": "archer_test",
- "alias": "concat(phone,1)"}
+ {
+ "index": 0,
+ "field": "phone",
+ "type": "varchar(80)",
+ "table": "users",
+ "schema": "archer_test",
+ "alias": "concat(phone,1)",
+ }
]
sql = """select concat(phone,1) from users;"""
- rows = (('18888888888',), ('18888888889',), ('18888888810',))
- query_result = ReviewSet(column_list=['concat(phone,1)'], rows=rows, full_sql=sql)
- r = data_masking(self.ins, 'archery', sql, query_result)
- mask_result_rows = [['188****8888', ], ['188****8889', ], ['188****8810', ]]
+ rows = (("18888888888",), ("18888888889",), ("18888888810",))
+ query_result = ReviewSet(
+ column_list=["concat(phone,1)"], rows=rows, full_sql=sql
+ )
+ r = data_masking(self.ins, "archery", sql, query_result)
+ mask_result_rows = [
+ [
+ "188****8888",
+ ],
+ [
+ "188****8889",
+ ],
+ [
+ "188****8810",
+ ],
+ ]
print("test_data_masking_concat_function_support", r.rows)
self.assertEqual(r.rows, mask_result_rows)
- @patch('sql.utils.data_masking.GoInceptionEngine')
+ @patch("sql.utils.data_masking.GoInceptionEngine")
def test_data_masking_max_function_support(self, _inception):
"""max_函数支持"""
_inception.return_value.query_data_masking.return_value = [
- {"index": 0, "field": "phone", "type": "varchar(80)", "table": "users", "schema": "archer_test",
- "alias": "max(phone+1)"}
+ {
+ "index": 0,
+ "field": "phone",
+ "type": "varchar(80)",
+ "table": "users",
+ "schema": "archer_test",
+ "alias": "max(phone+1)",
+ }
]
sql = """select max(phone+1) from users;"""
- rows = (('18888888888',), ('18888888889',), ('18888888810',))
- query_result = ReviewSet(column_list=['max(phone+1)'], rows=rows, full_sql=sql)
- mask_result_rows = [['188****8888', ], ['188****8889', ], ['188****8810', ]]
- r = data_masking(self.ins, 'archery', sql, query_result)
+ rows = (("18888888888",), ("18888888889",), ("18888888810",))
+ query_result = ReviewSet(column_list=["max(phone+1)"], rows=rows, full_sql=sql)
+ mask_result_rows = [
+ [
+ "188****8888",
+ ],
+ [
+ "188****8889",
+ ],
+ [
+ "188****8810",
+ ],
+ ]
+ r = data_masking(self.ins, "archery", sql, query_result)
print("test_data_masking_max_function_support", r.rows)
self.assertEqual(r.rows, mask_result_rows)
- @patch('sql.utils.data_masking.GoInceptionEngine')
+ @patch("sql.utils.data_masking.GoInceptionEngine")
def test_data_masking_union_support_keyword(self, _inception):
"""union关键字"""
- self.sys_config.set('query_check', 'true')
+ self.sys_config.set("query_check", "true")
self.sys_config.get_all_config()
_inception.return_value.query_data_masking.return_value = [
- {'index': 0, 'field': 'phone', 'type': 'varchar(80)', 'table': 'users', 'schema': 'archer_test',
- 'alias': 'phone'},
- {'index': 1, 'field': 'phone', 'type': 'varchar(80)', 'table': 'users', 'schema': 'archer_test',
- 'alias': 'phone'}
-
+ {
+ "index": 0,
+ "field": "phone",
+ "type": "varchar(80)",
+ "table": "users",
+ "schema": "archer_test",
+ "alias": "phone",
+ },
+ {
+ "index": 1,
+ "field": "phone",
+ "type": "varchar(80)",
+ "table": "users",
+ "schema": "archer_test",
+ "alias": "phone",
+ },
+ ]
+ sqls = [
+ "select phone from users union select phone from users;",
+ "select phone from users union all select phone from users;",
+ ]
+ rows = (("18888888888",), ("18888888889",), ("18888888810",))
+ mask_result_rows = [
+ [
+ "188****8888",
+ ],
+ [
+ "188****8889",
+ ],
+ [
+ "188****8810",
+ ],
]
- sqls = ["select phone from users union select phone from users;",
- "select phone from users union all select phone from users;"]
- rows = (('18888888888',), ('18888888889',), ('18888888810',))
- mask_result_rows = [['188****8888', ], ['188****8889', ], ['188****8810', ]]
for sql in sqls:
- query_result = ReviewSet(column_list=['phone'], rows=rows, full_sql=sql)
- r = data_masking(self.ins, 'archery', sql, query_result)
+ query_result = ReviewSet(column_list=["phone"], rows=rows, full_sql=sql)
+ r = data_masking(self.ins, "archery", sql, query_result)
print("test_data_masking_union_support_keyword", r.rows)
self.assertEqual(r.rows, mask_result_rows)
def test_brute_mask(self):
sql = """select * from users;"""
- rows = (('18888888888',), ('18888888889',), ('18888888810',))
- query_result = ReviewSet(column_list=['phone'], rows=rows, full_sql=sql)
+ rows = (("18888888888",), ("18888888889",), ("18888888810",))
+ query_result = ReviewSet(column_list=["phone"], rows=rows, full_sql=sql)
r = brute_mask(self.ins, query_result)
- mask_result_rows = [('188****8888',), ('188****8889',), ('188****8810',)]
+ mask_result_rows = [("188****8888",), ("188****8889",), ("188****8810",)]
self.assertEqual(r.rows, mask_result_rows)
def test_simple_column_mask(self):
sql = """select * from users;"""
- rows = (('18888888888',), ('18888888889',), ('18888888810',))
- query_result = ReviewSet(column_list=['phone'], rows=rows, full_sql=sql)
+ rows = (("18888888888",), ("18888888889",), ("18888888810",))
+ query_result = ReviewSet(column_list=["phone"], rows=rows, full_sql=sql)
r = simple_column_mask(self.ins, query_result)
- mask_result_rows = [('188****8888',), ('188****8889',), ('188****8810',)]
+ mask_result_rows = [("188****8888",), ("188****8889",), ("188****8810",)]
self.assertEqual(r.rows, mask_result_rows)
class TestResourceGroup(TestCase):
def setUp(self):
self.sys_config = SysConfig()
- self.user = User.objects.create(username='test_user', display='中文显示', is_active=True)
- self.su = User.objects.create(username='s_user', display='中文显示', is_active=True, is_superuser=True)
- self.ins1 = Instance.objects.create(instance_name='some_ins1', type='slave', db_type='mysql',
- host='some_host',
- port=3306, user='ins_user', password='some_str')
- self.ins2 = Instance.objects.create(instance_name='some_ins2', type='slave', db_type='mysql',
- host='some_host',
- port=3306, user='ins_user', password='some_str')
- self.rgp1 = ResourceGroup.objects.create(group_name='group1')
- self.rgp2 = ResourceGroup.objects.create(group_name='group2')
- self.agp = Group.objects.create(name='auth_group')
+ self.user = User.objects.create(
+ username="test_user", display="中文显示", is_active=True
+ )
+ self.su = User.objects.create(
+ username="s_user", display="中文显示", is_active=True, is_superuser=True
+ )
+ self.ins1 = Instance.objects.create(
+ instance_name="some_ins1",
+ type="slave",
+ db_type="mysql",
+ host="some_host",
+ port=3306,
+ user="ins_user",
+ password="some_str",
+ )
+ self.ins2 = Instance.objects.create(
+ instance_name="some_ins2",
+ type="slave",
+ db_type="mysql",
+ host="some_host",
+ port=3306,
+ user="ins_user",
+ password="some_str",
+ )
+ self.rgp1 = ResourceGroup.objects.create(group_name="group1")
+ self.rgp2 = ResourceGroup.objects.create(group_name="group2")
+ self.agp = Group.objects.create(name="auth_group")
def tearDown(self):
self.sys_config.purge()
@@ -1426,5 +1903,7 @@ def test_auth_group_users(self):
# 用户关联资源组
self.user.resource_group.add(self.rgp1)
# 获取资源组内关联指定权限组的用户
- users = auth_group_users(auth_group_names=[self.agp.name], group_id=self.rgp1.group_id)
+ users = auth_group_users(
+ auth_group_names=[self.agp.name], group_id=self.rgp1.group_id
+ )
self.assertIn(self.user, users)
diff --git a/sql/utils/workflow_audit.py b/sql/utils/workflow_audit.py
index 2dbdfec5e5..b684328218 100644
--- a/sql/utils/workflow_audit.py
+++ b/sql/utils/workflow_audit.py
@@ -5,8 +5,17 @@
from sql.utils.resource_group import user_groups, auth_group_users
from sql.utils.sql_review import is_auto_review
from common.utils.const import WorkflowDict
-from sql.models import WorkflowAudit, WorkflowAuditDetail, WorkflowAuditSetting, WorkflowLog, ResourceGroup, \
- SqlWorkflow, QueryPrivilegesApply, Users, ArchiveConfig
+from sql.models import (
+ WorkflowAudit,
+ WorkflowAuditDetail,
+ WorkflowAuditSetting,
+ WorkflowLog,
+ ResourceGroup,
+ SqlWorkflow,
+ QueryPrivilegesApply,
+ Users,
+ ArchiveConfig,
+)
from common.config import SysConfig
@@ -14,17 +23,20 @@ class Audit(object):
# 新增工单审核
@staticmethod
def add(workflow_type, workflow_id):
- result = {'status': 0, 'msg': '', 'data': []}
+ result = {"status": 0, "msg": "", "data": []}
# 检查是否已存在待审核数据
- workflow_info = WorkflowAudit.objects.filter(workflow_type=workflow_type, workflow_id=workflow_id,
- current_status=WorkflowDict.workflow_status['audit_wait'])
+ workflow_info = WorkflowAudit.objects.filter(
+ workflow_type=workflow_type,
+ workflow_id=workflow_id,
+ current_status=WorkflowDict.workflow_status["audit_wait"],
+ )
if len(workflow_info) >= 1:
- result['msg'] = '该工单当前状态为待审核,请勿重复提交'
- raise Exception(result['msg'])
+ result["msg"] = "该工单当前状态为待审核,请勿重复提交"
+ raise Exception(result["msg"])
# 获取工单信息
- if workflow_type == WorkflowDict.workflow_type['query']:
+ if workflow_type == WorkflowDict.workflow_type["query"]:
workflow_detail = QueryPrivilegesApply.objects.get(apply_id=workflow_id)
workflow_title = workflow_detail.title
group_id = workflow_detail.group_id
@@ -32,8 +44,8 @@ def add(workflow_type, workflow_id):
create_user = workflow_detail.user_name
create_user_display = workflow_detail.user_display
audit_auth_groups = workflow_detail.audit_auth_groups
- workflow_remark = ''
- elif workflow_type == WorkflowDict.workflow_type['sqlreview']:
+ workflow_remark = ""
+ elif workflow_type == WorkflowDict.workflow_type["sqlreview"]:
workflow_detail = SqlWorkflow.objects.get(pk=workflow_id)
workflow_title = workflow_detail.workflow_name
group_id = workflow_detail.group_id
@@ -41,8 +53,8 @@ def add(workflow_type, workflow_id):
create_user = workflow_detail.engineer
create_user_display = workflow_detail.engineer_display
audit_auth_groups = workflow_detail.audit_auth_groups
- workflow_remark = ''
- elif workflow_type == WorkflowDict.workflow_type['archive']:
+ workflow_remark = ""
+ elif workflow_type == WorkflowDict.workflow_type["archive"]:
workflow_detail = ArchiveConfig.objects.get(pk=workflow_id)
workflow_title = workflow_detail.title
group_id = workflow_detail.resource_group.group_id
@@ -50,25 +62,25 @@ def add(workflow_type, workflow_id):
create_user = workflow_detail.user_name
create_user_display = workflow_detail.user_display
audit_auth_groups = workflow_detail.audit_auth_groups
- workflow_remark = ''
+ workflow_remark = ""
else:
- result['msg'] = '工单类型不存在'
- raise Exception(result['msg'])
+ result["msg"] = "工单类型不存在"
+ raise Exception(result["msg"])
# 校验是否配置审批流程
- if audit_auth_groups == '':
- result['msg'] = '审批流程不能为空,请先配置审批流程'
- raise Exception(result['msg'])
+ if audit_auth_groups == "":
+ result["msg"] = "审批流程不能为空,请先配置审批流程"
+ raise Exception(result["msg"])
else:
- audit_auth_groups_list = audit_auth_groups.split(',')
+ audit_auth_groups_list = audit_auth_groups.split(",")
# 判断是否无需审核,并且修改审批人为空
- if SysConfig().get('auto_review', False):
- if workflow_type == WorkflowDict.workflow_type['sqlreview']:
+ if SysConfig().get("auto_review", False):
+ if workflow_type == WorkflowDict.workflow_type["sqlreview"]:
if is_auto_review(workflow_id):
sql_workflow = SqlWorkflow.objects.get(id=int(workflow_id))
- sql_workflow.audit_auth_groups = '无需审批'
- sql_workflow.status = 'workflow_review_pass'
+ sql_workflow.audit_auth_groups = "无需审批"
+ sql_workflow.status = "workflow_review_pass"
sql_workflow.save()
audit_auth_groups_list = None
@@ -82,23 +94,28 @@ def add(workflow_type, workflow_id):
audit_detail.workflow_type = workflow_type
audit_detail.workflow_title = workflow_title
audit_detail.workflow_remark = workflow_remark
- audit_detail.audit_auth_groups = ''
- audit_detail.current_audit = '-1'
- audit_detail.next_audit = '-1'
- audit_detail.current_status = WorkflowDict.workflow_status['audit_success'] # 审核通过
+ audit_detail.audit_auth_groups = ""
+ audit_detail.current_audit = "-1"
+ audit_detail.next_audit = "-1"
+ audit_detail.current_status = WorkflowDict.workflow_status[
+ "audit_success"
+ ] # 审核通过
audit_detail.create_user = create_user
audit_detail.create_user_display = create_user_display
audit_detail.save()
- result['data'] = {'workflow_status': WorkflowDict.workflow_status['audit_success']}
- result['msg'] = '无审核配置,直接审核通过'
+ result["data"] = {
+ "workflow_status": WorkflowDict.workflow_status["audit_success"]
+ }
+ result["msg"] = "无审核配置,直接审核通过"
# 增加工单日志
- Audit.add_log(audit_id=audit_detail.audit_id,
- operation_type=0,
- operation_type_desc='提交',
- operation_info='无需审批,系统直接审核通过',
- operator=audit_detail.create_user,
- operator_display=audit_detail.create_user_display
- )
+ Audit.add_log(
+ audit_id=audit_detail.audit_id,
+ operation_type=0,
+ operation_type_desc="提交",
+ operation_info="无需审批,系统直接审核通过",
+ operator=audit_detail.create_user,
+ operator_display=audit_detail.create_user_display,
+ )
else:
# 向审核主表插入待审核数据
audit_detail = WorkflowAudit()
@@ -108,158 +125,191 @@ def add(workflow_type, workflow_id):
audit_detail.workflow_type = workflow_type
audit_detail.workflow_title = workflow_title
audit_detail.workflow_remark = workflow_remark
- audit_detail.audit_auth_groups = ','.join(audit_auth_groups_list)
+ audit_detail.audit_auth_groups = ",".join(audit_auth_groups_list)
audit_detail.current_audit = audit_auth_groups_list[0]
# 判断有无下级审核
if len(audit_auth_groups_list) == 1:
- audit_detail.next_audit = '-1'
+ audit_detail.next_audit = "-1"
else:
audit_detail.next_audit = audit_auth_groups_list[1]
- audit_detail.current_status = WorkflowDict.workflow_status['audit_wait']
+ audit_detail.current_status = WorkflowDict.workflow_status["audit_wait"]
audit_detail.create_user = create_user
audit_detail.create_user_display = create_user_display
audit_detail.save()
- result['data'] = {'workflow_status': WorkflowDict.workflow_status['audit_wait']}
+ result["data"] = {
+ "workflow_status": WorkflowDict.workflow_status["audit_wait"]
+ }
# 增加工单日志
- audit_auth_group, current_audit_auth_group = Audit.review_info(workflow_id, workflow_type)
- Audit.add_log(audit_id=audit_detail.audit_id,
- operation_type=0,
- operation_type_desc='提交',
- operation_info='等待审批,审批流程:{}'.format(audit_auth_group),
- operator=audit_detail.create_user,
- operator_display=audit_detail.create_user_display
- )
+ audit_auth_group, current_audit_auth_group = Audit.review_info(
+ workflow_id, workflow_type
+ )
+ Audit.add_log(
+ audit_id=audit_detail.audit_id,
+ operation_type=0,
+ operation_type_desc="提交",
+ operation_info="等待审批,审批流程:{}".format(audit_auth_group),
+ operator=audit_detail.create_user,
+ operator_display=audit_detail.create_user_display,
+ )
# 增加审核id
- result['data']['audit_id'] = audit_detail.audit_id
+ result["data"]["audit_id"] = audit_detail.audit_id
# 返回添加结果
return result
# 工单审核
@staticmethod
def audit(audit_id, audit_status, audit_user, audit_remark):
- result = {'status': 0, 'msg': 'ok', 'data': 0}
+ result = {"status": 0, "msg": "ok", "data": 0}
audit_detail = WorkflowAudit.objects.get(audit_id=audit_id)
# 不同审核状态
- if audit_status == WorkflowDict.workflow_status['audit_success']:
+ if audit_status == WorkflowDict.workflow_status["audit_success"]:
# 判断当前工单是否为待审核状态
- if audit_detail.current_status != WorkflowDict.workflow_status['audit_wait']:
- result['msg'] = '工单不是待审核状态,请返回刷新'
- raise Exception(result['msg'])
+ if (
+ audit_detail.current_status
+ != WorkflowDict.workflow_status["audit_wait"]
+ ):
+ result["msg"] = "工单不是待审核状态,请返回刷新"
+ raise Exception(result["msg"])
# 判断是否还有下一级审核
- if audit_detail.next_audit == '-1':
+ if audit_detail.next_audit == "-1":
# 更新主表审核状态为审核通过
audit_result = WorkflowAudit()
audit_result.audit_id = audit_id
- audit_result.current_audit = '-1'
- audit_result.current_status = WorkflowDict.workflow_status['audit_success']
- audit_result.save(update_fields=['current_audit', 'current_status'])
+ audit_result.current_audit = "-1"
+ audit_result.current_status = WorkflowDict.workflow_status[
+ "audit_success"
+ ]
+ audit_result.save(update_fields=["current_audit", "current_status"])
else:
# 更新主表审核下级审核组和当前审核组
audit_result = WorkflowAudit()
audit_result.audit_id = audit_id
- audit_result.current_status = WorkflowDict.workflow_status['audit_wait']
+ audit_result.current_status = WorkflowDict.workflow_status["audit_wait"]
audit_result.current_audit = audit_detail.next_audit
# 判断后续是否还有下下一级审核组
- audit_auth_groups_list = audit_detail.audit_auth_groups.split(',')
+ audit_auth_groups_list = audit_detail.audit_auth_groups.split(",")
for index, auth_group in enumerate(audit_auth_groups_list):
if auth_group == audit_detail.next_audit:
# 无下下级审核组
if index == len(audit_auth_groups_list) - 1:
- audit_result.next_audit = '-1'
+ audit_result.next_audit = "-1"
break
# 存在下下级审核组
else:
audit_result.next_audit = audit_auth_groups_list[index + 1]
- audit_result.save(update_fields=['current_audit', 'next_audit', 'current_status'])
+ audit_result.save(
+ update_fields=["current_audit", "next_audit", "current_status"]
+ )
# 插入审核明细数据
audit_detail_result = WorkflowAuditDetail()
audit_detail_result.audit_id = audit_id
audit_detail_result.audit_user = audit_user
- audit_detail_result.audit_status = WorkflowDict.workflow_status['audit_success']
+ audit_detail_result.audit_status = WorkflowDict.workflow_status[
+ "audit_success"
+ ]
audit_detail_result.audit_time = timezone.now()
audit_detail_result.remark = audit_remark
audit_detail_result.save()
# 增加工单日志
- audit_auth_group, current_audit_auth_group = Audit.review_info(audit_detail.workflow_id,
- audit_detail.workflow_type)
- Audit.add_log(audit_id=audit_id,
- operation_type=1,
- operation_type_desc='审批通过',
- operation_info="审批备注:{},下级审批:{}".format(audit_remark, current_audit_auth_group),
- operator=audit_user,
- operator_display=Users.objects.get(username=audit_user).display
- )
- elif audit_status == WorkflowDict.workflow_status['audit_reject']:
+ audit_auth_group, current_audit_auth_group = Audit.review_info(
+ audit_detail.workflow_id, audit_detail.workflow_type
+ )
+ Audit.add_log(
+ audit_id=audit_id,
+ operation_type=1,
+ operation_type_desc="审批通过",
+ operation_info="审批备注:{},下级审批:{}".format(
+ audit_remark, current_audit_auth_group
+ ),
+ operator=audit_user,
+ operator_display=Users.objects.get(username=audit_user).display,
+ )
+ elif audit_status == WorkflowDict.workflow_status["audit_reject"]:
# 判断当前工单是否为待审核状态
- if audit_detail.current_status != WorkflowDict.workflow_status['audit_wait']:
- result['msg'] = '工单不是待审核状态,请返回刷新'
- raise Exception(result['msg'])
+ if (
+ audit_detail.current_status
+ != WorkflowDict.workflow_status["audit_wait"]
+ ):
+ result["msg"] = "工单不是待审核状态,请返回刷新"
+ raise Exception(result["msg"])
# 更新主表审核状态
audit_result = WorkflowAudit()
audit_result.audit_id = audit_id
- audit_result.current_audit = '-1'
- audit_result.next_audit = '-1'
- audit_result.current_status = WorkflowDict.workflow_status['audit_reject']
- audit_result.save(update_fields=['current_audit', 'next_audit', 'current_status'])
+ audit_result.current_audit = "-1"
+ audit_result.next_audit = "-1"
+ audit_result.current_status = WorkflowDict.workflow_status["audit_reject"]
+ audit_result.save(
+ update_fields=["current_audit", "next_audit", "current_status"]
+ )
# 插入审核明细数据
audit_detail_result = WorkflowAuditDetail()
audit_detail_result.audit_id = audit_id
audit_detail_result.audit_user = audit_user
- audit_detail_result.audit_status = WorkflowDict.workflow_status['audit_reject']
+ audit_detail_result.audit_status = WorkflowDict.workflow_status[
+ "audit_reject"
+ ]
audit_detail_result.audit_time = timezone.now()
audit_detail_result.remark = audit_remark
audit_detail_result.save()
# 增加工单日志
- Audit.add_log(audit_id=audit_id,
- operation_type=2,
- operation_type_desc='审批不通过',
- operation_info="审批备注:{}".format(audit_remark),
- operator=audit_user,
- operator_display=Users.objects.get(username=audit_user).display
- )
- elif audit_status == WorkflowDict.workflow_status['audit_abort']:
+ Audit.add_log(
+ audit_id=audit_id,
+ operation_type=2,
+ operation_type_desc="审批不通过",
+ operation_info="审批备注:{}".format(audit_remark),
+ operator=audit_user,
+ operator_display=Users.objects.get(username=audit_user).display,
+ )
+ elif audit_status == WorkflowDict.workflow_status["audit_abort"]:
# 判断当前工单是否为待审核/审核通过状态
- if audit_detail.current_status != WorkflowDict.workflow_status['audit_wait'] and \
- audit_detail.current_status != WorkflowDict.workflow_status['audit_success']:
- result['msg'] = '工单不是待审核态/审核通过状态,请返回刷新'
- raise Exception(result['msg'])
+ if (
+ audit_detail.current_status
+ != WorkflowDict.workflow_status["audit_wait"]
+ and audit_detail.current_status
+ != WorkflowDict.workflow_status["audit_success"]
+ ):
+ result["msg"] = "工单不是待审核态/审核通过状态,请返回刷新"
+ raise Exception(result["msg"])
# 更新主表审核状态
audit_result = WorkflowAudit()
audit_result.audit_id = audit_id
- audit_result.next_audit = '-1'
- audit_result.current_status = WorkflowDict.workflow_status['audit_abort']
- audit_result.save(update_fields=['current_status', 'next_audit'])
+ audit_result.next_audit = "-1"
+ audit_result.current_status = WorkflowDict.workflow_status["audit_abort"]
+ audit_result.save(update_fields=["current_status", "next_audit"])
# 插入审核明细数据
audit_detail_result = WorkflowAuditDetail()
audit_detail_result.audit_id = audit_id
audit_detail_result.audit_user = audit_user
- audit_detail_result.audit_status = WorkflowDict.workflow_status['audit_abort']
+ audit_detail_result.audit_status = WorkflowDict.workflow_status[
+ "audit_abort"
+ ]
audit_detail_result.audit_time = timezone.now()
audit_detail_result.remark = audit_remark
audit_detail_result.save()
# 增加工单日志
- Audit.add_log(audit_id=audit_id,
- operation_type=3,
- operation_type_desc='审批取消',
- operation_info="取消原因:{}".format(audit_remark),
- operator=audit_user,
- operator_display=Users.objects.get(username=audit_user).display
- )
+ Audit.add_log(
+ audit_id=audit_id,
+ operation_type=3,
+ operation_type_desc="审批取消",
+ operation_info="取消原因:{}".format(audit_remark),
+ operator=audit_user,
+ operator_display=Users.objects.get(username=audit_user).display,
+ )
else:
- result['msg'] = '审核异常'
- raise Exception(result['msg'])
+ result["msg"] = "审核异常"
+ raise Exception(result["msg"])
# 返回审核结果
- result['data'] = {'workflow_status': audit_result.current_status}
+ result["data"] = {"workflow_status": audit_result.current_status}
return result
# 获取用户待办工单数量
@@ -275,9 +325,10 @@ def todo(user):
auth_group_ids = [group.id for group in Group.objects.filter(user=user)]
return WorkflowAudit.objects.filter(
- current_status=WorkflowDict.workflow_status['audit_wait'],
+ current_status=WorkflowDict.workflow_status["audit_wait"],
group_id__in=group_ids,
- current_audit__in=auth_group_ids).count()
+ current_audit__in=auth_group_ids,
+ ).count()
# 通过审核id获取审核信息
@staticmethod
@@ -291,7 +342,9 @@ def detail(audit_id):
@staticmethod
def detail_by_workflow_id(workflow_id, workflow_type):
try:
- return WorkflowAudit.objects.get(workflow_id=workflow_id, workflow_type=workflow_type)
+ return WorkflowAudit.objects.get(
+ workflow_id=workflow_id, workflow_type=workflow_type
+ )
except Exception:
return None
@@ -299,7 +352,9 @@ def detail_by_workflow_id(workflow_id, workflow_type):
@staticmethod
def settings(group_id, workflow_type):
try:
- return WorkflowAuditSetting.objects.get(workflow_type=workflow_type, group_id=group_id).audit_auth_groups
+ return WorkflowAuditSetting.objects.get(
+ workflow_type=workflow_type, group_id=group_id
+ ).audit_auth_groups
except Exception:
return None
@@ -307,10 +362,12 @@ def settings(group_id, workflow_type):
@staticmethod
def change_settings(group_id, workflow_type, audit_auth_groups):
try:
- WorkflowAuditSetting.objects.get(workflow_type=workflow_type, group_id=group_id)
- WorkflowAuditSetting.objects.filter(workflow_type=workflow_type,
- group_id=group_id
- ).update(audit_auth_groups=audit_auth_groups)
+ WorkflowAuditSetting.objects.get(
+ workflow_type=workflow_type, group_id=group_id
+ )
+ WorkflowAuditSetting.objects.filter(
+ workflow_type=workflow_type, group_id=group_id
+ ).update(audit_auth_groups=audit_auth_groups)
except Exception:
inset = WorkflowAuditSetting()
inset.group_id = group_id
@@ -322,59 +379,84 @@ def change_settings(group_id, workflow_type, audit_auth_groups):
# 判断用户当前是否是可审核
@staticmethod
def can_review(user, workflow_id, workflow_type):
- audit_info = WorkflowAudit.objects.get(workflow_id=workflow_id, workflow_type=workflow_type)
+ audit_info = WorkflowAudit.objects.get(
+ workflow_id=workflow_id, workflow_type=workflow_type
+ )
group_id = audit_info.group_id
result = False
# 只有待审核状态数据才可以审核
- if audit_info.current_status == WorkflowDict.workflow_status['audit_wait']:
+ if audit_info.current_status == WorkflowDict.workflow_status["audit_wait"]:
try:
- auth_group_id = Audit.detail_by_workflow_id(workflow_id, workflow_type).current_audit
+ auth_group_id = Audit.detail_by_workflow_id(
+ workflow_id, workflow_type
+ ).current_audit
audit_auth_group = Group.objects.get(id=auth_group_id).name
except Exception:
- raise Exception('当前审批auth_group_id不存在,请检查并清洗历史数据')
- if auth_group_users([audit_auth_group], group_id).filter(id=user.id).exists() or user.is_superuser == 1:
+ raise Exception("当前审批auth_group_id不存在,请检查并清洗历史数据")
+ if (
+ auth_group_users([audit_auth_group], group_id)
+ .filter(id=user.id)
+ .exists()
+ or user.is_superuser == 1
+ ):
if workflow_type == 1:
- if user.has_perm('sql.query_review'):
+ if user.has_perm("sql.query_review"):
result = True
elif workflow_type == 2:
- if user.has_perm('sql.sql_review'):
+ if user.has_perm("sql.sql_review"):
result = True
elif workflow_type == 3:
- if user.has_perm('sql.archive_review'):
+ if user.has_perm("sql.archive_review"):
result = True
return result
# 获取当前工单审批流程和当前审核组
@staticmethod
def review_info(workflow_id, workflow_type):
- audit_info = WorkflowAudit.objects.get(workflow_id=workflow_id, workflow_type=workflow_type)
- if audit_info.audit_auth_groups == '':
- audit_auth_group = '无需审批'
+ audit_info = WorkflowAudit.objects.get(
+ workflow_id=workflow_id, workflow_type=workflow_type
+ )
+ if audit_info.audit_auth_groups == "":
+ audit_auth_group = "无需审批"
else:
try:
- audit_auth_group = '->'.join([Group.objects.get(id=auth_group_id).name for auth_group_id in
- audit_info.audit_auth_groups.split(',')])
+ audit_auth_group = "->".join(
+ [
+ Group.objects.get(id=auth_group_id).name
+ for auth_group_id in audit_info.audit_auth_groups.split(",")
+ ]
+ )
except Exception:
audit_auth_group = audit_info.audit_auth_groups
- if audit_info.current_audit == '-1':
+ if audit_info.current_audit == "-1":
current_audit_auth_group = None
else:
try:
- current_audit_auth_group = Group.objects.get(id=audit_info.current_audit).name
+ current_audit_auth_group = Group.objects.get(
+ id=audit_info.current_audit
+ ).name
except Exception:
current_audit_auth_group = audit_info.current_audit
return audit_auth_group, current_audit_auth_group
# 新增工单日志
@staticmethod
- def add_log(audit_id, operation_type, operation_type_desc, operation_info, operator, operator_display):
- WorkflowLog(audit_id=audit_id,
- operation_type=operation_type,
- operation_type_desc=operation_type_desc,
- operation_info=operation_info,
- operator=operator,
- operator_display=operator_display
- ).save()
+ def add_log(
+ audit_id,
+ operation_type,
+ operation_type_desc,
+ operation_info,
+ operator,
+ operator_display,
+ ):
+ WorkflowLog(
+ audit_id=audit_id,
+ operation_type=operation_type,
+ operation_type_desc=operation_type_desc,
+ operation_info=operation_info,
+ operator=operator,
+ operator_display=operator_display,
+ ).save()
# 获取工单日志
@staticmethod
diff --git a/sql/views.py b/sql/views.py
index 8ea4b6df95..15bda95d3b 100644
--- a/sql/views.py
+++ b/sql/views.py
@@ -20,67 +20,99 @@
from sql.engines.models import ReviewResult, ReviewSet
from sql.utils.tasks import task_info
-from .models import Users, SqlWorkflow, QueryPrivileges, ResourceGroup, \
- QueryPrivilegesApply, Config, SQL_WORKFLOW_CHOICES, InstanceTag, Instance, \
- QueryLog, ArchiveConfig, AuditEntry, TwoFactorAuthConfig
+from .models import (
+ Users,
+ SqlWorkflow,
+ QueryPrivileges,
+ ResourceGroup,
+ QueryPrivilegesApply,
+ Config,
+ SQL_WORKFLOW_CHOICES,
+ InstanceTag,
+ Instance,
+ QueryLog,
+ ArchiveConfig,
+ AuditEntry,
+ TwoFactorAuthConfig,
+)
from sql.utils.workflow_audit import Audit
-from sql.utils.sql_review import can_execute, can_timingtask, can_cancel, can_view, can_rollback
+from sql.utils.sql_review import (
+ can_execute,
+ can_timingtask,
+ can_cancel,
+ can_view,
+ can_rollback,
+)
from common.utils.const import Const, WorkflowDict
from sql.utils.resource_group import user_groups, user_instances, auth_group_users
import logging
-logger = logging.getLogger('default')
+logger = logging.getLogger("default")
def index(request):
- index_path_url = SysConfig().get('index_path_url', 'sqlworkflow')
+ index_path_url = SysConfig().get("index_path_url", "sqlworkflow")
return HttpResponseRedirect(f"/{index_path_url.strip('/')}/")
def login(request):
"""登录页面"""
if request.user and request.user.is_authenticated:
- return HttpResponseRedirect('/')
+ return HttpResponseRedirect("/")
- return render(request, 'login.html', context={'sign_up_enabled': SysConfig().get('sign_up_enabled')})
+ return render(
+ request,
+ "login.html",
+ context={"sign_up_enabled": SysConfig().get("sign_up_enabled")},
+ )
def twofa(request):
"""2fa认证页面"""
if request.user.is_authenticated:
- return HttpResponseRedirect('/')
+ return HttpResponseRedirect("/")
- username = request.session.get('user')
+ username = request.session.get("user")
if username:
- verify_mode = request.session.get('verify_mode')
+ verify_mode = request.session.get("verify_mode")
twofa_enabled = TwoFactorAuthConfig.objects.filter(username=username)
user_auth_types = [twofa.auth_type for twofa in twofa_enabled]
auth_types = []
for user_auth_type in user_auth_types:
auth_type = {}
- auth_type['code'] = user_auth_type
- if user_auth_type == 'totp':
- auth_type['display'] = 'Google身份验证器'
- elif user_auth_type == 'sms':
- auth_type['display'] = '短信验证码'
+ auth_type["code"] = user_auth_type
+ if user_auth_type == "totp":
+ auth_type["display"] = "Google身份验证器"
+ elif user_auth_type == "sms":
+ auth_type["display"] = "短信验证码"
auth_types.append(auth_type)
- if 'sms' in user_auth_types:
- phone = TwoFactorAuthConfig.objects.get(username=username, auth_type='sms').phone
+ if "sms" in user_auth_types:
+ phone = TwoFactorAuthConfig.objects.get(
+ username=username, auth_type="sms"
+ ).phone
else:
phone = 0
else:
- return HttpResponseRedirect('/login/')
+ return HttpResponseRedirect("/login/")
- return render(request, '2fa.html', context={'verify_mode': verify_mode, 'auth_types': auth_types,
- 'username': username, 'phone': phone})
+ return render(
+ request,
+ "2fa.html",
+ context={
+ "verify_mode": verify_mode,
+ "auth_types": auth_types,
+ "username": username,
+ "phone": phone,
+ },
+ )
-@permission_required('sql.menu_dashboard', raise_exception=True)
+@permission_required("sql.menu_dashboard", raise_exception=True)
def dashboard(request):
"""dashboard页面"""
- return render(request, 'dashboard.html')
+ return render(request, "dashboard.html")
def sqlworkflow(request):
@@ -89,28 +121,42 @@ def sqlworkflow(request):
# 过滤筛选项的数据
filter_dict = dict()
# 管理员,可查看所有工单
- if user.is_superuser or user.has_perm('sql.audit_user'):
+ if user.is_superuser or user.has_perm("sql.audit_user"):
pass
# 非管理员,拥有审核权限、资源组粒度执行权限的,可以查看组内所有工单
- elif user.has_perm('sql.sql_review') or user.has_perm('sql.sql_execute_for_resource_group'):
+ elif user.has_perm("sql.sql_review") or user.has_perm(
+ "sql.sql_execute_for_resource_group"
+ ):
# 先获取用户所在资源组列表
group_list = user_groups(user)
group_ids = [group.group_id for group in group_list]
- filter_dict['group_id__in'] = group_ids
+ filter_dict["group_id__in"] = group_ids
# 其他人只能查看自己提交的工单
else:
- filter_dict['engineer'] = user.username
- instance_id = SqlWorkflow.objects.filter(**filter_dict).values('instance_id').distinct()
- instance = Instance.objects.filter(pk__in=instance_id).order_by(Convert('instance_name', 'gbk').asc())
- resource_group_id = SqlWorkflow.objects.filter(**filter_dict).values('group_id').distinct()
+ filter_dict["engineer"] = user.username
+ instance_id = (
+ SqlWorkflow.objects.filter(**filter_dict).values("instance_id").distinct()
+ )
+ instance = Instance.objects.filter(pk__in=instance_id).order_by(
+ Convert("instance_name", "gbk").asc()
+ )
+ resource_group_id = (
+ SqlWorkflow.objects.filter(**filter_dict).values("group_id").distinct()
+ )
resource_group = ResourceGroup.objects.filter(group_id__in=resource_group_id)
- return render(request, 'sqlworkflow.html',
- {'status_list': SQL_WORKFLOW_CHOICES,
- 'instance': instance, 'resource_group': resource_group})
+ return render(
+ request,
+ "sqlworkflow.html",
+ {
+ "status_list": SQL_WORKFLOW_CHOICES,
+ "instance": instance,
+ "resource_group": resource_group,
+ },
+ )
-@permission_required('sql.sql_submit', raise_exception=True)
+@permission_required("sql.sql_submit", raise_exception=True)
def submit_sql(request):
"""提交SQL的页面"""
user = request.user
@@ -124,11 +170,16 @@ def submit_sql(request):
archer_config = SysConfig()
# 主动创建标签
- InstanceTag.objects.get_or_create(tag_code='can_write', defaults={'tag_name': '支持上线', 'active': True})
+ InstanceTag.objects.get_or_create(
+ tag_code="can_write", defaults={"tag_name": "支持上线", "active": True}
+ )
- context = {'active_user': active_user, 'group_list': group_list,
- 'enable_backup_switch': archer_config.get('enable_backup_switch')}
- return render(request, 'sqlsubmit.html', context)
+ context = {
+ "active_user": active_user,
+ "group_list": group_list,
+ "enable_backup_switch": archer_config.get("enable_backup_switch"),
+ }
+ return render(request, "sqlsubmit.html", context)
def detail(request, workflow_id):
@@ -137,7 +188,7 @@ def detail(request, workflow_id):
if not can_view(request.user, workflow_id):
raise PermissionDenied
# 自动审批不通过的不需要获取下列信息
- if workflow_detail.status != 'workflow_autoreviewwrong':
+ if workflow_detail.status != "workflow_autoreviewwrong":
# 获取当前审批和审批流程
audit_auth_group, current_audit_auth_group = Audit.review_info(workflow_id, 2)
@@ -154,22 +205,30 @@ def detail(request, workflow_id):
# 获取审核日志
try:
- audit_detail = Audit.detail_by_workflow_id(workflow_id=workflow_id,
- workflow_type=WorkflowDict.workflow_type['sqlreview'])
+ audit_detail = Audit.detail_by_workflow_id(
+ workflow_id=workflow_id,
+ workflow_type=WorkflowDict.workflow_type["sqlreview"],
+ )
audit_id = audit_detail.audit_id
- last_operation_info = Audit.logs(audit_id=audit_id).latest('id').operation_info
+ last_operation_info = (
+ Audit.logs(audit_id=audit_id).latest("id").operation_info
+ )
# 等待审批的展示当前全部审批人
- if workflow_detail.status == 'workflow_manreviewing':
+ if workflow_detail.status == "workflow_manreviewing":
auth_group_name = Group.objects.get(id=audit_detail.current_audit).name
- current_audit_users = auth_group_users([auth_group_name], audit_detail.group_id)
- current_audit_users_display = [user.display for user in current_audit_users]
- last_operation_info += ',当前审批人:' + ','.join(current_audit_users_display)
+ current_audit_users = auth_group_users(
+ [auth_group_name], audit_detail.group_id
+ )
+ current_audit_users_display = [
+ user.display for user in current_audit_users
+ ]
+ last_operation_info += ",当前审批人:" + ",".join(current_audit_users_display)
except Exception as e:
- logger.debug(f'无审核日志记录,错误信息{e}')
- last_operation_info = ''
+ logger.debug(f"无审核日志记录,错误信息{e}")
+ last_operation_info = ""
else:
- audit_auth_group = '系统自动驳回'
- current_audit_auth_group = '系统自动驳回'
+ audit_auth_group = "系统自动驳回"
+ current_audit_auth_group = "系统自动驳回"
is_can_review = False
is_can_execute = False
is_can_timingtask = False
@@ -178,35 +237,44 @@ def detail(request, workflow_id):
last_operation_info = None
# 获取定时执行任务信息
- if workflow_detail.status == 'workflow_timingtask':
- job_id = Const.workflowJobprefix['sqlreview'] + '-' + str(workflow_id)
+ if workflow_detail.status == "workflow_timingtask":
+ job_id = Const.workflowJobprefix["sqlreview"] + "-" + str(workflow_id)
job = task_info(job_id)
if job:
run_date = job.next_run
else:
- run_date = ''
+ run_date = ""
else:
- run_date = ''
+ run_date = ""
# 获取是否开启手工执行确认
- manual = SysConfig().get('manual')
-
- context = {'workflow_detail': workflow_detail, 'last_operation_info': last_operation_info,
- 'is_can_review': is_can_review, 'is_can_execute': is_can_execute, 'is_can_timingtask': is_can_timingtask,
- 'is_can_cancel': is_can_cancel, 'is_can_rollback': is_can_rollback, 'audit_auth_group': audit_auth_group,
- 'manual': manual, 'current_audit_auth_group': current_audit_auth_group, 'run_date': run_date}
- return render(request, 'detail.html', context)
+ manual = SysConfig().get("manual")
+
+ context = {
+ "workflow_detail": workflow_detail,
+ "last_operation_info": last_operation_info,
+ "is_can_review": is_can_review,
+ "is_can_execute": is_can_execute,
+ "is_can_timingtask": is_can_timingtask,
+ "is_can_cancel": is_can_cancel,
+ "is_can_rollback": is_can_rollback,
+ "audit_auth_group": audit_auth_group,
+ "manual": manual,
+ "current_audit_auth_group": current_audit_auth_group,
+ "run_date": run_date,
+ }
+ return render(request, "detail.html", context)
def rollback(request):
"""展示回滚的SQL页面"""
- workflow_id = request.GET.get('workflow_id')
+ workflow_id = request.GET.get("workflow_id")
if not can_rollback(request.user, workflow_id):
raise PermissionDenied
- download = request.GET.get('download')
- if workflow_id == '' or workflow_id is None:
- context = {'errMsg': 'workflow_id参数为空.'}
- return render(request, 'error.html', context)
+ download = request.GET.get("download")
+ if workflow_id == "" or workflow_id is None:
+ context = {"errMsg": "workflow_id参数为空."}
+ return render(request, "error.html", context)
workflow = SqlWorkflow.objects.get(id=int(workflow_id))
# 直接下载回滚语句
@@ -216,55 +284,66 @@ def rollback(request):
list_backup_sql = query_engine.get_rollback(workflow=workflow)
except Exception as msg:
logger.error(traceback.format_exc())
- context = {'errMsg': msg}
- return render(request, 'error.html', context)
+ context = {"errMsg": msg}
+ return render(request, "error.html", context)
# 获取数据,存入目录
- path = os.path.join(settings.BASE_DIR, 'downloads/rollback')
+ path = os.path.join(settings.BASE_DIR, "downloads/rollback")
os.makedirs(path, exist_ok=True)
- file_name = f'{path}/rollback_{workflow_id}.sql'
- with open(file_name, 'w') as f:
+ file_name = f"{path}/rollback_{workflow_id}.sql"
+ with open(file_name, "w") as f:
for sql in list_backup_sql:
- f.write(f'/*{sql[0]}*/\n{sql[1]}\n')
+ f.write(f"/*{sql[0]}*/\n{sql[1]}\n")
# 返回
- response = FileResponse(open(file_name, 'rb'))
- response['Content-Type'] = 'application/octet-stream'
- response['Content-Disposition'] = f'attachment;filename="rollback_{workflow_id}.sql"'
+ response = FileResponse(open(file_name, "rb"))
+ response["Content-Type"] = "application/octet-stream"
+ response[
+ "Content-Disposition"
+ ] = f'attachment;filename="rollback_{workflow_id}.sql"'
return response
# 异步获取,并在页面展示,如果数据量大加载会缓慢
else:
rollback_workflow_name = f"【回滚工单】原工单Id:{workflow_id} ,{workflow.workflow_name}"
- context = {'workflow_detail': workflow, 'rollback_workflow_name': rollback_workflow_name}
- return render(request, 'rollback.html', context)
+ context = {
+ "workflow_detail": workflow,
+ "rollback_workflow_name": rollback_workflow_name,
+ }
+ return render(request, "rollback.html", context)
-@permission_required('sql.menu_sqlanalyze', raise_exception=True)
+@permission_required("sql.menu_sqlanalyze", raise_exception=True)
def sqlanalyze(request):
"""SQL分析页面"""
- return render(request, 'sqlanalyze.html')
+ return render(request, "sqlanalyze.html")
-@permission_required('sql.menu_query', raise_exception=True)
+@permission_required("sql.menu_query", raise_exception=True)
def sqlquery(request):
"""SQL在线查询页面"""
# 主动创建标签
- InstanceTag.objects.get_or_create(tag_code='can_read', defaults={'tag_name': '支持查询', 'active': True})
+ InstanceTag.objects.get_or_create(
+ tag_code="can_read", defaults={"tag_name": "支持查询", "active": True}
+ )
# 收藏语句
user = request.user
- favorites = QueryLog.objects.filter(username=user.username, favorite=True).values('id', 'alias')
- can_download = 1 if user.has_perm('sql.query_download') or user.is_superuser else 0
- return render(request, 'sqlquery.html', {'favorites': favorites, 'can_download':can_download})
+ favorites = QueryLog.objects.filter(username=user.username, favorite=True).values(
+ "id", "alias"
+ )
+ can_download = 1 if user.has_perm("sql.query_download") or user.is_superuser else 0
+ return render(
+ request, "sqlquery.html", {"favorites": favorites, "can_download": can_download}
+ )
-@permission_required('sql.menu_queryapplylist', raise_exception=True)
+@permission_required("sql.menu_queryapplylist", raise_exception=True)
def queryapplylist(request):
"""查询权限申请列表页面"""
user = request.user
# 获取资源组
group_list = user_groups(user)
- context = {'group_list': group_list}
- return render(request, 'queryapplylist.html', context)
+ context = {"group_list": group_list}
+ return render(request, "queryapplylist.html", context)
def queryapplydetail(request, apply_id):
@@ -278,100 +357,114 @@ def queryapplydetail(request, apply_id):
# 获取审核日志
if workflow_detail.status == 2:
try:
- audit_id = Audit.detail_by_workflow_id(workflow_id=apply_id, workflow_type=1).audit_id
- last_operation_info = Audit.logs(audit_id=audit_id).latest('id').operation_info
+ audit_id = Audit.detail_by_workflow_id(
+ workflow_id=apply_id, workflow_type=1
+ ).audit_id
+ last_operation_info = (
+ Audit.logs(audit_id=audit_id).latest("id").operation_info
+ )
except Exception as e:
- logger.debug(f'无审核日志记录,错误信息{e}')
- last_operation_info = ''
+ logger.debug(f"无审核日志记录,错误信息{e}")
+ last_operation_info = ""
else:
- last_operation_info = ''
+ last_operation_info = ""
- context = {'workflow_detail': workflow_detail, 'audit_auth_group': audit_auth_group,
- 'last_operation_info': last_operation_info, 'current_audit_auth_group': current_audit_auth_group,
- 'is_can_review': is_can_review}
- return render(request, 'queryapplydetail.html', context)
+ context = {
+ "workflow_detail": workflow_detail,
+ "audit_auth_group": audit_auth_group,
+ "last_operation_info": last_operation_info,
+ "current_audit_auth_group": current_audit_auth_group,
+ "is_can_review": is_can_review,
+ }
+ return render(request, "queryapplydetail.html", context)
def queryuserprivileges(request):
"""查询权限管理页面"""
# 获取所有用户
- user_list = QueryPrivileges.objects.filter(is_deleted=0).values('user_display').distinct()
- context = {'user_list': user_list}
- return render(request, 'queryuserprivileges.html', context)
+ user_list = (
+ QueryPrivileges.objects.filter(is_deleted=0).values("user_display").distinct()
+ )
+ context = {"user_list": user_list}
+ return render(request, "queryuserprivileges.html", context)
-@permission_required('sql.menu_sqladvisor', raise_exception=True)
+@permission_required("sql.menu_sqladvisor", raise_exception=True)
def sqladvisor(request):
"""SQL优化工具页面"""
- return render(request, 'sqladvisor.html')
+ return render(request, "sqladvisor.html")
-@permission_required('sql.menu_slowquery', raise_exception=True)
+@permission_required("sql.menu_slowquery", raise_exception=True)
def slowquery(request):
"""SQL慢日志页面"""
- return render(request, 'slowquery.html')
+ return render(request, "slowquery.html")
-@permission_required('sql.menu_instance', raise_exception=True)
+@permission_required("sql.menu_instance", raise_exception=True)
def instance(request):
"""实例管理页面"""
# 获取实例标签
tags = InstanceTag.objects.filter(active=True)
- return render(request, 'instance.html', {'tags': tags})
+ return render(request, "instance.html", {"tags": tags})
-@permission_required('sql.menu_instance_account', raise_exception=True)
+@permission_required("sql.menu_instance_account", raise_exception=True)
def instanceaccount(request):
"""实例账号管理页面"""
- return render(request, 'instanceaccount.html')
+ return render(request, "instanceaccount.html")
-@permission_required('sql.menu_database', raise_exception=True)
+@permission_required("sql.menu_database", raise_exception=True)
def database(request):
"""实例数据库管理页面"""
# 获取所有有效用户,通知对象
active_user = Users.objects.filter(is_active=1)
- return render(request, 'database.html', {"active_user": active_user})
+ return render(request, "database.html", {"active_user": active_user})
-@permission_required('sql.menu_dbdiagnostic', raise_exception=True)
+@permission_required("sql.menu_dbdiagnostic", raise_exception=True)
def dbdiagnostic(request):
"""会话管理页面"""
- return render(request, 'dbdiagnostic.html')
+ return render(request, "dbdiagnostic.html")
-@permission_required('sql.menu_data_dictionary', raise_exception=True)
+@permission_required("sql.menu_data_dictionary", raise_exception=True)
def data_dictionary(request):
"""数据字典页面"""
- return render(request, 'data_dictionary.html', locals())
+ return render(request, "data_dictionary.html", locals())
-@permission_required('sql.menu_param', raise_exception=True)
+@permission_required("sql.menu_param", raise_exception=True)
def instance_param(request):
"""实例参数管理页面"""
- return render(request, 'param.html')
+ return render(request, "param.html")
-@permission_required('sql.menu_my2sql', raise_exception=True)
+@permission_required("sql.menu_my2sql", raise_exception=True)
def my2sql(request):
"""my2sql页面"""
- return render(request, 'my2sql.html')
+ return render(request, "my2sql.html")
-@permission_required('sql.menu_schemasync', raise_exception=True)
+@permission_required("sql.menu_schemasync", raise_exception=True)
def schemasync(request):
"""数据库差异对比页面"""
- return render(request, 'schemasync.html')
+ return render(request, "schemasync.html")
-@permission_required('sql.menu_archive', raise_exception=True)
+@permission_required("sql.menu_archive", raise_exception=True)
def archive(request):
"""归档列表页面"""
# 获取资源组
group_list = user_groups(request.user)
- ins_list = user_instances(request.user, db_type=['mysql']).order_by(Convert('instance_name', 'gbk').asc())
- return render(request, 'archive.html', {'group_list': group_list, 'ins_list': ins_list})
+ ins_list = user_instances(request.user, db_type=["mysql"]).order_by(
+ Convert("instance_name", "gbk").asc()
+ )
+ return render(
+ request, "archive.html", {"group_list": group_list, "ins_list": ins_list}
+ )
def archive_detail(request, id):
@@ -382,24 +475,32 @@ def archive_detail(request, id):
audit_auth_group, current_audit_auth_group = Audit.review_info(id, 3)
is_can_review = Audit.can_review(request.user, id, 3)
except Exception as e:
- logger.debug(f'归档配置{id}无审核信息,{e}')
+ logger.debug(f"归档配置{id}无审核信息,{e}")
audit_auth_group, current_audit_auth_group = None, None
is_can_review = False
# 获取审核日志
if archive_config.status == 2:
try:
- audit_id = Audit.detail_by_workflow_id(workflow_id=id, workflow_type=3).audit_id
- last_operation_info = Audit.logs(audit_id=audit_id).latest('id').operation_info
+ audit_id = Audit.detail_by_workflow_id(
+ workflow_id=id, workflow_type=3
+ ).audit_id
+ last_operation_info = (
+ Audit.logs(audit_id=audit_id).latest("id").operation_info
+ )
except Exception as e:
- logger.debug(f'归档配置{id}无审核日志记录,错误信息{e}')
- last_operation_info = ''
+ logger.debug(f"归档配置{id}无审核日志记录,错误信息{e}")
+ last_operation_info = ""
else:
- last_operation_info = ''
+ last_operation_info = ""
- context = {'archive_config': archive_config, 'audit_auth_group': audit_auth_group,
- 'last_operation_info': last_operation_info, 'current_audit_auth_group': current_audit_auth_group,
- 'is_can_review': is_can_review}
- return render(request, 'archivedetail.html', context)
+ context = {
+ "archive_config": archive_config,
+ "audit_auth_group": audit_auth_group,
+ "last_operation_info": last_operation_info,
+ "current_audit_auth_group": current_audit_auth_group,
+ "is_can_review": is_can_review,
+ }
+ return render(request, "archivedetail.html", context)
@superuser_required
@@ -412,29 +513,35 @@ def config(request):
# 获取所有实例标签
instance_tags = InstanceTag.objects.all()
# 支持自动审核的数据库类型
- db_type = ['mysql', 'oracle', 'mongo', 'clickhouse']
+ db_type = ["mysql", "oracle", "mongo", "clickhouse"]
# 获取所有配置项
- all_config = Config.objects.all().values('item', 'value')
+ all_config = Config.objects.all().values("item", "value")
sys_config = {}
for items in all_config:
- sys_config[items['item']] = items['value']
+ sys_config[items["item"]] = items["value"]
- context = {'group_list': group_list, 'auth_group_list': auth_group_list, 'instance_tags': instance_tags,
- 'db_type': db_type, 'config': sys_config, 'WorkflowDict': WorkflowDict}
- return render(request, 'config.html', context)
+ context = {
+ "group_list": group_list,
+ "auth_group_list": auth_group_list,
+ "instance_tags": instance_tags,
+ "db_type": db_type,
+ "config": sys_config,
+ "WorkflowDict": WorkflowDict,
+ }
+ return render(request, "config.html", context)
@superuser_required
def group(request):
"""资源组管理页面"""
- return render(request, 'group.html')
+ return render(request, "group.html")
@superuser_required
def groupmgmt(request, group_id):
"""资源组组关系管理页面"""
group = ResourceGroup.objects.get(group_id=group_id)
- return render(request, 'groupmgmt.html', {'group': group})
+ return render(request, "groupmgmt.html", {"group": group})
def workflows(request):
@@ -448,38 +555,46 @@ def workflowsdetail(request, audit_id):
audit_detail = Audit.detail(audit_id)
if not audit_detail:
raise Http404("不存在对应的工单记录")
- if audit_detail.workflow_type == WorkflowDict.workflow_type['query']:
- return HttpResponseRedirect(reverse('sql:queryapplydetail', args=(audit_detail.workflow_id,)))
- elif audit_detail.workflow_type == WorkflowDict.workflow_type['sqlreview']:
- return HttpResponseRedirect(reverse('sql:detail', args=(audit_detail.workflow_id,)))
- elif audit_detail.workflow_type == WorkflowDict.workflow_type['archive']:
- return HttpResponseRedirect(reverse('sql:archive_detail', args=(audit_detail.workflow_id,)))
-
-
-@permission_required('sql.menu_document', raise_exception=True)
+ if audit_detail.workflow_type == WorkflowDict.workflow_type["query"]:
+ return HttpResponseRedirect(
+ reverse("sql:queryapplydetail", args=(audit_detail.workflow_id,))
+ )
+ elif audit_detail.workflow_type == WorkflowDict.workflow_type["sqlreview"]:
+ return HttpResponseRedirect(
+ reverse("sql:detail", args=(audit_detail.workflow_id,))
+ )
+ elif audit_detail.workflow_type == WorkflowDict.workflow_type["archive"]:
+ return HttpResponseRedirect(
+ reverse("sql:archive_detail", args=(audit_detail.workflow_id,))
+ )
+
+
+@permission_required("sql.menu_document", raise_exception=True)
def dbaprinciples(request):
"""SQL文档页面"""
# 读取MD文件
- file = os.path.join(settings.BASE_DIR, 'docs/docs.md')
- with open(file, 'r', encoding="utf-8") as f:
- md = f.read().replace('\n', '\\n')
- return render(request, 'dbaprinciples.html', {'md': md})
+ file = os.path.join(settings.BASE_DIR, "docs/docs.md")
+ with open(file, "r", encoding="utf-8") as f:
+ md = f.read().replace("\n", "\\n")
+ return render(request, "dbaprinciples.html", {"md": md})
-@permission_required('sql.audit_user', raise_exception=True)
+@permission_required("sql.audit_user", raise_exception=True)
def audit(request):
"""通用审计日志页面"""
- _action_types = AuditEntry.objects.values_list('action').distinct()
- action_types = [ i[0] for i in _action_types ]
- return render(request, 'audit.html', {'action_types': action_types})
+ _action_types = AuditEntry.objects.values_list("action").distinct()
+ action_types = [i[0] for i in _action_types]
+ return render(request, "audit.html", {"action_types": action_types})
-@permission_required('sql.audit_user', raise_exception=True)
+@permission_required("sql.audit_user", raise_exception=True)
def audit_sqlquery(request):
"""SQL在线查询页面审计"""
user = request.user
- favorites = QueryLog.objects.filter(username=user.username, favorite=True).values('id', 'alias')
- return render(request, 'audit_sqlquery.html', {'favorites': favorites})
+ favorites = QueryLog.objects.filter(username=user.username, favorite=True).values(
+ "id", "alias"
+ )
+ return render(request, "audit_sqlquery.html", {"favorites": favorites})
def audit_sqlworkflow(request):
@@ -488,22 +603,34 @@ def audit_sqlworkflow(request):
# 过滤筛选项的数据
filter_dict = dict()
# 管理员,可查看所有工单
- if user.is_superuser or user.has_perm('sql.audit_user'):
+ if user.is_superuser or user.has_perm("sql.audit_user"):
pass
# 非管理员,拥有审核权限、资源组粒度执行权限的,可以查看组内所有工单
- elif user.has_perm('sql.sql_review') or user.has_perm('sql.sql_execute_for_resource_group'):
+ elif user.has_perm("sql.sql_review") or user.has_perm(
+ "sql.sql_execute_for_resource_group"
+ ):
# 先获取用户所在资源组列表
group_list = user_groups(user)
group_ids = [group.group_id for group in group_list]
- filter_dict['group_id__in'] = group_ids
+ filter_dict["group_id__in"] = group_ids
# 其他人只能查看自己提交的工单
else:
- filter_dict['engineer'] = user.username
- instance_id = SqlWorkflow.objects.filter(**filter_dict).values('instance_id').distinct()
+ filter_dict["engineer"] = user.username
+ instance_id = (
+ SqlWorkflow.objects.filter(**filter_dict).values("instance_id").distinct()
+ )
instance = Instance.objects.filter(pk__in=instance_id)
- resource_group_id = SqlWorkflow.objects.filter(**filter_dict).values('group_id').distinct()
+ resource_group_id = (
+ SqlWorkflow.objects.filter(**filter_dict).values("group_id").distinct()
+ )
resource_group = ResourceGroup.objects.filter(group_id__in=resource_group_id)
- return render(request, 'audit_sqlworkflow.html',
- {'status_list': SQL_WORKFLOW_CHOICES,
- 'instance': instance, 'resource_group': resource_group})
+ return render(
+ request,
+ "audit_sqlworkflow.html",
+ {
+ "status_list": SQL_WORKFLOW_CHOICES,
+ "instance": instance,
+ "resource_group": resource_group,
+ },
+ )
diff --git a/sql_api/api_instance.py b/sql_api/api_instance.py
index 0229c19483..4cb50b51ba 100644
--- a/sql_api/api_instance.py
+++ b/sql_api/api_instance.py
@@ -1,8 +1,14 @@
from rest_framework import views, generics, status, serializers
from rest_framework.response import Response
from drf_spectacular.utils import extend_schema
-from .serializers import InstanceSerializer, InstanceDetailSerializer, TunnelSerializer, \
- AliyunRdsSerializer, InstanceResourceSerializer, InstanceResourceListSerializer
+from .serializers import (
+ InstanceSerializer,
+ InstanceDetailSerializer,
+ TunnelSerializer,
+ AliyunRdsSerializer,
+ InstanceResourceSerializer,
+ InstanceResourceListSerializer,
+)
from .pagination import CustomizedPagination
from .filters import InstanceFilter
from sql.models import Instance, Tunnel, AliyunRdsConfig
@@ -15,28 +21,31 @@ class InstanceList(generics.ListAPIView):
"""
列出所有的instance或者创建一个新的instance配置
"""
+
filterset_class = InstanceFilter
pagination_class = CustomizedPagination
serializer_class = InstanceSerializer
- queryset = Instance.objects.all().order_by('id')
-
- @extend_schema(summary="实例清单",
- request=InstanceSerializer,
- responses={200: InstanceSerializer},
- description="列出所有实例(过滤,分页)")
+ queryset = Instance.objects.all().order_by("id")
+
+ @extend_schema(
+ summary="实例清单",
+ request=InstanceSerializer,
+ responses={200: InstanceSerializer},
+ description="列出所有实例(过滤,分页)",
+ )
def get(self, request):
instances = self.filter_queryset(self.queryset)
page_ins = self.paginate_queryset(queryset=instances)
serializer_obj = self.get_serializer(page_ins, many=True)
- data = {
- 'data': serializer_obj.data
- }
+ data = {"data": serializer_obj.data}
return self.get_paginated_response(data)
- @extend_schema(summary="创建实例",
- request=InstanceSerializer,
- responses={201: InstanceSerializer},
- description="创建一个实例配置")
+ @extend_schema(
+ summary="创建实例",
+ request=InstanceSerializer,
+ responses={201: InstanceSerializer},
+ description="创建一个实例配置",
+ )
def post(self, request):
serializer = InstanceSerializer(data=request.data)
if serializer.is_valid():
@@ -49,6 +58,7 @@ class InstanceDetail(views.APIView):
"""
实例操作
"""
+
serializer_class = InstanceDetailSerializer
def get_object(self, pk):
@@ -57,10 +67,12 @@ def get_object(self, pk):
except Instance.DoesNotExist:
raise Http404
- @extend_schema(summary="更新实例",
- request=InstanceDetailSerializer,
- responses={200: InstanceDetailSerializer},
- description="更新一个实例配置")
+ @extend_schema(
+ summary="更新实例",
+ request=InstanceDetailSerializer,
+ responses={200: InstanceDetailSerializer},
+ description="更新一个实例配置",
+ )
def put(self, request, pk):
instance = self.get_object(pk)
serializer = InstanceDetailSerializer(instance, data=request.data)
@@ -69,8 +81,7 @@ def put(self, request, pk):
return Response(serializer.data)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
- @extend_schema(summary="删除实例",
- description="删除一个实例配置")
+ @extend_schema(summary="删除实例", description="删除一个实例配置")
def delete(self, request, pk):
instance = self.get_object(pk)
instance.delete()
@@ -81,27 +92,30 @@ class TunnelList(generics.ListAPIView):
"""
列出所有的tunnel或者创建一个新的tunnel配置
"""
+
pagination_class = CustomizedPagination
serializer_class = TunnelSerializer
- queryset = Tunnel.objects.all().order_by('id')
-
- @extend_schema(summary="隧道清单",
- request=TunnelSerializer,
- responses={200: TunnelSerializer},
- description="列出所有隧道(过滤,分页)")
+ queryset = Tunnel.objects.all().order_by("id")
+
+ @extend_schema(
+ summary="隧道清单",
+ request=TunnelSerializer,
+ responses={200: TunnelSerializer},
+ description="列出所有隧道(过滤,分页)",
+ )
def get(self, request):
tunnels = self.filter_queryset(self.queryset)
page_tunnels = self.paginate_queryset(queryset=tunnels)
serializer_obj = self.get_serializer(page_tunnels, many=True)
- data = {
- 'data': serializer_obj.data
- }
+ data = {"data": serializer_obj.data}
return self.get_paginated_response(data)
- @extend_schema(summary="创建隧道",
- request=TunnelSerializer,
- responses={201: TunnelSerializer},
- description="创建一个隧道配置")
+ @extend_schema(
+ summary="创建隧道",
+ request=TunnelSerializer,
+ responses={201: TunnelSerializer},
+ description="创建一个隧道配置",
+ )
def post(self, request):
serializer = TunnelSerializer(data=request.data)
if serializer.is_valid():
@@ -114,27 +128,30 @@ class AliyunRdsList(generics.ListAPIView):
"""
列出所有的AliyunRDS或者创建一个新的AliyunRDS配置
"""
+
pagination_class = CustomizedPagination
serializer_class = AliyunRdsSerializer
- queryset = AliyunRdsConfig.objects.all().select_related('ak').order_by('id')
-
- @extend_schema(summary="AliyunRDS清单",
- request=AliyunRdsSerializer,
- responses={200: AliyunRdsSerializer},
- description="列出所有AliyunRDS(过滤,分页)")
+ queryset = AliyunRdsConfig.objects.all().select_related("ak").order_by("id")
+
+ @extend_schema(
+ summary="AliyunRDS清单",
+ request=AliyunRdsSerializer,
+ responses={200: AliyunRdsSerializer},
+ description="列出所有AliyunRDS(过滤,分页)",
+ )
def get(self, request):
aliyunrds = self.filter_queryset(self.queryset)
page_rds = self.paginate_queryset(queryset=aliyunrds)
serializer_obj = self.get_serializer(page_rds, many=True)
- data = {
- 'data': serializer_obj.data
- }
+ data = {"data": serializer_obj.data}
return self.get_paginated_response(data)
- @extend_schema(summary="创建AliyunRDS",
- request=AliyunRdsSerializer,
- responses={201: AliyunRdsSerializer},
- description="创建一个AliyunRDS配置(包含一个CloudAccessKey)")
+ @extend_schema(
+ summary="创建AliyunRDS",
+ request=AliyunRdsSerializer,
+ responses={201: AliyunRdsSerializer},
+ description="创建一个AliyunRDS配置(包含一个CloudAccessKey)",
+ )
def post(self, request):
serializer = AliyunRdsSerializer(data=request.data)
if serializer.is_valid():
@@ -148,47 +165,54 @@ class InstanceResource(views.APIView):
获取实例内的资源信息,database、schema、table、column
"""
- @extend_schema(summary="实例资源",
- request=InstanceResourceSerializer,
- responses={200: InstanceResourceListSerializer},
- description="获取实例内的资源信息")
+ @extend_schema(
+ summary="实例资源",
+ request=InstanceResourceSerializer,
+ responses={200: InstanceResourceListSerializer},
+ description="获取实例内的资源信息",
+ )
def post(self, request):
# 参数验证
serializer = InstanceResourceSerializer(data=request.data)
if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
- instance_id = request.data['instance_id']
- resource_type = request.data['resource_type']
- db_name = request.data['db_name'] if 'db_name' in request.data.keys() else ''
- schema_name = request.data['schema_name'] if 'schema_name' in request.data.keys() else ''
- tb_name = request.data['tb_name'] if 'tb_name' in request.data.keys() else ''
+ instance_id = request.data["instance_id"]
+ resource_type = request.data["resource_type"]
+ db_name = request.data["db_name"] if "db_name" in request.data.keys() else ""
+ schema_name = (
+ request.data["schema_name"] if "schema_name" in request.data.keys() else ""
+ )
+ tb_name = request.data["tb_name"] if "tb_name" in request.data.keys() else ""
instance = Instance.objects.get(pk=instance_id)
try:
# escape
- db_name = MySQLdb.escape_string(db_name).decode('utf-8')
- schema_name = MySQLdb.escape_string(schema_name).decode('utf-8')
- tb_name = MySQLdb.escape_string(tb_name).decode('utf-8')
+ db_name = MySQLdb.escape_string(db_name).decode("utf-8")
+ schema_name = MySQLdb.escape_string(schema_name).decode("utf-8")
+ tb_name = MySQLdb.escape_string(tb_name).decode("utf-8")
query_engine = get_engine(instance=instance)
- if resource_type == 'database':
+ if resource_type == "database":
resource = query_engine.get_all_databases()
- elif resource_type == 'schema' and db_name:
+ elif resource_type == "schema" and db_name:
resource = query_engine.get_all_schemas(db_name=db_name)
- elif resource_type == 'table' and db_name:
- resource = query_engine.get_all_tables(db_name=db_name, schema_name=schema_name)
- elif resource_type == 'column' and db_name and tb_name:
- resource = query_engine.get_all_columns_by_tb(db_name=db_name, tb_name=tb_name, schema_name=schema_name)
+ elif resource_type == "table" and db_name:
+ resource = query_engine.get_all_tables(
+ db_name=db_name, schema_name=schema_name
+ )
+ elif resource_type == "column" and db_name and tb_name:
+ resource = query_engine.get_all_columns_by_tb(
+ db_name=db_name, tb_name=tb_name, schema_name=schema_name
+ )
else:
- raise serializers.ValidationError({'errors': '不支持的资源类型或者参数不完整!'})
+ raise serializers.ValidationError({"errors": "不支持的资源类型或者参数不完整!"})
except Exception as msg:
- raise serializers.ValidationError({'errors': msg})
+ raise serializers.ValidationError({"errors": msg})
else:
if resource.error:
- raise serializers.ValidationError({'errors': resource.error})
+ raise serializers.ValidationError({"errors": resource.error})
else:
- resource = {'count': len(resource.rows),
- 'result': resource.rows}
+ resource = {"count": len(resource.rows), "result": resource.rows}
serializer_obj = InstanceResourceListSerializer(resource)
return Response(serializer_obj.data)
diff --git a/sql_api/api_user.py b/sql_api/api_user.py
index 5aec0a4151..d307066cc6 100644
--- a/sql_api/api_user.py
+++ b/sql_api/api_user.py
@@ -1,9 +1,17 @@
from rest_framework import views, generics, status, permissions
from rest_framework.response import Response
from drf_spectacular.utils import extend_schema
-from .serializers import UserSerializer, UserDetailSerializer, GroupSerializer, \
- ResourceGroupSerializer, TwoFASerializer, UserAuthSerializer, TwoFAVerifySerializer, \
- TwoFASaveSerializer, TwoFAStateSerializer
+from .serializers import (
+ UserSerializer,
+ UserDetailSerializer,
+ GroupSerializer,
+ ResourceGroupSerializer,
+ TwoFASerializer,
+ UserAuthSerializer,
+ TwoFAVerifySerializer,
+ TwoFASaveSerializer,
+ TwoFAStateSerializer,
+)
from .pagination import CustomizedPagination
from .permissions import IsOwner
from .filters import UserFilter
@@ -23,28 +31,31 @@ class UserList(generics.ListAPIView):
"""
列出所有的user或者创建一个新的user
"""
+
filterset_class = UserFilter
pagination_class = CustomizedPagination
serializer_class = UserSerializer
- queryset = Users.objects.all().order_by('id')
-
- @extend_schema(summary="用户清单",
- request=UserSerializer,
- responses={200: UserSerializer},
- description="列出所有用户(过滤,分页)")
+ queryset = Users.objects.all().order_by("id")
+
+ @extend_schema(
+ summary="用户清单",
+ request=UserSerializer,
+ responses={200: UserSerializer},
+ description="列出所有用户(过滤,分页)",
+ )
def get(self, request):
users = self.filter_queryset(self.queryset)
page_user = self.paginate_queryset(queryset=users)
serializer_obj = self.get_serializer(page_user, many=True)
- data = {
- 'data': serializer_obj.data
- }
+ data = {"data": serializer_obj.data}
return self.get_paginated_response(data)
- @extend_schema(summary="创建用户",
- request=UserSerializer,
- responses={201: UserSerializer},
- description="创建一个用户")
+ @extend_schema(
+ summary="创建用户",
+ request=UserSerializer,
+ responses={201: UserSerializer},
+ description="创建一个用户",
+ )
def post(self, request):
serializer = UserSerializer(data=request.data)
if serializer.is_valid():
@@ -57,6 +68,7 @@ class UserDetail(views.APIView):
"""
用户操作
"""
+
serializer_class = UserDetailSerializer
def get_object(self, pk):
@@ -65,10 +77,12 @@ def get_object(self, pk):
except Users.DoesNotExist:
raise Http404
- @extend_schema(summary="更新用户",
- request=UserDetailSerializer,
- responses={200: UserDetailSerializer},
- description="更新一个用户")
+ @extend_schema(
+ summary="更新用户",
+ request=UserDetailSerializer,
+ responses={200: UserDetailSerializer},
+ description="更新一个用户",
+ )
def put(self, request, pk):
user = self.get_object(pk)
serializer = UserDetailSerializer(user, data=request.data)
@@ -77,8 +91,7 @@ def put(self, request, pk):
return Response(serializer.data)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
- @extend_schema(summary="删除用户",
- description="删除一个用户")
+ @extend_schema(summary="删除用户", description="删除一个用户")
def delete(self, request, pk):
user = self.get_object(pk)
user.delete()
@@ -89,27 +102,30 @@ class GroupList(generics.ListAPIView):
"""
列出所有的group或者创建一个新的group
"""
+
pagination_class = CustomizedPagination
serializer_class = GroupSerializer
- queryset = Group.objects.all().order_by('id')
-
- @extend_schema(summary="用户组清单",
- request=GroupSerializer,
- responses={200: GroupSerializer},
- description="列出所有用户组(过滤,分页)")
+ queryset = Group.objects.all().order_by("id")
+
+ @extend_schema(
+ summary="用户组清单",
+ request=GroupSerializer,
+ responses={200: GroupSerializer},
+ description="列出所有用户组(过滤,分页)",
+ )
def get(self, request):
groups = self.filter_queryset(self.queryset)
page_groups = self.paginate_queryset(queryset=groups)
serializer_obj = self.get_serializer(page_groups, many=True)
- data = {
- 'data': serializer_obj.data
- }
+ data = {"data": serializer_obj.data}
return self.get_paginated_response(data)
- @extend_schema(summary="创建用户组",
- request=GroupSerializer,
- responses={201: GroupSerializer},
- description="创建一个用户组")
+ @extend_schema(
+ summary="创建用户组",
+ request=GroupSerializer,
+ responses={201: GroupSerializer},
+ description="创建一个用户组",
+ )
def post(self, request):
serializer = GroupSerializer(data=request.data)
if serializer.is_valid():
@@ -122,6 +138,7 @@ class GroupDetail(views.APIView):
"""
用户组操作
"""
+
serializer_class = GroupSerializer
def get_object(self, pk):
@@ -130,10 +147,12 @@ def get_object(self, pk):
except Group.DoesNotExist:
raise Http404
- @extend_schema(summary="更新用户组",
- request=GroupSerializer,
- responses={200: GroupSerializer},
- description="更新一个用户组")
+ @extend_schema(
+ summary="更新用户组",
+ request=GroupSerializer,
+ responses={200: GroupSerializer},
+ description="更新一个用户组",
+ )
def put(self, request, pk):
group = self.get_object(pk)
serializer = GroupSerializer(group, data=request.data)
@@ -142,8 +161,7 @@ def put(self, request, pk):
return Response(serializer.data)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
- @extend_schema(summary="删除用户组",
- description="删除一个用户组")
+ @extend_schema(summary="删除用户组", description="删除一个用户组")
def delete(self, request, pk):
group = self.get_object(pk)
group.delete()
@@ -154,27 +172,30 @@ class ResourceGroupList(generics.ListAPIView):
"""
列出所有的resourcegroup或者创建一个新的resourcegroup
"""
+
pagination_class = CustomizedPagination
serializer_class = ResourceGroupSerializer
- queryset = ResourceGroup.objects.all().order_by('group_id')
-
- @extend_schema(summary="资源组清单",
- request=ResourceGroupSerializer,
- responses={200: ResourceGroupSerializer},
- description="列出所有资源组(过滤,分页)")
+ queryset = ResourceGroup.objects.all().order_by("group_id")
+
+ @extend_schema(
+ summary="资源组清单",
+ request=ResourceGroupSerializer,
+ responses={200: ResourceGroupSerializer},
+ description="列出所有资源组(过滤,分页)",
+ )
def get(self, request):
groups = self.filter_queryset(self.queryset)
page_groups = self.paginate_queryset(queryset=groups)
serializer_obj = self.get_serializer(page_groups, many=True)
- data = {
- 'data': serializer_obj.data
- }
+ data = {"data": serializer_obj.data}
return self.get_paginated_response(data)
- @extend_schema(summary="创建资源组",
- request=ResourceGroupSerializer,
- responses={201: ResourceGroupSerializer},
- description="创建一个资源组")
+ @extend_schema(
+ summary="创建资源组",
+ request=ResourceGroupSerializer,
+ responses={201: ResourceGroupSerializer},
+ description="创建一个资源组",
+ )
def post(self, request):
serializer = ResourceGroupSerializer(data=request.data)
if serializer.is_valid():
@@ -187,6 +208,7 @@ class ResourceGroupDetail(views.APIView):
"""
资源组操作
"""
+
serializer_class = ResourceGroupSerializer
def get_object(self, pk):
@@ -195,10 +217,12 @@ def get_object(self, pk):
except ResourceGroup.DoesNotExist:
raise Http404
- @extend_schema(summary="更新资源组",
- request=ResourceGroupSerializer,
- responses={200: ResourceGroupSerializer},
- description="更新一个资源组")
+ @extend_schema(
+ summary="更新资源组",
+ request=ResourceGroupSerializer,
+ responses={200: ResourceGroupSerializer},
+ description="更新一个资源组",
+ )
def put(self, request, pk):
group = self.get_object(pk)
serializer = ResourceGroupSerializer(group, data=request.data)
@@ -207,8 +231,7 @@ def put(self, request, pk):
return Response(serializer.data)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
- @extend_schema(summary="删除资源组",
- description="删除一个资源组")
+ @extend_schema(summary="删除资源组", description="删除一个资源组")
def delete(self, request, pk):
group = self.get_object(pk)
group.delete()
@@ -219,24 +242,23 @@ class UserAuth(views.APIView):
"""
用户认证校验
"""
+
permission_classes = [IsOwner]
- @extend_schema(summary="用户认证校验",
- request=UserAuthSerializer,
- description="用户认证校验")
+ @extend_schema(summary="用户认证校验", request=UserAuthSerializer, description="用户认证校验")
def post(self, request):
# 参数验证
serializer = UserAuthSerializer(data=request.data)
if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
- result = {'status': 0, 'msg': '认证成功'}
- engineer = request.data['engineer']
- password = request.data['password']
+ result = {"status": 0, "msg": "认证成功"}
+ engineer = request.data["engineer"]
+ password = request.data["password"]
user = authenticate(username=engineer, password=password)
if not user:
- result = {'status': 1, 'msg': '用户名或密码错误!'}
+ result = {"status": 1, "msg": "用户名或密码错误!"}
return Response(result)
@@ -245,44 +267,43 @@ class TwoFA(views.APIView):
"""
配置2fa
"""
+
permission_classes = [permissions.AllowAny]
- @extend_schema(summary="配置2fa",
- request=TwoFASerializer,
- description="配置2fa")
+ @extend_schema(summary="配置2fa", request=TwoFASerializer, description="配置2fa")
def post(self, request):
# 参数验证
serializer = TwoFASerializer(data=request.data)
if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
- engineer = request.data['engineer']
- enable = request.data['enable']
- auth_type = request.data['auth_type']
+ engineer = request.data["engineer"]
+ enable = request.data["enable"]
+ auth_type = request.data["auth_type"]
user = Users.objects.get(username=engineer)
- request_user = request.session.get('user')
+ request_user = request.session.get("user")
if not request.user.is_authenticated:
if request_user:
if request_user != engineer:
- return Response({'status': 1, 'msg': '登录用户与校验用户不一致!'})
+ return Response({"status": 1, "msg": "登录用户与校验用户不一致!"})
else:
- return Response({'status': 1, 'msg': '需先校验用户密码!'})
+ return Response({"status": 1, "msg": "需先校验用户密码!"})
authenticator = get_authenticator(user=user, auth_type=auth_type)
- if enable == 'true':
- if auth_type == 'totp':
+ if enable == "true":
+ if auth_type == "totp":
# 启用2fa - 先生成secret key
result = authenticator.generate_key()
- elif auth_type == 'sms':
+ elif auth_type == "sms":
# 启用2fa - 先发送短信验证码
- phone = request.data['phone']
- otp = '{:06d}'.format(random.randint(0, 999999))
+ phone = request.data["phone"]
+ otp = "{:06d}".format(random.randint(0, 999999))
result = authenticator.get_captcha(phone=phone, otp=otp)
- if result['status'] == 0:
- r = get_redis_connection('default')
- data = {'otp': otp, 'update_time': int(time.time())}
- r.set(f'captcha-{phone}', json.dumps(data), 300)
+ if result["status"] == 0:
+ r = get_redis_connection("default")
+ data = {"otp": otp, "update_time": int(time.time())}
+ r.set(f"captcha-{phone}", json.dumps(data), 300)
else:
# 启用2fa
result = authenticator.enable()
@@ -296,23 +317,28 @@ class TwoFAState(views.APIView):
"""
查询用户2fa配置情况
"""
+
permission_classes = [IsOwner]
- @extend_schema(summary="查询2fa配置情况",
- request=TwoFAStateSerializer,
- description="查询2fa配置情况")
+ @extend_schema(
+ summary="查询2fa配置情况", request=TwoFAStateSerializer, description="查询2fa配置情况"
+ )
def post(self, request):
# 参数验证
serializer = TwoFAStateSerializer(data=request.data)
if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
- result = {'status': 0, 'msg': 'ok', 'data': {}}
- engineer = request.data['engineer']
+ result = {"status": 0, "msg": "ok", "data": {}}
+ engineer = request.data["engineer"]
user = Users.objects.get(username=engineer)
configs = TwoFactorAuthConfig.objects.filter(user=user)
- result['data']['totp'] = 'enabled' if configs.filter(auth_type='totp') else 'disabled'
- result['data']['sms'] = 'enabled' if configs.filter(auth_type='sms') else 'disabled'
+ result["data"]["totp"] = (
+ "enabled" if configs.filter(auth_type="totp") else "disabled"
+ )
+ result["data"]["sms"] = (
+ "enabled" if configs.filter(auth_type="sms") else "disabled"
+ )
return Response(result)
@@ -321,25 +347,26 @@ class TwoFASave(views.APIView):
"""
保存2fa配置(TOTP)
"""
+
permission_classes = [IsOwner]
- @extend_schema(summary="保存2fa配置",
- request=TwoFASaveSerializer,
- description="保存2fa配置")
+ @extend_schema(
+ summary="保存2fa配置", request=TwoFASaveSerializer, description="保存2fa配置"
+ )
def post(self, request):
# 参数验证
serializer = TwoFASaveSerializer(data=request.data)
if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
- engineer = request.data['engineer']
- auth_type = request.data['auth_type']
- key = request.data['key'] if 'key' in request.data.keys() else None
- phone = request.data['phone'] if 'phone' in request.data.keys() else None
+ engineer = request.data["engineer"]
+ auth_type = request.data["auth_type"]
+ key = request.data["key"] if "key" in request.data.keys() else None
+ phone = request.data["phone"] if "phone" in request.data.keys() else None
user = Users.objects.get(username=engineer)
authenticator = get_authenticator(user=user, auth_type=auth_type)
- if auth_type == 'sms':
+ if auth_type == "sms":
result = authenticator.save(phone)
else:
result = authenticator.save(key)
@@ -351,46 +378,47 @@ class TwoFAVerify(views.APIView):
"""
检验2fa密码
"""
+
permission_classes = [permissions.AllowAny]
- @extend_schema(summary="检验2fa密码",
- request=TwoFAVerifySerializer,
- description="检验2fa密码")
+ @extend_schema(
+ summary="检验2fa密码", request=TwoFAVerifySerializer, description="检验2fa密码"
+ )
def post(self, request):
# 参数验证
serializer = TwoFAVerifySerializer(data=request.data)
if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
- engineer = request.data['engineer']
- otp = request.data['otp']
- key = request.data['key'] if 'key' in request.data.keys() else None
- phone = request.data['phone'] if 'phone' in request.data.keys() else None
+ engineer = request.data["engineer"]
+ otp = request.data["otp"]
+ key = request.data["key"] if "key" in request.data.keys() else None
+ phone = request.data["phone"] if "phone" in request.data.keys() else None
user = Users.objects.get(username=engineer)
- request_user = request.session.get('user')
+ request_user = request.session.get("user")
if not request.user.is_authenticated:
if request_user:
if request_user != engineer:
- return Response({'status': 1, 'msg': '登录用户与校验用户不一致!'})
+ return Response({"status": 1, "msg": "登录用户与校验用户不一致!"})
else:
- return Response({'status': 1, 'msg': '需先校验用户密码!'})
+ return Response({"status": 1, "msg": "需先校验用户密码!"})
twofa_config = TwoFactorAuthConfig.objects.filter(user=user)
if not twofa_config:
if not key:
- return Response({'status': 1, 'msg': '用户未配置2FA!'})
+ return Response({"status": 1, "msg": "用户未配置2FA!"})
- auth_type = request.data['auth_type']
+ auth_type = request.data["auth_type"]
authenticator = get_authenticator(user=user, auth_type=auth_type)
- if auth_type == 'sms':
+ if auth_type == "sms":
result = authenticator.verify(otp, phone)
else:
result = authenticator.verify(otp, key)
# 校验通过后自动登录,刷新expire_date
- if result['status'] == 0 and not request.user.is_authenticated:
- login(request, user, backend='django.contrib.auth.backends.ModelBackend')
+ if result["status"] == 0 and not request.user.is_authenticated:
+ login(request, user, backend="django.contrib.auth.backends.ModelBackend")
request.session.set_expiry(settings.SESSION_COOKIE_AGE)
return Response(result)
diff --git a/sql_api/api_workflow.py b/sql_api/api_workflow.py
index 3575fcc6e6..7bff6b5798 100644
--- a/sql_api/api_workflow.py
+++ b/sql_api/api_workflow.py
@@ -1,12 +1,28 @@
from rest_framework import views, generics, status, serializers
from rest_framework.response import Response
from drf_spectacular.utils import extend_schema
-from .serializers import WorkflowContentSerializer, ExecuteCheckSerializer, \
- ExecuteCheckResultSerializer, WorkflowAuditSerializer, WorkflowAuditListSerializer, \
- WorkflowLogSerializer, WorkflowLogListSerializer, AuditWorkflowSerializer, ExecuteWorkflowSerializer
+from .serializers import (
+ WorkflowContentSerializer,
+ ExecuteCheckSerializer,
+ ExecuteCheckResultSerializer,
+ WorkflowAuditSerializer,
+ WorkflowAuditListSerializer,
+ WorkflowLogSerializer,
+ WorkflowLogListSerializer,
+ AuditWorkflowSerializer,
+ ExecuteWorkflowSerializer,
+)
from .pagination import CustomizedPagination
from .filters import WorkflowFilter, WorkflowAuditFilter
-from sql.models import SqlWorkflow, SqlWorkflowContent, Instance, WorkflowAudit, Users, WorkflowLog, ArchiveConfig
+from sql.models import (
+ SqlWorkflow,
+ SqlWorkflowContent,
+ Instance,
+ WorkflowAudit,
+ Users,
+ WorkflowLog,
+ ArchiveConfig,
+)
from sql.utils.sql_review import can_cancel, can_execute, on_correct_time_period
from sql.utils.resource_group import user_groups
from sql.utils.workflow_audit import Audit
@@ -23,24 +39,27 @@
import datetime
import logging
-logger = logging.getLogger('default')
+logger = logging.getLogger("default")
class ExecuteCheck(views.APIView):
- @extend_schema(summary="SQL检查",
- request=ExecuteCheckSerializer,
- responses={200: ExecuteCheckResultSerializer},
- description="对提供的SQL进行语法检查")
+ @extend_schema(
+ summary="SQL检查",
+ request=ExecuteCheckSerializer,
+ responses={200: ExecuteCheckResultSerializer},
+ description="对提供的SQL进行语法检查",
+ )
def post(self, request):
# 参数验证
serializer = ExecuteCheckSerializer(data=request.data)
if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
- instance = Instance.objects.get(pk=request.data['instance_id'])
+ instance = Instance.objects.get(pk=request.data["instance_id"])
check_engine = get_engine(instance=instance)
- check_result = check_engine.execute_check(db_name=request.data['db_name'],
- sql=request.data['full_sql'].strip())
+ check_result = check_engine.execute_check(
+ db_name=request.data["db_name"], sql=request.data["full_sql"].strip()
+ )
review_result_list = []
for r in check_result.rows:
review_result_list += [r.__dict__]
@@ -53,28 +72,33 @@ class WorkflowList(generics.ListAPIView):
"""
列出所有的workflow或者提交一条新的workflow
"""
+
filterset_class = WorkflowFilter
pagination_class = CustomizedPagination
serializer_class = WorkflowContentSerializer
- queryset = SqlWorkflowContent.objects.all().select_related('workflow').order_by('-id')
-
- @extend_schema(summary="SQL上线工单清单",
- request=WorkflowContentSerializer,
- responses={200: WorkflowContentSerializer},
- description="列出所有SQL上线工单(过滤,分页)")
+ queryset = (
+ SqlWorkflowContent.objects.all().select_related("workflow").order_by("-id")
+ )
+
+ @extend_schema(
+ summary="SQL上线工单清单",
+ request=WorkflowContentSerializer,
+ responses={200: WorkflowContentSerializer},
+ description="列出所有SQL上线工单(过滤,分页)",
+ )
def get(self, request):
workflows = self.filter_queryset(self.queryset)
page_wf = self.paginate_queryset(queryset=workflows)
serializer_obj = self.get_serializer(page_wf, many=True)
- data = {
- 'data': serializer_obj.data
- }
+ data = {"data": serializer_obj.data}
return self.get_paginated_response(data)
- @extend_schema(summary="提交SQL上线工单",
- request=WorkflowContentSerializer,
- responses={201: WorkflowContentSerializer},
- description="提交一条SQL上线工单")
+ @extend_schema(
+ summary="提交SQL上线工单",
+ request=WorkflowContentSerializer,
+ responses={201: WorkflowContentSerializer},
+ description="提交一条SQL上线工单",
+ )
def post(self, request):
serializer = WorkflowContentSerializer(data=request.data)
if serializer.is_valid():
@@ -87,20 +111,24 @@ class WorkflowAuditList(generics.ListAPIView):
"""
列出指定用户当前待自己审核的工单
"""
+
filterset_class = WorkflowAuditFilter
pagination_class = CustomizedPagination
serializer_class = WorkflowAuditListSerializer
queryset = WorkflowAudit.objects.filter(
- current_status=WorkflowDict.workflow_status['audit_wait']).order_by('-audit_id')
+ current_status=WorkflowDict.workflow_status["audit_wait"]
+ ).order_by("-audit_id")
@extend_schema(exclude=True)
def get(self, request):
return Response({"detail": "方法 “GET” 不被允许。"})
- @extend_schema(summary="待审核清单",
- request=WorkflowAuditSerializer,
- responses={200: WorkflowAuditListSerializer},
- description="列出指定用户待审核清单(过滤,分页)")
+ @extend_schema(
+ summary="待审核清单",
+ request=WorkflowAuditSerializer,
+ responses={200: WorkflowAuditListSerializer},
+ description="列出指定用户待审核清单(过滤,分页)",
+ )
def post(self, request):
# 参数验证
serializer = WorkflowAuditSerializer(data=request.data)
@@ -108,7 +136,7 @@ def post(self, request):
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
# 先获取用户所在资源组列表
- user = Users.objects.get(username=request.data['engineer'])
+ user = Users.objects.get(username=request.data["engineer"])
group_list = user_groups(user)
group_ids = [group.group_id for group in group_list]
@@ -118,15 +146,15 @@ def post(self, request):
else:
auth_group_ids = [group.id for group in Group.objects.filter(user=user)]
- self.queryset = self.queryset.filter(current_status=WorkflowDict.workflow_status['audit_wait'],
- group_id__in=group_ids,
- current_audit__in=auth_group_ids)
+ self.queryset = self.queryset.filter(
+ current_status=WorkflowDict.workflow_status["audit_wait"],
+ group_id__in=group_ids,
+ current_audit__in=auth_group_ids,
+ )
audit = self.filter_queryset(self.queryset)
page_audit = self.paginate_queryset(queryset=audit)
serializer_obj = self.get_serializer(page_audit, many=True)
- data = {
- 'data': serializer_obj.data
- }
+ data = {"data": serializer_obj.data}
return self.get_paginated_response(data)
@@ -135,28 +163,28 @@ class AuditWorkflow(views.APIView):
审核workflow,包括查询权限申请、SQL上线申请、数据归档申请
"""
- @extend_schema(summary="审核工单",
- request=AuditWorkflowSerializer,
- description="审核一条工单(通过或终止)")
+ @extend_schema(
+ summary="审核工单", request=AuditWorkflowSerializer, description="审核一条工单(通过或终止)"
+ )
def post(self, request):
# 参数验证
serializer = AuditWorkflowSerializer(data=request.data)
if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
- audit_type = request.data['audit_type']
- workflow_type = request.data['workflow_type']
- workflow_id = request.data['workflow_id']
- audit_remark = request.data['audit_remark']
- engineer = request.data['engineer']
+ audit_type = request.data["audit_type"]
+ workflow_type = request.data["workflow_type"]
+ workflow_id = request.data["workflow_id"]
+ audit_remark = request.data["audit_remark"]
+ engineer = request.data["engineer"]
user = Users.objects.get(username=engineer)
# 审核查询权限申请
if workflow_type == 1:
- audit_status = 1 if audit_type == 'pass' else 2
+ audit_status = 1 if audit_type == "pass" else 2
if audit_remark is None:
- audit_remark = ''
+ audit_remark = ""
if Audit.can_review(user, workflow_id, workflow_type) is False:
raise serializers.ValidationError({"errors": "你无权操作当前工单!"})
@@ -164,30 +192,49 @@ def post(self, request):
# 使用事务保持数据一致性
try:
with transaction.atomic():
- audit_id = Audit.detail_by_workflow_id(workflow_id=workflow_id,
- workflow_type=WorkflowDict.workflow_type['query']).audit_id
+ audit_id = Audit.detail_by_workflow_id(
+ workflow_id=workflow_id,
+ workflow_type=WorkflowDict.workflow_type["query"],
+ ).audit_id
# 调用工作流接口审核
- audit_result = Audit.audit(audit_id, audit_status, user.username, audit_remark)
+ audit_result = Audit.audit(
+ audit_id, audit_status, user.username, audit_remark
+ )
# 按照审核结果更新业务表审核状态
audit_detail = Audit.detail(audit_id)
- if audit_detail.workflow_type == WorkflowDict.workflow_type['query']:
+ if (
+ audit_detail.workflow_type
+ == WorkflowDict.workflow_type["query"]
+ ):
# 更新业务表审核状态,插入权限信息
- _query_apply_audit_call_back(audit_detail.workflow_id, audit_result['data']['workflow_status'])
+ _query_apply_audit_call_back(
+ audit_detail.workflow_id,
+ audit_result["data"]["workflow_status"],
+ )
except Exception as msg:
logger.error(traceback.format_exc())
- raise serializers.ValidationError({'errors': msg})
+ raise serializers.ValidationError({"errors": msg})
else:
# 消息通知
- async_task(notify_for_audit, audit_id=audit_id, audit_remark=audit_remark, timeout=60,
- task_name=f'query-priv-audit-{workflow_id}')
- return Response({'msg': 'passed'}) if audit_type == 'pass' else Response({'msg': 'canceled'})
+ async_task(
+ notify_for_audit,
+ audit_id=audit_id,
+ audit_remark=audit_remark,
+ timeout=60,
+ task_name=f"query-priv-audit-{workflow_id}",
+ )
+ return (
+ Response({"msg": "passed"})
+ if audit_type == "pass"
+ else Response({"msg": "canceled"})
+ )
# 审核SQL上线申请
elif workflow_type == 2:
# SQL上线申请通过
- if audit_type == 'pass':
+ if audit_type == "pass":
# 权限验证
if Audit.can_review(user, workflow_id, workflow_type) is False:
raise serializers.ValidationError({"errors": "你无权操作当前工单!"})
@@ -196,30 +243,48 @@ def post(self, request):
try:
with transaction.atomic():
# 调用工作流接口审核
- audit_id = Audit.detail_by_workflow_id(workflow_id=workflow_id,
- workflow_type=WorkflowDict.workflow_type[
- 'sqlreview']).audit_id
- audit_result = Audit.audit(audit_id, WorkflowDict.workflow_status['audit_success'],
- user.username, audit_remark)
+ audit_id = Audit.detail_by_workflow_id(
+ workflow_id=workflow_id,
+ workflow_type=WorkflowDict.workflow_type["sqlreview"],
+ ).audit_id
+ audit_result = Audit.audit(
+ audit_id,
+ WorkflowDict.workflow_status["audit_success"],
+ user.username,
+ audit_remark,
+ )
# 按照审核结果更新业务表审核状态
- if audit_result['data']['workflow_status'] == WorkflowDict.workflow_status['audit_success']:
+ if (
+ audit_result["data"]["workflow_status"]
+ == WorkflowDict.workflow_status["audit_success"]
+ ):
# 将流程状态修改为审核通过
- SqlWorkflow(id=workflow_id, status='workflow_review_pass').save(update_fields=['status'])
+ SqlWorkflow(
+ id=workflow_id, status="workflow_review_pass"
+ ).save(update_fields=["status"])
except Exception as msg:
logger.error(traceback.format_exc())
- raise serializers.ValidationError({'errors': msg})
+ raise serializers.ValidationError({"errors": msg})
else:
# 开启了Pass阶段通知参数才发送消息通知
sys_config = SysConfig()
- is_notified = 'Pass' in sys_config.get('notify_phase_control').split(',') \
- if sys_config.get('notify_phase_control') else True
+ is_notified = (
+ "Pass" in sys_config.get("notify_phase_control").split(",")
+ if sys_config.get("notify_phase_control")
+ else True
+ )
if is_notified:
- async_task(notify_for_audit, audit_id=audit_id, audit_remark=audit_remark, timeout=60,
- task_name=f'sqlreview-pass-{workflow_id}')
- return Response({'msg': 'passed'})
+ async_task(
+ notify_for_audit,
+ audit_id=audit_id,
+ audit_remark=audit_remark,
+ timeout=60,
+ task_name=f"sqlreview-pass-{workflow_id}",
+ )
+ return Response({"msg": "passed"})
# SQL上线申请驳回/取消
- elif audit_type == 'cancel':
+ elif audit_type == "cancel":
workflow_detail = SqlWorkflow.objects.get(id=workflow_id)
if audit_remark is None:
@@ -232,73 +297,93 @@ def post(self, request):
try:
with transaction.atomic():
# 调用工作流接口取消或者驳回
- audit_id = Audit.detail_by_workflow_id(workflow_id=workflow_id,
- workflow_type=WorkflowDict.workflow_type[
- 'sqlreview']).audit_id
+ audit_id = Audit.detail_by_workflow_id(
+ workflow_id=workflow_id,
+ workflow_type=WorkflowDict.workflow_type["sqlreview"],
+ ).audit_id
# 仅待审核的需要调用工作流,审核通过的不需要
- if workflow_detail.status != 'workflow_manreviewing':
+ if workflow_detail.status != "workflow_manreviewing":
# 增加工单日志
if user.username == workflow_detail.engineer:
- Audit.add_log(audit_id=audit_id,
- operation_type=3,
- operation_type_desc='取消执行',
- operation_info="取消原因:{}".format(audit_remark),
- operator=user.username,
- operator_display=user.display
- )
+ Audit.add_log(
+ audit_id=audit_id,
+ operation_type=3,
+ operation_type_desc="取消执行",
+ operation_info="取消原因:{}".format(audit_remark),
+ operator=user.username,
+ operator_display=user.display,
+ )
else:
- Audit.add_log(audit_id=audit_id,
- operation_type=2,
- operation_type_desc='审批不通过',
- operation_info="审批备注:{}".format(audit_remark),
- operator=user.username,
- operator_display=user.display
- )
+ Audit.add_log(
+ audit_id=audit_id,
+ operation_type=2,
+ operation_type_desc="审批不通过",
+ operation_info="审批备注:{}".format(audit_remark),
+ operator=user.username,
+ operator_display=user.display,
+ )
else:
if user.username == workflow_detail.engineer:
- Audit.audit(audit_id,
- WorkflowDict.workflow_status['audit_abort'],
- user.username, audit_remark)
+ Audit.audit(
+ audit_id,
+ WorkflowDict.workflow_status["audit_abort"],
+ user.username,
+ audit_remark,
+ )
# 非提交人需要校验审核权限
- elif user.has_perm('sql.sql_review'):
- Audit.audit(audit_id,
- WorkflowDict.workflow_status['audit_reject'],
- user.username, audit_remark)
+ elif user.has_perm("sql.sql_review"):
+ Audit.audit(
+ audit_id,
+ WorkflowDict.workflow_status["audit_reject"],
+ user.username,
+ audit_remark,
+ )
else:
- raise serializers.ValidationError({"errors": "Permission Denied"})
+ raise serializers.ValidationError(
+ {"errors": "Permission Denied"}
+ )
# 删除定时执行task
- if workflow_detail.status == 'workflow_timingtask':
+ if workflow_detail.status == "workflow_timingtask":
schedule_name = f"sqlreview-timing-{workflow_id}"
del_schedule(schedule_name)
# 将流程状态修改为人工终止流程
- workflow_detail.status = 'workflow_abort'
+ workflow_detail.status = "workflow_abort"
workflow_detail.save()
except Exception as msg:
logger.error(f"取消工单报错,错误信息:{traceback.format_exc()}")
- raise serializers.ValidationError({'errors': msg})
+ raise serializers.ValidationError({"errors": msg})
else:
# 发送取消、驳回通知,开启了Cancel阶段通知参数才发送消息通知
sys_config = SysConfig()
- is_notified = 'Cancel' in sys_config.get('notify_phase_control').split(',') \
- if sys_config.get('notify_phase_control') else True
+ is_notified = (
+ "Cancel" in sys_config.get("notify_phase_control").split(",")
+ if sys_config.get("notify_phase_control")
+ else True
+ )
if is_notified:
- audit_detail = Audit.detail_by_workflow_id(workflow_id=workflow_id,
- workflow_type=WorkflowDict.workflow_type[
- 'sqlreview'])
+ audit_detail = Audit.detail_by_workflow_id(
+ workflow_id=workflow_id,
+ workflow_type=WorkflowDict.workflow_type["sqlreview"],
+ )
if audit_detail.current_status in (
- WorkflowDict.workflow_status['audit_abort'],
- WorkflowDict.workflow_status['audit_reject']):
- async_task(notify_for_audit, audit_id=audit_detail.audit_id, audit_remark=audit_remark,
- timeout=60,
- task_name=f'sqlreview-cancel-{workflow_id}')
- return Response({'msg': 'canceled'})
+ WorkflowDict.workflow_status["audit_abort"],
+ WorkflowDict.workflow_status["audit_reject"],
+ ):
+ async_task(
+ notify_for_audit,
+ audit_id=audit_detail.audit_id,
+ audit_remark=audit_remark,
+ timeout=60,
+ task_name=f"sqlreview-cancel-{workflow_id}",
+ )
+ return Response({"msg": "canceled"})
# 审核数据归档申请
elif workflow_type == 3:
- audit_status = 1 if audit_type == 'pass' else 2
+ audit_status = 1 if audit_type == "pass" else 2
if audit_remark is None:
- audit_remark = ''
+ audit_remark = ""
if Audit.can_review(user, workflow_id, workflow_type) is False:
raise serializers.ValidationError({"errors": "你无权操作当前工单!"})
@@ -306,24 +391,39 @@ def post(self, request):
# 使用事务保持数据一致性
try:
with transaction.atomic():
- audit_id = Audit.detail_by_workflow_id(workflow_id=workflow_id,
- workflow_type=WorkflowDict.workflow_type['archive']).audit_id
+ audit_id = Audit.detail_by_workflow_id(
+ workflow_id=workflow_id,
+ workflow_type=WorkflowDict.workflow_type["archive"],
+ ).audit_id
# 调用工作流插入审核信息,更新业务表审核状态
- audit_status = Audit.audit(audit_id, audit_status, user.username, audit_remark)['data'][
- 'workflow_status']
- ArchiveConfig(id=workflow_id,
- status=audit_status,
- state=True if audit_status == WorkflowDict.workflow_status['audit_success'] else False
- ).save(update_fields=['status', 'state'])
+ audit_status = Audit.audit(
+ audit_id, audit_status, user.username, audit_remark
+ )["data"]["workflow_status"]
+ ArchiveConfig(
+ id=workflow_id,
+ status=audit_status,
+ state=True
+ if audit_status == WorkflowDict.workflow_status["audit_success"]
+ else False,
+ ).save(update_fields=["status", "state"])
except Exception as msg:
logger.error(traceback.format_exc())
- raise serializers.ValidationError({'errors': msg})
+ raise serializers.ValidationError({"errors": msg})
else:
# 消息通知
- async_task(notify_for_audit, audit_id=audit_id, audit_remark=audit_remark, timeout=60,
- task_name=f'archive-audit-{workflow_id}')
- return Response({'msg': 'passed'}) if audit_type == 'pass' else Response({'msg': 'canceled'})
+ async_task(
+ notify_for_audit,
+ audit_id=audit_id,
+ audit_remark=audit_remark,
+ timeout=60,
+ task_name=f"archive-audit-{workflow_id}",
+ )
+ return (
+ Response({"msg": "passed"})
+ if audit_type == "pass"
+ else Response({"msg": "canceled"})
+ )
class ExecuteWorkflow(views.APIView):
@@ -331,86 +431,116 @@ class ExecuteWorkflow(views.APIView):
执行workflow,包括SQL上线工单、数据归档工单
"""
- @extend_schema(summary="执行工单",
- request=ExecuteWorkflowSerializer,
- description="执行一条工单")
+ @extend_schema(
+ summary="执行工单", request=ExecuteWorkflowSerializer, description="执行一条工单"
+ )
def post(self, request):
# 参数验证
serializer = ExecuteWorkflowSerializer(data=request.data)
if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
- workflow_type = request.data['workflow_type']
- workflow_id = request.data['workflow_id']
+ workflow_type = request.data["workflow_type"]
+ workflow_id = request.data["workflow_id"]
# 执行SQL上线工单
if workflow_type == 2:
- mode = request.data['mode']
- engineer = request.data['engineer']
+ mode = request.data["mode"]
+ engineer = request.data["engineer"]
user = Users.objects.get(username=engineer)
# 校验多个权限
- if not (user.has_perm('sql.sql_execute') or user.has_perm('sql.sql_execute_for_resource_group')):
+ if not (
+ user.has_perm("sql.sql_execute")
+ or user.has_perm("sql.sql_execute_for_resource_group")
+ ):
raise serializers.ValidationError({"errors": "你无权执行当前工单!"})
if can_execute(user, workflow_id) is False:
raise serializers.ValidationError({"errors": "你无权执行当前工单!"})
if on_correct_time_period(workflow_id) is False:
- raise serializers.ValidationError({"errors": "不在可执行时间范围内,如果需要修改执行时间请重新提交工单!"})
+ raise serializers.ValidationError(
+ {"errors": "不在可执行时间范围内,如果需要修改执行时间请重新提交工单!"}
+ )
# 获取审核信息
- audit_id = Audit.detail_by_workflow_id(workflow_id=workflow_id,
- workflow_type=WorkflowDict.workflow_type['sqlreview']).audit_id
+ audit_id = Audit.detail_by_workflow_id(
+ workflow_id=workflow_id,
+ workflow_type=WorkflowDict.workflow_type["sqlreview"],
+ ).audit_id
# 交由系统执行
if mode == "auto":
# 修改工单状态为排队中
- SqlWorkflow(id=workflow_id, status="workflow_queuing").save(update_fields=['status'])
+ SqlWorkflow(id=workflow_id, status="workflow_queuing").save(
+ update_fields=["status"]
+ )
# 删除定时执行任务
schedule_name = f"sqlreview-timing-{workflow_id}"
del_schedule(schedule_name)
# 加入执行队列
- async_task('sql.utils.execute_sql.execute', workflow_id, user,
- hook='sql.utils.execute_sql.execute_callback',
- timeout=-1, task_name=f'sqlreview-execute-{workflow_id}')
+ async_task(
+ "sql.utils.execute_sql.execute",
+ workflow_id,
+ user,
+ hook="sql.utils.execute_sql.execute_callback",
+ timeout=-1,
+ task_name=f"sqlreview-execute-{workflow_id}",
+ )
# 增加工单日志
- Audit.add_log(audit_id=audit_id,
- operation_type=5,
- operation_type_desc='执行工单',
- operation_info='工单执行排队中',
- operator=user.username,
- operator_display=user.display)
+ Audit.add_log(
+ audit_id=audit_id,
+ operation_type=5,
+ operation_type_desc="执行工单",
+ operation_info="工单执行排队中",
+ operator=user.username,
+ operator_display=user.display,
+ )
# 线下手工执行
elif mode == "manual":
# 将流程状态修改为执行结束
- SqlWorkflow(id=workflow_id, status="workflow_finish", finish_time=datetime.datetime.now()
- ).save(update_fields=['status', 'finish_time'])
+ SqlWorkflow(
+ id=workflow_id,
+ status="workflow_finish",
+ finish_time=datetime.datetime.now(),
+ ).save(update_fields=["status", "finish_time"])
# 增加工单日志
- Audit.add_log(audit_id=audit_id,
- operation_type=6,
- operation_type_desc='手工工单',
- operation_info='确认手工执行结束',
- operator=user.username,
- operator_display=user.display)
+ Audit.add_log(
+ audit_id=audit_id,
+ operation_type=6,
+ operation_type_desc="手工工单",
+ operation_info="确认手工执行结束",
+ operator=user.username,
+ operator_display=user.display,
+ )
# 开启了Execute阶段通知参数才发送消息通知
sys_config = SysConfig()
- is_notified = 'Execute' in sys_config.get('notify_phase_control').split(',') \
- if sys_config.get('notify_phase_control') else True
+ is_notified = (
+ "Execute" in sys_config.get("notify_phase_control").split(",")
+ if sys_config.get("notify_phase_control")
+ else True
+ )
if is_notified:
notify_for_execute(SqlWorkflow.objects.get(id=workflow_id))
# 执行数据归档工单
elif workflow_type == 3:
- async_task('sql.archiver.archive', workflow_id, timeout=-1, task_name=f'archive-{workflow_id}')
+ async_task(
+ "sql.archiver.archive",
+ workflow_id,
+ timeout=-1,
+ task_name=f"archive-{workflow_id}",
+ )
- return Response({'msg': '开始执行,执行结果请到工单详情页查看'})
+ return Response({"msg": "开始执行,执行结果请到工单详情页查看"})
class WorkflowLogList(generics.ListAPIView):
"""
获取某个工单的日志
"""
+
pagination_class = CustomizedPagination
serializer_class = WorkflowLogListSerializer
queryset = WorkflowLog.objects.all()
@@ -419,22 +549,24 @@ class WorkflowLogList(generics.ListAPIView):
def get(self, request):
return Response({"detail": "方法 “GET” 不被允许。"})
- @extend_schema(summary="工单日志",
- request=WorkflowLogSerializer,
- responses={200: WorkflowLogListSerializer},
- description="获取某个工单的日志(分页)")
+ @extend_schema(
+ summary="工单日志",
+ request=WorkflowLogSerializer,
+ responses={200: WorkflowLogListSerializer},
+ description="获取某个工单的日志(分页)",
+ )
def post(self, request):
# 参数验证
serializer = WorkflowLogSerializer(data=request.data)
if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
- audit_id = WorkflowAudit.objects.get(workflow_id=request.data['workflow_id'],
- workflow_type=request.data['workflow_type']).audit_id
- workflow_logs = self.queryset.filter(audit_id=audit_id).order_by('-id')
+ audit_id = WorkflowAudit.objects.get(
+ workflow_id=request.data["workflow_id"],
+ workflow_type=request.data["workflow_type"],
+ ).audit_id
+ workflow_logs = self.queryset.filter(audit_id=audit_id).order_by("-id")
page_log = self.paginate_queryset(queryset=workflow_logs)
serializer_obj = self.get_serializer(page_log, many=True)
- data = {
- 'data': serializer_obj.data
- }
+ data = {"data": serializer_obj.data}
return self.get_paginated_response(data)
diff --git a/sql_api/apps.py b/sql_api/apps.py
index 8cec48388f..e3bdbe1fee 100644
--- a/sql_api/apps.py
+++ b/sql_api/apps.py
@@ -2,4 +2,4 @@
class SqlApi2Config(AppConfig):
- name = 'sql_api'
+ name = "sql_api"
diff --git a/sql_api/filters.py b/sql_api/filters.py
index db71b0f630..11e4bbc123 100644
--- a/sql_api/filters.py
+++ b/sql_api/filters.py
@@ -3,48 +3,44 @@
class UserFilter(filters.FilterSet):
-
class Meta:
model = Users
fields = {
- 'id': ['exact'],
- 'username': ['exact'],
+ "id": ["exact"],
+ "username": ["exact"],
}
class InstanceFilter(filters.FilterSet):
-
class Meta:
model = Instance
fields = {
- 'id': ['exact'],
- 'instance_name': ['icontains'],
- 'db_type': ['exact'],
- 'host': ['exact'],
+ "id": ["exact"],
+ "instance_name": ["icontains"],
+ "db_type": ["exact"],
+ "host": ["exact"],
}
class WorkflowFilter(filters.FilterSet):
-
class Meta:
model = SqlWorkflowContent
fields = {
- 'id': ['exact'],
- 'workflow_id': ['exact'],
- 'workflow__workflow_name': ['icontains'],
- 'workflow__instance_id': ['exact'],
- 'workflow__db_name': ['exact'],
- 'workflow__engineer': ['exact'],
- 'workflow__status': ['exact'],
- 'workflow__create_time': ['lt', 'gte'],
+ "id": ["exact"],
+ "workflow_id": ["exact"],
+ "workflow__workflow_name": ["icontains"],
+ "workflow__instance_id": ["exact"],
+ "workflow__db_name": ["exact"],
+ "workflow__engineer": ["exact"],
+ "workflow__status": ["exact"],
+ "workflow__create_time": ["lt", "gte"],
}
class WorkflowAuditFilter(filters.FilterSet):
-
class Meta:
model = WorkflowAudit
fields = {
- 'workflow_title': ['icontains'],
- 'workflow_type': ['exact'],
+ "workflow_title": ["icontains"],
+ "workflow_type": ["exact"],
}
diff --git a/sql_api/pagination.py b/sql_api/pagination.py
index 510015965c..b262fce9dc 100644
--- a/sql_api/pagination.py
+++ b/sql_api/pagination.py
@@ -8,15 +8,24 @@ class CustomizedPagination(PageNumberPagination):
"""
自定义分页器
"""
- page_size = settings.REST_FRAMEWORK['PAGE_SIZE'] if 'PAGE_SIZE' in settings.REST_FRAMEWORK.keys() else 20
- page_query_param = 'page'
- page_size_query_param = 'size'
+
+ page_size = (
+ settings.REST_FRAMEWORK["PAGE_SIZE"]
+ if "PAGE_SIZE" in settings.REST_FRAMEWORK.keys()
+ else 20
+ )
+ page_query_param = "page"
+ page_size_query_param = "size"
max_page_size = None
def get_paginated_response(self, data):
- return Response(OrderedDict([
- ('count', data.get('count', self.page.paginator.count)),
- ('next', self.get_next_link()),
- ('previous', self.get_previous_link()),
- ('results', data.get('data', data))
- ]))
+ return Response(
+ OrderedDict(
+ [
+ ("count", data.get("count", self.page.paginator.count)),
+ ("next", self.get_next_link()),
+ ("previous", self.get_previous_link()),
+ ("results", data.get("data", data)),
+ ]
+ )
+ )
diff --git a/sql_api/permissions.py b/sql_api/permissions.py
index ff017a2d49..bff4e508b1 100644
--- a/sql_api/permissions.py
+++ b/sql_api/permissions.py
@@ -6,9 +6,10 @@ class IsInUserWhitelist(permissions.BasePermission):
"""
自定义权限,只允许白名单用户调用api
"""
+
def has_permission(self, request, view):
- config = SysConfig().get('api_user_whitelist')
- user_list = config.split(',') if config else []
+ config = SysConfig().get("api_user_whitelist")
+ user_list = config.split(",") if config else []
api_user_whitelist = [int(uid) for uid in user_list]
# 只有在api_user_whitelist参数中的用户才有权限
@@ -19,9 +20,10 @@ class IsOwner(permissions.BasePermission):
"""
当参数engineer与请求用户一致时才有权限
"""
+
def has_permission(self, request, view):
try:
- engineer = request.data['engineer']
+ engineer = request.data["engineer"]
except KeyError as e:
return False
diff --git a/sql_api/serializers.py b/sql_api/serializers.py
index 874ef81352..93045ad673 100644
--- a/sql_api/serializers.py
+++ b/sql_api/serializers.py
@@ -1,6 +1,16 @@
from rest_framework import serializers
-from sql.models import Users, Instance, Tunnel, AliyunRdsConfig, CloudAccessKey, \
- SqlWorkflow, SqlWorkflowContent, ResourceGroup, WorkflowAudit, WorkflowLog
+from sql.models import (
+ Users,
+ Instance,
+ Tunnel,
+ AliyunRdsConfig,
+ CloudAccessKey,
+ SqlWorkflow,
+ SqlWorkflowContent,
+ ResourceGroup,
+ WorkflowAudit,
+ WorkflowLog,
+)
from django.contrib.auth.models import Group
from django.contrib.auth.password_validation import validate_password
from django.core.exceptions import ValidationError
@@ -15,7 +25,7 @@
import traceback
import logging
-logger = logging.getLogger('default')
+logger = logging.getLogger("default")
class UserSerializer(serializers.ModelSerializer):
@@ -35,20 +45,13 @@ def validate_password(self, password):
class Meta:
model = Users
fields = "__all__"
- extra_kwargs = {
- 'password': {
- 'write_only': True
- },
- 'display': {
- 'required': True
- }
- }
+ extra_kwargs = {"password": {"write_only": True}, "display": {"required": True}}
class UserDetailSerializer(serializers.ModelSerializer):
def update(self, instance, validated_data):
for attr, value in validated_data.items():
- if attr == 'password':
+ if attr == "password":
instance.set_password(value)
else:
setattr(instance, attr, value)
@@ -66,64 +69,58 @@ class Meta:
model = Users
fields = "__all__"
extra_kwargs = {
- 'password': {
- 'write_only': True,
- 'required': False
- },
- 'username': {
- 'required': False
- }
+ "password": {"write_only": True, "required": False},
+ "username": {"required": False},
}
class GroupSerializer(serializers.ModelSerializer):
-
class Meta:
model = Group
fields = "__all__"
class ResourceGroupSerializer(serializers.ModelSerializer):
-
class Meta:
model = ResourceGroup
fields = "__all__"
class UserAuthSerializer(serializers.Serializer):
- engineer = serializers.CharField(label='用户名')
- password = serializers.CharField(label='密码')
+ engineer = serializers.CharField(label="用户名")
+ password = serializers.CharField(label="密码")
class TwoFASerializer(serializers.Serializer):
- engineer = serializers.CharField(label='用户名')
- enable = serializers.ChoiceField(choices=['true', 'false'], label='启用or禁用')
- phone = serializers.CharField(required=False, label='手机号码')
- auth_type = serializers.ChoiceField(choices=['totp', 'sms'],
- label='验证类型:totp-Google身份验证器,sms-短信验证码')
+ engineer = serializers.CharField(label="用户名")
+ enable = serializers.ChoiceField(choices=["true", "false"], label="启用or禁用")
+ phone = serializers.CharField(required=False, label="手机号码")
+ auth_type = serializers.ChoiceField(
+ choices=["totp", "sms"], label="验证类型:totp-Google身份验证器,sms-短信验证码"
+ )
def validate(self, attrs):
- auth_type = attrs.get('auth_type')
- engineer = attrs.get('engineer')
- enable = attrs.get('enable')
+ auth_type = attrs.get("auth_type")
+ engineer = attrs.get("engineer")
+ enable = attrs.get("enable")
try:
Users.objects.get(username=engineer)
except Users.DoesNotExist:
raise serializers.ValidationError({"errors": "不存在该用户"})
- if auth_type == 'sms' and enable == 'true':
- if not attrs.get('phone'):
+ if auth_type == "sms" and enable == "true":
+ if not attrs.get("phone"):
raise serializers.ValidationError({"errors": "缺少 phone"})
return attrs
class TwoFAStateSerializer(serializers.Serializer):
- engineer = serializers.CharField(label='用户名')
+ engineer = serializers.CharField(label="用户名")
def validate(self, attrs):
- engineer = attrs.get('engineer')
+ engineer = attrs.get("engineer")
try:
Users.objects.get(username=engineer)
@@ -134,23 +131,25 @@ def validate(self, attrs):
class TwoFASaveSerializer(serializers.Serializer):
- engineer = serializers.CharField(label='用户名')
- key = serializers.CharField(required=False, label='密钥')
- phone = serializers.CharField(required=False, label='手机号码')
- auth_type = serializers.ChoiceField(choices=['disabled', 'totp', 'sms'],
- label='验证类型:disabled-关闭,totp-Google身份验证器,sms-短信验证码')
+ engineer = serializers.CharField(label="用户名")
+ key = serializers.CharField(required=False, label="密钥")
+ phone = serializers.CharField(required=False, label="手机号码")
+ auth_type = serializers.ChoiceField(
+ choices=["disabled", "totp", "sms"],
+ label="验证类型:disabled-关闭,totp-Google身份验证器,sms-短信验证码",
+ )
def validate(self, attrs):
- engineer = attrs.get('engineer')
- auth_type = attrs.get('auth_type')
- key = attrs.get('key')
- phone = attrs.get('phone')
+ engineer = attrs.get("engineer")
+ auth_type = attrs.get("auth_type")
+ key = attrs.get("key")
+ phone = attrs.get("phone")
- if auth_type == 'sms':
+ if auth_type == "sms":
if not phone:
raise serializers.ValidationError({"errors": "缺少 phone"})
- if auth_type == 'totp':
+ if auth_type == "totp":
if not key:
raise serializers.ValidationError({"errors": "缺少 key"})
@@ -163,18 +162,18 @@ def validate(self, attrs):
class TwoFAVerifySerializer(serializers.Serializer):
- engineer = serializers.CharField(label='用户名')
- otp = serializers.IntegerField(label='一次性密码/验证码')
- key = serializers.CharField(required=False, label='密钥')
- phone = serializers.CharField(required=False, label='手机号码')
- auth_type = serializers.CharField(label='验证方式')
+ engineer = serializers.CharField(label="用户名")
+ otp = serializers.IntegerField(label="一次性密码/验证码")
+ key = serializers.CharField(required=False, label="密钥")
+ phone = serializers.CharField(required=False, label="手机号码")
+ auth_type = serializers.CharField(label="验证方式")
def validate(self, attrs):
- engineer = attrs.get('engineer')
- auth_type = attrs.get('auth_type')
+ engineer = attrs.get("engineer")
+ auth_type = attrs.get("auth_type")
- if auth_type == 'sms':
- if not attrs.get('phone'):
+ if auth_type == "sms":
+ if not attrs.get("phone"):
raise serializers.ValidationError({"errors": "缺少 phone"})
try:
@@ -186,51 +185,33 @@ def validate(self, attrs):
class InstanceSerializer(serializers.ModelSerializer):
-
class Meta:
model = Instance
fields = "__all__"
- extra_kwargs = {
- 'password': {
- 'write_only': True
- }
- }
+ extra_kwargs = {"password": {"write_only": True}}
class InstanceDetailSerializer(serializers.ModelSerializer):
-
class Meta:
model = Instance
fields = "__all__"
extra_kwargs = {
- 'password': {
- 'write_only': True
- },
- 'instance_name': {
- 'required': False
- },
- 'type': {
- 'required': False
- },
- 'db_type': {
- 'required': False
- },
- 'host': {
- 'required': False
- }
+ "password": {"write_only": True},
+ "instance_name": {"required": False},
+ "type": {"required": False},
+ "db_type": {"required": False},
+ "host": {"required": False},
}
class TunnelSerializer(serializers.ModelSerializer):
-
class Meta:
model = Tunnel
fields = "__all__"
- write_only_fields = ['password', 'pkey', 'pkey_password']
+ write_only_fields = ["password", "pkey", "pkey_password"]
class CloudAccessKeySerializer(serializers.ModelSerializer):
-
class Meta:
model = CloudAccessKey
fields = "__all__"
@@ -241,7 +222,7 @@ class AliyunRdsSerializer(serializers.ModelSerializer):
def create(self, validated_data):
"""创建包含accesskey的aliyunrds实例"""
- rds_data = validated_data.pop('ak')
+ rds_data = validated_data.pop("ak")
try:
with transaction.atomic():
@@ -249,24 +230,26 @@ def create(self, validated_data):
rds = AliyunRdsConfig.objects.create(ak=ak, **validated_data)
except Exception as e:
logger.error(f"创建AliyunRds报错,错误信息:{traceback.format_exc()}")
- raise serializers.ValidationError({'errors': str(e)})
+ raise serializers.ValidationError({"errors": str(e)})
else:
return rds
class Meta:
model = AliyunRdsConfig
- fields = ('id', 'rds_dbinstanceid', 'is_enable', 'instance', 'ak')
+ fields = ("id", "rds_dbinstanceid", "is_enable", "instance", "ak")
class InstanceResourceSerializer(serializers.Serializer):
- instance_id = serializers.IntegerField(label='实例id')
- resource_type = serializers.ChoiceField(choices=['database', 'schema', 'table', 'column'], label='资源类型')
- db_name = serializers.CharField(required=False, label='数据库名')
- schema_name = serializers.CharField(required=False, label='schema名')
- tb_name = serializers.CharField(required=False, label='表名')
+ instance_id = serializers.IntegerField(label="实例id")
+ resource_type = serializers.ChoiceField(
+ choices=["database", "schema", "table", "column"], label="资源类型"
+ )
+ db_name = serializers.CharField(required=False, label="数据库名")
+ schema_name = serializers.CharField(required=False, label="schema名")
+ tb_name = serializers.CharField(required=False, label="表名")
def validate(self, attrs):
- instance_id = attrs.get('instance_id')
+ instance_id = attrs.get("instance_id")
try:
Instance.objects.get(id=instance_id)
@@ -282,9 +265,9 @@ class InstanceResourceListSerializer(serializers.Serializer):
class ExecuteCheckSerializer(serializers.Serializer):
- instance_id = serializers.IntegerField(label='实例id')
- db_name = serializers.CharField(label='数据库名')
- full_sql = serializers.CharField(label='SQL内容')
+ instance_id = serializers.IntegerField(label="实例id")
+ db_name = serializers.CharField(label="数据库名")
+ full_sql = serializers.CharField(label="SQL内容")
def validate_instance_id(self, instance_id):
try:
@@ -311,8 +294,8 @@ class ExecuteCheckResultSerializer(serializers.Serializer):
class WorkflowSerializer(serializers.ModelSerializer):
def validate(self, attrs):
- engineer = attrs.get('engineer')
- group_id = attrs.get('group_id')
+ engineer = attrs.get("engineer")
+ group_id = attrs.get("group_id")
try:
Users.objects.get(username=engineer)
@@ -329,13 +312,17 @@ def validate(self, attrs):
class Meta:
model = SqlWorkflow
fields = "__all__"
- read_only_fields = ['status', 'is_backup', 'syntax_type', 'audit_auth_groups', 'engineer_display',
- 'group_name', 'finish_time', 'is_manual']
- extra_kwargs = {
- 'demand_url': {
- 'required': False
- }
- }
+ read_only_fields = [
+ "status",
+ "is_backup",
+ "syntax_type",
+ "audit_auth_groups",
+ "engineer_display",
+ "group_name",
+ "finish_time",
+ "is_manual",
+ ]
+ extra_kwargs = {"demand_url": {"required": False}}
class WorkflowContentSerializer(serializers.ModelSerializer):
@@ -343,90 +330,119 @@ class WorkflowContentSerializer(serializers.ModelSerializer):
def create(self, validated_data):
"""使用原工单submit流程创建工单"""
- workflow_data = validated_data.pop('workflow')
- instance = workflow_data['instance']
- sql_content = validated_data['sql_content'].strip()
- user = Users.objects.get(username=workflow_data['engineer'])
- group = ResourceGroup.objects.get(pk=workflow_data['group_id'])
+ workflow_data = validated_data.pop("workflow")
+ instance = workflow_data["instance"]
+ sql_content = validated_data["sql_content"].strip()
+ user = Users.objects.get(username=workflow_data["engineer"])
+ group = ResourceGroup.objects.get(pk=workflow_data["group_id"])
active_user = Users.objects.filter(is_active=1)
# 验证组权限(用户是否在该组、该组是否有指定实例)
try:
- user_instances(user, tag_codes=['can_write']).get(id=instance.id)
+ user_instances(user, tag_codes=["can_write"]).get(id=instance.id)
except instance.DoesNotExist:
- raise serializers.ValidationError({'errors': '你所在组未关联该实例!'})
+ raise serializers.ValidationError({"errors": "你所在组未关联该实例!"})
# 再次交给engine进行检测,防止绕过
try:
check_engine = get_engine(instance=instance)
- check_result = check_engine.execute_check(db_name=workflow_data['db_name'],
- sql=sql_content)
+ check_result = check_engine.execute_check(
+ db_name=workflow_data["db_name"], sql=sql_content
+ )
except Exception as e:
- raise serializers.ValidationError({'errors': str(e)})
+ raise serializers.ValidationError({"errors": str(e)})
# 未开启备份选项,并且engine支持备份,强制设置备份
- is_backup = workflow_data['is_backup'] if 'is_backup' in workflow_data.keys() else False
+ is_backup = (
+ workflow_data["is_backup"] if "is_backup" in workflow_data.keys() else False
+ )
sys_config = SysConfig()
- if not sys_config.get('enable_backup_switch') and check_engine.auto_backup:
+ if not sys_config.get("enable_backup_switch") and check_engine.auto_backup:
is_backup = True
# 按照系统配置确定是自动驳回还是放行
- auto_review_wrong = sys_config.get('auto_review_wrong', '') # 1表示出现警告就驳回,2和空表示出现错误才驳回
- workflow_status = 'workflow_manreviewing'
- if check_result.warning_count > 0 and auto_review_wrong == '1':
- workflow_status = 'workflow_autoreviewwrong'
- elif check_result.error_count > 0 and auto_review_wrong in ('', '1', '2'):
- workflow_status = 'workflow_autoreviewwrong'
-
- workflow_data.update(status=workflow_status,
- is_backup=is_backup,
- is_manual=0,
- syntax_type=check_result.syntax_type,
- engineer_display=user.display,
- group_name=group.group_name,
- audit_auth_groups=Audit.settings(workflow_data['group_id'],
- WorkflowDict.workflow_type['sqlreview']))
+ auto_review_wrong = sys_config.get(
+ "auto_review_wrong", ""
+ ) # 1表示出现警告就驳回,2和空表示出现错误才驳回
+ workflow_status = "workflow_manreviewing"
+ if check_result.warning_count > 0 and auto_review_wrong == "1":
+ workflow_status = "workflow_autoreviewwrong"
+ elif check_result.error_count > 0 and auto_review_wrong in ("", "1", "2"):
+ workflow_status = "workflow_autoreviewwrong"
+
+ workflow_data.update(
+ status=workflow_status,
+ is_backup=is_backup,
+ is_manual=0,
+ syntax_type=check_result.syntax_type,
+ engineer_display=user.display,
+ group_name=group.group_name,
+ audit_auth_groups=Audit.settings(
+ workflow_data["group_id"], WorkflowDict.workflow_type["sqlreview"]
+ ),
+ )
try:
with transaction.atomic():
workflow = SqlWorkflow.objects.create(**workflow_data)
- validated_data['review_content'] = check_result.json()
- workflow_content = SqlWorkflowContent.objects.create(workflow=workflow, **validated_data)
+ validated_data["review_content"] = check_result.json()
+ workflow_content = SqlWorkflowContent.objects.create(
+ workflow=workflow, **validated_data
+ )
# 自动审核通过了,才调用工作流
- if workflow_status == 'workflow_manreviewing':
+ if workflow_status == "workflow_manreviewing":
# 调用工作流插入审核信息, SQL上线权限申请workflow_type=2
- Audit.add(WorkflowDict.workflow_type['sqlreview'], workflow.id)
+ Audit.add(WorkflowDict.workflow_type["sqlreview"], workflow.id)
except Exception as e:
logger.error(f"提交工单报错,错误信息:{traceback.format_exc()}")
- raise serializers.ValidationError({'errors': str(e)})
+ raise serializers.ValidationError({"errors": str(e)})
else:
# 自动审核通过且开启了Apply阶段通知参数才发送消息通知
- is_notified = 'Apply' in sys_config.get('notify_phase_control').split(',') \
- if sys_config.get('notify_phase_control') else True
- if workflow_status == 'workflow_manreviewing' and is_notified:
+ is_notified = (
+ "Apply" in sys_config.get("notify_phase_control").split(",")
+ if sys_config.get("notify_phase_control")
+ else True
+ )
+ if workflow_status == "workflow_manreviewing" and is_notified:
# 获取审核信息
- audit_id = Audit.detail_by_workflow_id(workflow_id=workflow.id,
- workflow_type=WorkflowDict.workflow_type['sqlreview']).audit_id
- async_task(notify_for_audit, audit_id=audit_id, cc_users=active_user, timeout=60,
- task_name=f'sqlreview-submit-{workflow.id}')
+ audit_id = Audit.detail_by_workflow_id(
+ workflow_id=workflow.id,
+ workflow_type=WorkflowDict.workflow_type["sqlreview"],
+ ).audit_id
+ async_task(
+ notify_for_audit,
+ audit_id=audit_id,
+ cc_users=active_user,
+ timeout=60,
+ task_name=f"sqlreview-submit-{workflow.id}",
+ )
return workflow_content
class Meta:
model = SqlWorkflowContent
- fields = ('id', 'workflow_id', 'workflow', 'sql_content', 'review_content', 'execute_result')
- read_only_fields = ['review_content', 'execute_result']
+ fields = (
+ "id",
+ "workflow_id",
+ "workflow",
+ "sql_content",
+ "review_content",
+ "execute_result",
+ )
+ read_only_fields = ["review_content", "execute_result"]
class AuditWorkflowSerializer(serializers.Serializer):
- engineer = serializers.CharField(label='操作用户')
- workflow_id = serializers.IntegerField(label='工单id')
- audit_remark = serializers.CharField(label='审批备注')
- workflow_type = serializers.ChoiceField(choices=[1, 2, 3], label='工单类型:1-查询权限申请,2-SQL上线申请,3-数据归档申请')
- audit_type = serializers.ChoiceField(choices=['pass', 'cancel'], label='审核类型')
+ engineer = serializers.CharField(label="操作用户")
+ workflow_id = serializers.IntegerField(label="工单id")
+ audit_remark = serializers.CharField(label="审批备注")
+ workflow_type = serializers.ChoiceField(
+ choices=[1, 2, 3], label="工单类型:1-查询权限申请,2-SQL上线申请,3-数据归档申请"
+ )
+ audit_type = serializers.ChoiceField(choices=["pass", "cancel"], label="审核类型")
def validate(self, attrs):
- engineer = attrs.get('engineer')
- workflow_id = attrs.get('workflow_id')
- workflow_type = attrs.get('workflow_type')
+ engineer = attrs.get("engineer")
+ workflow_id = attrs.get("workflow_id")
+ workflow_type = attrs.get("workflow_type")
try:
Users.objects.get(username=engineer)
@@ -434,7 +450,9 @@ def validate(self, attrs):
raise serializers.ValidationError({"errors": f"不存在该用户:{engineer}"})
try:
- WorkflowAudit.objects.get(workflow_id=workflow_id, workflow_type=workflow_type)
+ WorkflowAudit.objects.get(
+ workflow_id=workflow_id, workflow_type=workflow_type
+ )
except WorkflowAudit.DoesNotExist:
raise serializers.ValidationError({"errors": "不存在该工单"})
@@ -442,7 +460,7 @@ def validate(self, attrs):
class WorkflowAuditSerializer(serializers.Serializer):
- engineer = serializers.CharField(label='操作用户')
+ engineer = serializers.CharField(label="操作用户")
def validate_engineer(self, engineer):
try:
@@ -453,35 +471,51 @@ def validate_engineer(self, engineer):
class WorkflowAuditListSerializer(serializers.ModelSerializer):
-
class Meta:
model = WorkflowAudit
- exclude = ['group_id', 'workflow_id', 'workflow_remark', 'next_audit', 'create_user', 'sys_time']
+ exclude = [
+ "group_id",
+ "workflow_id",
+ "workflow_remark",
+ "next_audit",
+ "create_user",
+ "sys_time",
+ ]
class WorkflowLogSerializer(serializers.Serializer):
- workflow_id = serializers.IntegerField(label='工单id')
- workflow_type = serializers.ChoiceField(choices=[1, 2, 3], label='工单类型:1-查询权限申请,2-SQL上线申请,3-数据归档申请')
+ workflow_id = serializers.IntegerField(label="工单id")
+ workflow_type = serializers.ChoiceField(
+ choices=[1, 2, 3], label="工单类型:1-查询权限申请,2-SQL上线申请,3-数据归档申请"
+ )
class WorkflowLogListSerializer(serializers.ModelSerializer):
-
class Meta:
model = WorkflowLog
- fields = ['operation_type_desc', 'operation_info', 'operator_display', 'operation_time']
+ fields = [
+ "operation_type_desc",
+ "operation_info",
+ "operator_display",
+ "operation_time",
+ ]
class ExecuteWorkflowSerializer(serializers.Serializer):
- engineer = serializers.CharField(required=False, label='操作用户')
- workflow_id = serializers.IntegerField(label='工单id')
- workflow_type = serializers.ChoiceField(choices=[2, 3], label='工单类型:1-查询权限申请,2-SQL上线申请,3-数据归档申请')
- mode = serializers.ChoiceField(choices=['auto', 'manual'], label='执行模式:auto-线上执行,manual-已手动执行', required=False)
+ engineer = serializers.CharField(required=False, label="操作用户")
+ workflow_id = serializers.IntegerField(label="工单id")
+ workflow_type = serializers.ChoiceField(
+ choices=[2, 3], label="工单类型:1-查询权限申请,2-SQL上线申请,3-数据归档申请"
+ )
+ mode = serializers.ChoiceField(
+ choices=["auto", "manual"], label="执行模式:auto-线上执行,manual-已手动执行", required=False
+ )
def validate(self, attrs):
- engineer = attrs.get('engineer')
- workflow_id = attrs.get('workflow_id')
- workflow_type = attrs.get('workflow_type')
- mode = attrs.get('mode')
+ engineer = attrs.get("engineer")
+ workflow_id = attrs.get("workflow_id")
+ workflow_type = attrs.get("workflow_type")
+ mode = attrs.get("mode")
# SQL上线工单的mode和engineer为必需字段
if workflow_type == 2:
@@ -496,7 +530,9 @@ def validate(self, attrs):
raise serializers.ValidationError({"errors": f"不存在该用户:{engineer}"})
try:
- WorkflowAudit.objects.get(workflow_id=workflow_id, workflow_type=workflow_type)
+ WorkflowAudit.objects.get(
+ workflow_id=workflow_id, workflow_type=workflow_type
+ )
except WorkflowAudit.DoesNotExist:
raise serializers.ValidationError({"errors": "不存在该工单"})
diff --git a/sql_api/tests.py b/sql_api/tests.py
index 80b6ebe870..4c93e6c31a 100644
--- a/sql_api/tests.py
+++ b/sql_api/tests.py
@@ -5,9 +5,20 @@
from rest_framework.test import APITestCase
from rest_framework import status
from common.config import SysConfig
-from sql.models import ResourceGroup, Instance, AliyunRdsConfig, CloudAccessKey, Tunnel, \
- SqlWorkflow, SqlWorkflowContent, WorkflowAudit, WorkflowLog, InstanceTag, WorkflowAuditSetting, \
- TwoFactorAuthConfig
+from sql.models import (
+ ResourceGroup,
+ Instance,
+ AliyunRdsConfig,
+ CloudAccessKey,
+ Tunnel,
+ SqlWorkflow,
+ SqlWorkflowContent,
+ WorkflowAudit,
+ WorkflowLog,
+ InstanceTag,
+ WorkflowAuditSetting,
+ TwoFactorAuthConfig,
+)
import json
User = get_user_model()
@@ -15,36 +26,40 @@
class InfoTest(TestCase):
def setUp(self) -> None:
- self.superuser = User.objects.create(username='super', is_superuser=True)
+ self.superuser = User.objects.create(username="super", is_superuser=True)
self.client.force_login(self.superuser)
def tearDown(self) -> None:
self.superuser.delete()
def test_info_api(self):
- r = self.client.get('/api/info')
+ r = self.client.get("/api/info")
r_json = r.json()
- self.assertIsInstance(r_json['archery']['version'], str)
+ self.assertIsInstance(r_json["archery"]["version"], str)
def test_debug_api(self):
- r = self.client.get('/api/debug')
+ r = self.client.get("/api/debug")
r_json = r.json()
- self.assertIsInstance(r_json['archery']['version'], str)
+ self.assertIsInstance(r_json["archery"]["version"], str)
class TestUser(APITestCase):
"""测试用户相关接口"""
def setUp(self):
- self.user = User(username='test_user', display='测试用户', is_active=True)
- self.user.set_password('test_password')
+ self.user = User(username="test_user", display="测试用户", is_active=True)
+ self.user.set_password("test_password")
self.user.save()
- self.group = Group.objects.create(id=1, name='DBA')
- self.res_group = ResourceGroup.objects.create(group_id=1, group_name='test')
- r = self.client.post('/api/auth/token/', {'username': 'test_user', 'password': 'test_password'}, format='json')
- self.token = r.data['access']
- self.client.credentials(HTTP_AUTHORIZATION='Bearer ' + self.token)
- SysConfig().set('api_user_whitelist', self.user.id)
+ self.group = Group.objects.create(id=1, name="DBA")
+ self.res_group = ResourceGroup.objects.create(group_id=1, group_name="test")
+ r = self.client.post(
+ "/api/auth/token/",
+ {"username": "test_user", "password": "test_password"},
+ format="json",
+ )
+ self.token = r.data["access"]
+ self.client.credentials(HTTP_AUTHORIZATION="Bearer " + self.token)
+ SysConfig().set("api_user_whitelist", self.user.id)
def tearDown(self):
self.user.delete()
@@ -54,131 +69,124 @@ def tearDown(self):
def test_user_not_in_whitelist(self):
"""测试api用户白名单参数"""
- SysConfig().set('api_user_whitelist', '')
- r = self.client.get('/api/v1/user/', format='json')
+ SysConfig().set("api_user_whitelist", "")
+ r = self.client.get("/api/v1/user/", format="json")
self.assertEqual(r.status_code, status.HTTP_403_FORBIDDEN)
- self.assertDictEqual(r.json(), {'detail': '您没有执行该操作的权限。'})
+ self.assertDictEqual(r.json(), {"detail": "您没有执行该操作的权限。"})
def test_get_user_list(self):
"""测试获取用户清单"""
- r = self.client.get('/api/v1/user/', format='json')
+ r = self.client.get("/api/v1/user/", format="json")
self.assertEqual(r.status_code, status.HTTP_200_OK)
- self.assertEqual(r.json()['count'], 1)
+ self.assertEqual(r.json()["count"], 1)
def test_create_user(self):
"""测试创建用户"""
json_data = {
- 'username': 'test_user2',
- 'password': 'test_password2',
- 'display': '测试用户2'
+ "username": "test_user2",
+ "password": "test_password2",
+ "display": "测试用户2",
}
- r = self.client.post('/api/v1/user/', json_data, format='json')
+ r = self.client.post("/api/v1/user/", json_data, format="json")
self.assertEqual(r.status_code, status.HTTP_201_CREATED)
- self.assertEqual(r.json()['username'], 'test_user2')
+ self.assertEqual(r.json()["username"], "test_user2")
def test_update_user(self):
"""测试更新用户"""
- json_data = {
- 'display': '更新中文名'
- }
- r = self.client.put(f'/api/v1/user/{self.user.id}/', json_data, format='json')
+ json_data = {"display": "更新中文名"}
+ r = self.client.put(f"/api/v1/user/{self.user.id}/", json_data, format="json")
user = User.objects.get(pk=self.user.id)
self.assertEqual(r.status_code, status.HTTP_200_OK)
- self.assertEqual(user.display, '更新中文名')
+ self.assertEqual(user.display, "更新中文名")
def test_delete_user(self):
"""测试删除用户"""
json_data = {
- 'username': 'test_user2',
- 'password': 'test_password2',
- 'display': '测试用户2'
+ "username": "test_user2",
+ "password": "test_password2",
+ "display": "测试用户2",
}
- r1 = self.client.post('/api/v1/user/', json_data, format='json')
- r2 = self.client.delete(f'/api/v1/user/{r1.json()["id"]}/', format='json')
+ r1 = self.client.post("/api/v1/user/", json_data, format="json")
+ r2 = self.client.delete(f'/api/v1/user/{r1.json()["id"]}/', format="json")
self.assertEqual(r2.status_code, status.HTTP_204_NO_CONTENT)
- self.assertEqual(User.objects.filter(username='test_user2').count(), 0)
+ self.assertEqual(User.objects.filter(username="test_user2").count(), 0)
def test_get_user_group_list(self):
"""测试获取用户组清单"""
- r = self.client.get('/api/v1/user/group/', format='json')
+ r = self.client.get("/api/v1/user/group/", format="json")
self.assertEqual(r.status_code, status.HTTP_200_OK)
- self.assertEqual(r.json()['count'], 1)
+ self.assertEqual(r.json()["count"], 1)
def test_create_user_group(self):
"""测试创建用户组"""
- json_data = {
- 'name': 'RD'
- }
- r = self.client.post('/api/v1/user/group/', json_data, format='json')
+ json_data = {"name": "RD"}
+ r = self.client.post("/api/v1/user/group/", json_data, format="json")
self.assertEqual(r.status_code, status.HTTP_201_CREATED)
- self.assertEqual(r.json()['name'], 'RD')
+ self.assertEqual(r.json()["name"], "RD")
def test_update_user_group(self):
"""测试更新用户组"""
- json_data = {
- 'name': '更新用户组名称'
- }
- r = self.client.put(f'/api/v1/user/group/{self.group.id}/', json_data, format='json')
+ json_data = {"name": "更新用户组名称"}
+ r = self.client.put(
+ f"/api/v1/user/group/{self.group.id}/", json_data, format="json"
+ )
group = Group.objects.get(pk=self.group.id)
self.assertEqual(r.status_code, status.HTTP_200_OK)
- self.assertEqual(group.name, '更新用户组名称')
+ self.assertEqual(group.name, "更新用户组名称")
def test_delete_user_group(self):
"""测试删除用户组"""
- r = self.client.delete(f'/api/v1/user/group/{self.group.id}/', format='json')
+ r = self.client.delete(f"/api/v1/user/group/{self.group.id}/", format="json")
self.assertEqual(r.status_code, status.HTTP_204_NO_CONTENT)
- self.assertEqual(Group.objects.filter(name='DBA').count(), 0)
+ self.assertEqual(Group.objects.filter(name="DBA").count(), 0)
def test_get_resource_group_list(self):
"""测试获取资源组清单"""
- r = self.client.get('/api/v1/user/resourcegroup/', format='json')
+ r = self.client.get("/api/v1/user/resourcegroup/", format="json")
self.assertEqual(r.status_code, status.HTTP_200_OK)
- self.assertEqual(r.json()['count'], 1)
+ self.assertEqual(r.json()["count"], 1)
def test_create_resource_group(self):
"""测试创建资源组"""
json_data = {
- 'group_name': 'prod',
- 'ding_webhook': 'https://oapi.dingtalk.com/robot/send?access_token=123'
+ "group_name": "prod",
+ "ding_webhook": "https://oapi.dingtalk.com/robot/send?access_token=123",
}
- r = self.client.post('/api/v1/user/resourcegroup/', json_data, format='json')
+ r = self.client.post("/api/v1/user/resourcegroup/", json_data, format="json")
self.assertEqual(r.status_code, status.HTTP_201_CREATED)
- self.assertEqual(r.json()['group_name'], 'prod')
+ self.assertEqual(r.json()["group_name"], "prod")
def test_update_resource_group(self):
"""测试更新资源组"""
- json_data = {
- 'group_name': '更新资源组名称'
- }
- r = self.client.put(f'/api/v1/user/resourcegroup/{self.res_group.group_id}/', json_data, format='json')
+ json_data = {"group_name": "更新资源组名称"}
+ r = self.client.put(
+ f"/api/v1/user/resourcegroup/{self.res_group.group_id}/",
+ json_data,
+ format="json",
+ )
group = ResourceGroup.objects.get(pk=self.res_group.group_id)
self.assertEqual(r.status_code, status.HTTP_200_OK)
- self.assertEqual(group.group_name, '更新资源组名称')
+ self.assertEqual(group.group_name, "更新资源组名称")
def test_delete_resource_group(self):
"""测试删除资源组"""
- r = self.client.delete(f'/api/v1/user/resourcegroup/{self.res_group.group_id}/', format='json')
+ r = self.client.delete(
+ f"/api/v1/user/resourcegroup/{self.res_group.group_id}/", format="json"
+ )
self.assertEqual(r.status_code, status.HTTP_204_NO_CONTENT)
- self.assertEqual(Group.objects.filter(name='test').count(), 0)
+ self.assertEqual(Group.objects.filter(name="test").count(), 0)
def test_user_auth(self):
"""测试用户认证校验"""
- json_data = {
- "engineer": "test_user",
- "password": "test_password"
- }
- r = self.client.post(f'/api/v1/user/auth/', json_data, format='json')
+ json_data = {"engineer": "test_user", "password": "test_password"}
+ r = self.client.post(f"/api/v1/user/auth/", json_data, format="json")
self.assertEqual(r.status_code, status.HTTP_200_OK)
- self.assertEqual(r.json(), {'status': 0, 'msg': '认证成功'})
+ self.assertEqual(r.json(), {"status": 0, "msg": "认证成功"})
def test_2fa_config(self):
"""测试用户配置2fa"""
- json_data = {
- "engineer": "test_user",
- "auth_type": "totp",
- "enable": "false"
- }
- r = self.client.post(f'/api/v1/user/2fa/', json_data, format='json')
+ json_data = {"engineer": "test_user", "auth_type": "totp", "enable": "false"}
+ r = self.client.post(f"/api/v1/user/2fa/", json_data, format="json")
self.assertEqual(r.status_code, status.HTTP_200_OK)
self.assertEqual(TwoFactorAuthConfig.objects.count(), 0)
@@ -187,9 +195,9 @@ def test_2fa_save(self):
json_data = {
"engineer": "test_user",
"auth_type": "totp",
- "key": "ZUGRIJZP6H7LIOAL4LH5JA4GSXXT3WOK"
+ "key": "ZUGRIJZP6H7LIOAL4LH5JA4GSXXT3WOK",
}
- r = self.client.post(f'/api/v1/user/2fa/save/', json_data, format='json')
+ r = self.client.post(f"/api/v1/user/2fa/save/", json_data, format="json")
self.assertEqual(r.status_code, status.HTTP_200_OK)
self.assertEqual(TwoFactorAuthConfig.objects.count(), 1)
@@ -199,29 +207,46 @@ def test_2fa_verify(self):
"engineer": "test_user",
"otp": 123456,
"key": "ZUGRIJZP6H7LIOAL4LH5JA4GSXXT3WOK",
- "auth_type": "totp"
+ "auth_type": "totp",
}
- r = self.client.post(f'/api/v1/user/2fa/verify/', json_data, format='json')
+ r = self.client.post(f"/api/v1/user/2fa/verify/", json_data, format="json")
self.assertEqual(r.status_code, status.HTTP_200_OK)
- self.assertEqual(r.json()['status'], 1)
+ self.assertEqual(r.json()["status"], 1)
class TestInstance(APITestCase):
"""测试实例相关接口"""
def setUp(self):
- self.user = User(username='test_user', display='测试用户', is_active=True)
- self.user.set_password('test_password')
+ self.user = User(username="test_user", display="测试用户", is_active=True)
+ self.user.set_password("test_password")
self.user.save()
- self.ins = Instance.objects.create(instance_name='some_ins', type='slave', db_type='mysql',
- host='some_host', port=3306, user='ins_user', password='some_str')
- self.ak = CloudAccessKey.objects.create(type='aliyun', key_id='abc', key_secret='abc')
- self.rds = AliyunRdsConfig.objects.create(rds_dbinstanceid='abc', ak_id=self.ak.id, instance=self.ins)
- self.tunnel = Tunnel.objects.create(tunnel_name='one_tunnel', host='one_host', port=22)
- r = self.client.post('/api/auth/token/', {'username': 'test_user', 'password': 'test_password'}, format='json')
- self.token = r.data['access']
- self.client.credentials(HTTP_AUTHORIZATION='Bearer ' + self.token)
- SysConfig().set('api_user_whitelist', self.user.id)
+ self.ins = Instance.objects.create(
+ instance_name="some_ins",
+ type="slave",
+ db_type="mysql",
+ host="some_host",
+ port=3306,
+ user="ins_user",
+ password="some_str",
+ )
+ self.ak = CloudAccessKey.objects.create(
+ type="aliyun", key_id="abc", key_secret="abc"
+ )
+ self.rds = AliyunRdsConfig.objects.create(
+ rds_dbinstanceid="abc", ak_id=self.ak.id, instance=self.ins
+ )
+ self.tunnel = Tunnel.objects.create(
+ tunnel_name="one_tunnel", host="one_host", port=22
+ )
+ r = self.client.post(
+ "/api/auth/token/",
+ {"username": "test_user", "password": "test_password"},
+ format="json",
+ )
+ self.token = r.data["access"]
+ self.client.credentials(HTTP_AUTHORIZATION="Bearer " + self.token)
+ SysConfig().set("api_user_whitelist", self.user.id)
def tearDown(self):
self.user.delete()
@@ -233,49 +258,54 @@ def tearDown(self):
def test_get_instance_list(self):
"""测试获取实例清单"""
- r = self.client.get('/api/v1/instance/', format='json')
+ r = self.client.get("/api/v1/instance/", format="json")
self.assertEqual(r.status_code, status.HTTP_200_OK)
- self.assertEqual(r.json()['count'], 1)
+ self.assertEqual(r.json()["count"], 1)
def test_create_instance(self):
"""测试创建实例"""
json_data = {
- 'instance_name': 'test_ins',
- 'type': 'master',
- 'db_type': 'mysql',
- 'host': 'some_host',
- 'port': 3306
+ "instance_name": "test_ins",
+ "type": "master",
+ "db_type": "mysql",
+ "host": "some_host",
+ "port": 3306,
}
- r = self.client.post('/api/v1/instance/', json_data, format='json')
+ r = self.client.post("/api/v1/instance/", json_data, format="json")
self.assertEqual(r.status_code, status.HTTP_201_CREATED)
- self.assertEqual(r.json()['instance_name'], 'test_ins')
+ self.assertEqual(r.json()["instance_name"], "test_ins")
def test_update_instance(self):
"""测试更新实例"""
- json_data = {
- 'instance_name': '更新实例名称'
- }
- r = self.client.put(f'/api/v1/instance/{self.ins.id}/', json_data, format='json')
+ json_data = {"instance_name": "更新实例名称"}
+ r = self.client.put(
+ f"/api/v1/instance/{self.ins.id}/", json_data, format="json"
+ )
ins = Instance.objects.get(pk=self.ins.id)
self.assertEqual(r.status_code, status.HTTP_200_OK)
- self.assertEqual(ins.instance_name, '更新实例名称')
+ self.assertEqual(ins.instance_name, "更新实例名称")
def test_delete_instance(self):
"""测试删除实例"""
- r = self.client.delete(f'/api/v1/instance/{self.ins.id}/', format='json')
+ r = self.client.delete(f"/api/v1/instance/{self.ins.id}/", format="json")
self.assertEqual(r.status_code, status.HTTP_204_NO_CONTENT)
- self.assertEqual(Instance.objects.filter(instance_name='some_ins').count(), 0)
+ self.assertEqual(Instance.objects.filter(instance_name="some_ins").count(), 0)
def test_get_aliyunrds_list(self):
"""测试获取aliyunrds清单"""
- r = self.client.get('/api/v1/instance/rds/', format='json')
+ r = self.client.get("/api/v1/instance/rds/", format="json")
self.assertEqual(r.status_code, status.HTTP_200_OK)
- self.assertEqual(r.json()['count'], 1)
+ self.assertEqual(r.json()["count"], 1)
def test_create_aliyunrds(self):
"""测试创建aliyunrds"""
- ins = Instance.objects.create(instance_name='another_ins', type='slave', db_type='mysql',
- host='another_host', port=3306)
+ ins = Instance.objects.create(
+ instance_name="another_ins",
+ type="slave",
+ db_type="mysql",
+ host="another_host",
+ port=3306,
+ )
json_data = {
"rds_dbinstanceid": "bbc",
"is_enable": True,
@@ -284,29 +314,25 @@ def test_create_aliyunrds(self):
"type": "aliyun",
"key_id": "bbc",
"key_secret": "bbc",
- "remark": "bbc"
- }
+ "remark": "bbc",
+ },
}
- r = self.client.post('/api/v1/instance/rds/', json_data, format='json')
+ r = self.client.post("/api/v1/instance/rds/", json_data, format="json")
self.assertEqual(r.status_code, status.HTTP_201_CREATED)
- self.assertEqual(r.json()['rds_dbinstanceid'], 'bbc')
+ self.assertEqual(r.json()["rds_dbinstanceid"], "bbc")
def test_get_tunnel_list(self):
"""测试获取隧道清单"""
- r = self.client.get('/api/v1/instance/tunnel/', format='json')
+ r = self.client.get("/api/v1/instance/tunnel/", format="json")
self.assertEqual(r.status_code, status.HTTP_200_OK)
- self.assertEqual(r.json()['count'], 1)
+ self.assertEqual(r.json()["count"], 1)
def test_create_tunnel(self):
"""测试创建隧道"""
- json_data = {
- "tunnel_name": "tunnel_test",
- "host": "one_host",
- "port": 22
- }
- r = self.client.post('/api/v1/instance/tunnel/', json_data, format='json')
+ json_data = {"tunnel_name": "tunnel_test", "host": "one_host", "port": 22}
+ r = self.client.post("/api/v1/instance/tunnel/", json_data, format="json")
self.assertEqual(r.status_code, status.HTTP_201_CREATED)
- self.assertEqual(r.json()['tunnel_name'], 'tunnel_test')
+ self.assertEqual(r.json()["tunnel_name"], "tunnel_test")
class TestWorkflow(APITestCase):
@@ -314,63 +340,82 @@ class TestWorkflow(APITestCase):
def setUp(self):
self.now = datetime.now()
- self.group = Group.objects.create(id=1, name='DBA')
- self.res_group = ResourceGroup.objects.create(group_id=1, group_name='test')
- self.ins_tag = InstanceTag.objects.create(tag_code='can_write', active=1)
- self.wfs = WorkflowAuditSetting.objects.create(group_id=self.res_group.group_id,
- workflow_type=2, audit_auth_groups=self.group.id)
- can_execute_permission = Permission.objects.get(codename='sql_execute')
- can_execute_resource_permission = Permission.objects.get(codename='sql_execute_for_resource_group')
- can_review_permission = Permission.objects.get(codename='sql_review')
- self.user = User(username='test_user', display='测试用户', is_active=True)
- self.user.set_password('test_password')
+ self.group = Group.objects.create(id=1, name="DBA")
+ self.res_group = ResourceGroup.objects.create(group_id=1, group_name="test")
+ self.ins_tag = InstanceTag.objects.create(tag_code="can_write", active=1)
+ self.wfs = WorkflowAuditSetting.objects.create(
+ group_id=self.res_group.group_id,
+ workflow_type=2,
+ audit_auth_groups=self.group.id,
+ )
+ can_execute_permission = Permission.objects.get(codename="sql_execute")
+ can_execute_resource_permission = Permission.objects.get(
+ codename="sql_execute_for_resource_group"
+ )
+ can_review_permission = Permission.objects.get(codename="sql_review")
+ self.user = User(username="test_user", display="测试用户", is_active=True)
+ self.user.set_password("test_password")
self.user.save()
- self.user.user_permissions.add(can_execute_permission, can_execute_resource_permission, can_review_permission)
+ self.user.user_permissions.add(
+ can_execute_permission,
+ can_execute_resource_permission,
+ can_review_permission,
+ )
self.user.groups.add(self.group.id)
self.user.resource_group.add(self.res_group.group_id)
- self.ins = Instance.objects.create(instance_name='some_ins', type='slave', db_type='redis',
- host='some_host', port=6379, user='ins_user', password='some_str')
+ self.ins = Instance.objects.create(
+ instance_name="some_ins",
+ type="slave",
+ db_type="redis",
+ host="some_host",
+ port=6379,
+ user="ins_user",
+ password="some_str",
+ )
self.ins.resource_group.add(self.res_group.group_id)
self.ins.instance_tag.add(self.ins_tag.id)
self.wf1 = SqlWorkflow.objects.create(
- workflow_name='some_name',
+ workflow_name="some_name",
group_id=1,
- group_name='g1',
+ group_name="g1",
engineer=self.user.username,
engineer_display=self.user.display,
- audit_auth_groups='1',
+ audit_auth_groups="1",
create_time=self.now - timedelta(days=1),
- status='workflow_manreviewing',
+ status="workflow_manreviewing",
is_backup=False,
instance=self.ins,
- db_name='some_db',
+ db_name="some_db",
syntax_type=1,
)
self.wfc1 = SqlWorkflowContent.objects.create(
workflow=self.wf1,
- sql_content='some_sql',
- execute_result=json.dumps([{
- 'id': 1,
- 'sql': 'some_content'
- }])
+ sql_content="some_sql",
+ execute_result=json.dumps([{"id": 1, "sql": "some_content"}]),
)
self.audit1 = WorkflowAudit.objects.create(
group_id=1,
- group_name='some_group',
+ group_name="some_group",
workflow_id=self.wf1.id,
workflow_type=2,
- workflow_title='申请标题',
- workflow_remark='申请备注',
- audit_auth_groups='1',
- current_audit='1',
- next_audit='-1',
- current_status=0)
- self.wl = WorkflowLog.objects.create(audit_id=self.audit1.audit_id,
- operation_type=1)
- r = self.client.post('/api/auth/token/', {'username': 'test_user', 'password': 'test_password'}, format='json')
- self.token = r.data['access']
- self.client.credentials(HTTP_AUTHORIZATION='Bearer ' + self.token)
- SysConfig().set('api_user_whitelist', self.user.id)
+ workflow_title="申请标题",
+ workflow_remark="申请备注",
+ audit_auth_groups="1",
+ current_audit="1",
+ next_audit="-1",
+ current_status=0,
+ )
+ self.wl = WorkflowLog.objects.create(
+ audit_id=self.audit1.audit_id, operation_type=1
+ )
+ r = self.client.post(
+ "/api/auth/token/",
+ {"username": "test_user", "password": "test_password"},
+ format="json",
+ )
+ self.token = r.data["access"]
+ self.client.credentials(HTTP_AUTHORIZATION="Bearer " + self.token)
+ SysConfig().set("api_user_whitelist", self.user.id)
def tearDown(self):
self.user.delete()
@@ -383,45 +428,43 @@ def tearDown(self):
def test_get_sql_workflow_list(self):
"""测试获取SQL上线工单列表"""
- r = self.client.get('/api/v1/workflow/', format='json')
+ r = self.client.get("/api/v1/workflow/", format="json")
self.assertEqual(r.status_code, status.HTTP_200_OK)
- self.assertEqual(r.json()['count'], 1)
+ self.assertEqual(r.json()["count"], 1)
def test_get_audit_list(self):
"""测试获取待审核工单列表"""
- json_data = {
- "engineer": "test_user"
- }
- r = self.client.post('/api/v1/workflow/auditlist/', json_data, format='json')
+ json_data = {"engineer": "test_user"}
+ r = self.client.post("/api/v1/workflow/auditlist/", json_data, format="json")
self.assertEqual(r.status_code, status.HTTP_200_OK)
- self.assertEqual(r.json()['count'], 1)
+ self.assertEqual(r.json()["count"], 1)
def test_get_workflow_log_list(self):
"""测试获工单日志"""
json_data = {
"workflow_id": self.wf1.id,
- "workflow_type": self.audit1.workflow_type
+ "workflow_type": self.audit1.workflow_type,
}
- r = self.client.post('/api/v1/workflow/log/', json_data, format='json')
+ r = self.client.post("/api/v1/workflow/log/", json_data, format="json")
self.assertEqual(r.status_code, status.HTTP_200_OK)
- self.assertEqual(r.json()['count'], 1)
+ self.assertEqual(r.json()["count"], 1)
def test_submit_workflow(self):
"""测试提交SQL上线工单"""
json_data = {
"workflow": {
- "workflow_name": "上线工单1",
- "demand_url": "test",
- "group_id": 1,
- "db_name": "test_db",
- "engineer": self.user.username,
- "instance": self.ins.id
+ "workflow_name": "上线工单1",
+ "demand_url": "test",
+ "group_id": 1,
+ "db_name": "test_db",
+ "engineer": self.user.username,
+ "instance": self.ins.id,
},
- "sql_content": "alter table abc add column note varchar(64);"
+ "sql_content": "alter table abc add column note varchar(64);",
}
- r = self.client.post('/api/v1/workflow/', json_data, format='json')
+ r = self.client.post("/api/v1/workflow/", json_data, format="json")
self.assertEqual(r.status_code, status.HTTP_201_CREATED)
- self.assertEqual(r.json()['workflow']['workflow_name'], '上线工单1')
+ self.assertEqual(r.json()["workflow"]["workflow_name"], "上线工单1")
def test_audit_workflow(self):
"""测试审核工单"""
@@ -430,11 +473,11 @@ def test_audit_workflow(self):
"workflow_id": self.wf1.id,
"audit_remark": "取消",
"workflow_type": self.audit1.workflow_type,
- "audit_type": "cancel"
+ "audit_type": "cancel",
}
- r = self.client.post('/api/v1/workflow/audit/', json_data, format='json')
+ r = self.client.post("/api/v1/workflow/audit/", json_data, format="json")
self.assertEqual(r.status_code, status.HTTP_200_OK)
- self.assertEqual(r.json(), {'msg': 'canceled'})
+ self.assertEqual(r.json(), {"msg": "canceled"})
def test_execute_workflow(self):
"""测试执行工单"""
@@ -444,16 +487,16 @@ def test_execute_workflow(self):
"workflow_id": self.wf1.id,
"audit_remark": "通过",
"workflow_type": self.audit1.workflow_type,
- "audit_type": "pass"
+ "audit_type": "pass",
}
- self.client.post('/api/v1/workflow/audit/', audit_data, format='json')
+ self.client.post("/api/v1/workflow/audit/", audit_data, format="json")
# 再执行
execute_data = {
"engineer": self.user.username,
"workflow_id": self.wf1.id,
"workflow_type": self.audit1.workflow_type,
- "mode": "manual"
+ "mode": "manual",
}
- r = self.client.post('/api/v1/workflow/execute/', execute_data, format='json')
+ r = self.client.post("/api/v1/workflow/execute/", execute_data, format="json")
self.assertEqual(r.status_code, status.HTTP_200_OK)
- self.assertEqual(r.json(), {'msg': '开始执行,执行结果请到工单详情页查看'})
+ self.assertEqual(r.json(), {"msg": "开始执行,执行结果请到工单详情页查看"})
diff --git a/sql_api/urls.py b/sql_api/urls.py
index a4c064b33b..1c098e6b64 100644
--- a/sql_api/urls.py
+++ b/sql_api/urls.py
@@ -1,43 +1,57 @@
from django.urls import path, include
from sql_api import views
from rest_framework import routers
-from rest_framework_simplejwt.views import TokenObtainPairView, TokenRefreshView, TokenVerifyView
-from drf_spectacular.views import SpectacularAPIView, SpectacularRedocView, SpectacularSwaggerView
+from rest_framework_simplejwt.views import (
+ TokenObtainPairView,
+ TokenRefreshView,
+ TokenVerifyView,
+)
+from drf_spectacular.views import (
+ SpectacularAPIView,
+ SpectacularRedocView,
+ SpectacularSwaggerView,
+)
from . import api_user, api_instance, api_workflow
router = routers.DefaultRouter()
urlpatterns = [
- path('v1/', include(router.urls)),
- path('auth/token/', TokenObtainPairView.as_view(), name='token_obtain_pair'),
- path('auth/token/refresh/', TokenRefreshView.as_view(), name='token_refresh'),
- path('auth/token/verify/', TokenVerifyView.as_view(), name='token_verify'),
- path('schema/', SpectacularAPIView.as_view(), name='schema'),
- path('swagger/', SpectacularSwaggerView.as_view(url_name='sql_api:schema'), name='swagger'),
- path('redoc/', SpectacularRedocView.as_view(url_name='sql_api:schema'), name='redoc'),
- path('v1/user/', api_user.UserList.as_view()),
- path('v1/user//', api_user.UserDetail.as_view()),
- path('v1/user/group/', api_user.GroupList.as_view()),
- path('v1/user/group//', api_user.GroupDetail.as_view()),
- path('v1/user/resourcegroup/', api_user.ResourceGroupList.as_view()),
- path('v1/user/resourcegroup//', api_user.ResourceGroupDetail.as_view()),
- path('v1/user/auth/', api_user.UserAuth.as_view()),
- path('v1/user/2fa/', api_user.TwoFA.as_view()),
- path('v1/user/2fa/state/', api_user.TwoFAState.as_view()),
- path('v1/user/2fa/save/', api_user.TwoFASave.as_view()),
- path('v1/user/2fa/verify/', api_user.TwoFAVerify.as_view()),
- path('v1/instance/', api_instance.InstanceList.as_view()),
- path('v1/instance//', api_instance.InstanceDetail.as_view()),
- path('v1/instance/resource/', api_instance.InstanceResource.as_view()),
- path('v1/instance/tunnel/', api_instance.TunnelList.as_view()),
- path('v1/instance/rds/', api_instance.AliyunRdsList.as_view()),
- path('v1/workflow/', api_workflow.WorkflowList.as_view()),
- path('v1/workflow/sqlcheck/', api_workflow.ExecuteCheck.as_view()),
- path('v1/workflow/audit/', api_workflow.AuditWorkflow.as_view()),
- path('v1/workflow/auditlist/', api_workflow.WorkflowAuditList.as_view()),
- path('v1/workflow/execute/', api_workflow.ExecuteWorkflow.as_view()),
- path('v1/workflow/log/', api_workflow.WorkflowLogList.as_view()),
- path('info', views.info),
- path('debug', views.debug),
- path('do_once/mirage', views.mirage)
+ path("v1/", include(router.urls)),
+ path("auth/token/", TokenObtainPairView.as_view(), name="token_obtain_pair"),
+ path("auth/token/refresh/", TokenRefreshView.as_view(), name="token_refresh"),
+ path("auth/token/verify/", TokenVerifyView.as_view(), name="token_verify"),
+ path("schema/", SpectacularAPIView.as_view(), name="schema"),
+ path(
+ "swagger/",
+ SpectacularSwaggerView.as_view(url_name="sql_api:schema"),
+ name="swagger",
+ ),
+ path(
+ "redoc/", SpectacularRedocView.as_view(url_name="sql_api:schema"), name="redoc"
+ ),
+ path("v1/user/", api_user.UserList.as_view()),
+ path("v1/user//", api_user.UserDetail.as_view()),
+ path("v1/user/group/", api_user.GroupList.as_view()),
+ path("v1/user/group//", api_user.GroupDetail.as_view()),
+ path("v1/user/resourcegroup/", api_user.ResourceGroupList.as_view()),
+ path("v1/user/resourcegroup//", api_user.ResourceGroupDetail.as_view()),
+ path("v1/user/auth/", api_user.UserAuth.as_view()),
+ path("v1/user/2fa/", api_user.TwoFA.as_view()),
+ path("v1/user/2fa/state/", api_user.TwoFAState.as_view()),
+ path("v1/user/2fa/save/", api_user.TwoFASave.as_view()),
+ path("v1/user/2fa/verify/", api_user.TwoFAVerify.as_view()),
+ path("v1/instance/", api_instance.InstanceList.as_view()),
+ path("v1/instance//", api_instance.InstanceDetail.as_view()),
+ path("v1/instance/resource/", api_instance.InstanceResource.as_view()),
+ path("v1/instance/tunnel/", api_instance.TunnelList.as_view()),
+ path("v1/instance/rds/", api_instance.AliyunRdsList.as_view()),
+ path("v1/workflow/", api_workflow.WorkflowList.as_view()),
+ path("v1/workflow/sqlcheck/", api_workflow.ExecuteCheck.as_view()),
+ path("v1/workflow/audit/", api_workflow.AuditWorkflow.as_view()),
+ path("v1/workflow/auditlist/", api_workflow.WorkflowAuditList.as_view()),
+ path("v1/workflow/execute/", api_workflow.ExecuteWorkflow.as_view()),
+ path("v1/workflow/log/", api_workflow.WorkflowLogList.as_view()),
+ path("info", views.info),
+ path("debug", views.debug),
+ path("do_once/mirage", views.mirage),
]
diff --git a/sql_api/views.py b/sql_api/views.py
index ec8e17e43a..5e470f5bc6 100644
--- a/sql_api/views.py
+++ b/sql_api/views.py
@@ -23,15 +23,13 @@
def info(request):
# 获取django_q信息
- django_q_version = '.'.join(str(i) for i in django_q.VERSION)
+ django_q_version = ".".join(str(i) for i in django_q.VERSION)
system_info = {
- 'archery': {
- 'version': archery.display_version
+ "archery": {"version": archery.display_version},
+ "django_q": {
+ "version": django_q_version,
},
- 'django_q': {
- 'version': django_q_version,
- }
}
return JsonResponse(system_info)
@@ -39,24 +37,24 @@ def info(request):
@superuser_required
def debug(request):
# 获取完整信息
- full = request.GET.get('full')
+ full = request.GET.get("full")
# 系统配置
sys_config = SysConfig().sys_config
# 敏感信息处理
secret_keys = [
- 'inception_remote_backup_password',
- 'ding_app_secret',
- 'feishu_app_secret',
- 'mail_smtp_password'
+ "inception_remote_backup_password",
+ "ding_app_secret",
+ "feishu_app_secret",
+ "mail_smtp_password",
]
sys_config.update({k: "******" for k in secret_keys})
# MySQL信息
cursor = connection.cursor()
mysql_info = {
- 'mysql_server_info': cursor.db.mysql_server_info,
- 'timezone_name': cursor.db.timezone_name
+ "mysql_server_info": cursor.db.mysql_server_info,
+ "timezone_name": cursor.db.timezone_name,
}
# Redis信息
@@ -64,30 +62,30 @@ def debug(request):
redis_conn = get_redis_connection("default")
full_redis_info = redis_conn.info()
redis_info = {
- 'redis_version': full_redis_info.get('redis_version'),
- 'redis_mode': full_redis_info.get('redis_mode'),
- 'role': full_redis_info.get('role'),
- 'maxmemory_human': full_redis_info.get('maxmemory_human'),
- 'used_memory_human': full_redis_info.get('used_memory_human'),
+ "redis_version": full_redis_info.get("redis_version"),
+ "redis_mode": full_redis_info.get("redis_mode"),
+ "role": full_redis_info.get("role"),
+ "maxmemory_human": full_redis_info.get("maxmemory_human"),
+ "used_memory_human": full_redis_info.get("used_memory_human"),
}
except Exception as e:
- redis_info = f'获取Redis信息报错:{e}'
+ redis_info = f"获取Redis信息报错:{e}"
full_redis_info = redis_info
# django_q
try:
- django_q_version = '.'.join(str(i) for i in django_q.VERSION)
+ django_q_version = ".".join(str(i) for i in django_q.VERSION)
broker = get_broker()
stats = Stat.get_all(broker=broker)
queue_size = broker.queue_size()
lock_size = broker.lock_size()
if lock_size:
- queue_size = '{}({})'.format(queue_size, lock_size)
+ queue_size = "{}({})".format(queue_size, lock_size)
q_broker_stats = {
- 'info': broker.info(),
- 'Queued': queue_size,
- 'Success': Success.objects.count(),
- 'Failures': Failure.objects.count(),
+ "info": broker.info(),
+ "Queued": queue_size,
+ "Success": Success.objects.count(),
+ "Failures": Failure.objects.count(),
}
q_cluster_stats = []
for stat in stats:
@@ -95,93 +93,104 @@ def debug(request):
uptime = (timezone.now() - stat.tob).total_seconds()
hours, remainder = divmod(uptime, 3600)
minutes, seconds = divmod(remainder, 60)
- uptime = '%d:%02d:%02d' % (hours, minutes, seconds)
- q_cluster_stats.append({
- 'host': stat.host,
- 'cluster_id': stat.cluster_id,
- 'state': stat.status,
- 'pool': len(stat.workers),
- 'tq': stat.task_q_size,
- 'rq': stat.done_q_size,
- 'rc': stat.reincarnations,
- 'up': uptime
- })
+ uptime = "%d:%02d:%02d" % (hours, minutes, seconds)
+ q_cluster_stats.append(
+ {
+ "host": stat.host,
+ "cluster_id": stat.cluster_id,
+ "state": stat.status,
+ "pool": len(stat.workers),
+ "tq": stat.task_q_size,
+ "rq": stat.done_q_size,
+ "rc": stat.reincarnations,
+ "up": uptime,
+ }
+ )
django_q_info = {
- 'version': django_q_version,
- 'conf': django_q.conf.Conf.conf,
- 'q_cluster_stats': q_cluster_stats if q_cluster_stats else '没有正在运行的集群信息,请检查django_q状态',
- 'q_broker_stats': q_broker_stats
+ "version": django_q_version,
+ "conf": django_q.conf.Conf.conf,
+ "q_cluster_stats": q_cluster_stats
+ if q_cluster_stats
+ else "没有正在运行的集群信息,请检查django_q状态",
+ "q_broker_stats": q_broker_stats,
}
except Exception as e:
- django_q_info = f'获取django_q信息报错:{e}'
+ django_q_info = f"获取django_q信息报错:{e}"
# Inception和goInception信息
- go_inception_host = sys_config.get('go_inception_host')
- go_inception_port = sys_config.get('go_inception_port', 0)
- inception_remote_backup_host = sys_config.get('inception_remote_backup_host', '')
- inception_remote_backup_port = sys_config.get('inception_remote_backup_port', '')
- inception_remote_backup_user = sys_config.get('inception_remote_backup_user', '')
- inception_remote_backup_password = sys_config.get('inception_remote_backup_password', '')
+ go_inception_host = sys_config.get("go_inception_host")
+ go_inception_port = sys_config.get("go_inception_port", 0)
+ inception_remote_backup_host = sys_config.get("inception_remote_backup_host", "")
+ inception_remote_backup_port = sys_config.get("inception_remote_backup_port", "")
+ inception_remote_backup_user = sys_config.get("inception_remote_backup_user", "")
+ inception_remote_backup_password = sys_config.get(
+ "inception_remote_backup_password", ""
+ )
# goInception
try:
- goinc_conn = MySQLdb.connect(host=go_inception_host, port=int(go_inception_port),
- connect_timeout=1, cursorclass=MySQLdb.cursors.DictCursor)
+ goinc_conn = MySQLdb.connect(
+ host=go_inception_host,
+ port=int(go_inception_port),
+ connect_timeout=1,
+ cursorclass=MySQLdb.cursors.DictCursor,
+ )
cursor = goinc_conn.cursor()
- cursor.execute('inception get variables')
+ cursor.execute("inception get variables")
rows = cursor.fetchall()
full_goinception_info = dict()
for row in rows:
- full_goinception_info[row.get('Variable_name')] = row.get('Value')
+ full_goinception_info[row.get("Variable_name")] = row.get("Value")
goinception_info = {
- 'version': full_goinception_info.get('version'),
- 'max_allowed_packet': full_goinception_info.get('max_allowed_packet'),
- 'lang': full_goinception_info.get('lang'),
- 'osc_on': full_goinception_info.get('osc_on'),
- 'osc_bin_dir': full_goinception_info.get('osc_bin_dir'),
- 'ghost_on': full_goinception_info.get('ghost_on'),
+ "version": full_goinception_info.get("version"),
+ "max_allowed_packet": full_goinception_info.get("max_allowed_packet"),
+ "lang": full_goinception_info.get("lang"),
+ "osc_on": full_goinception_info.get("osc_on"),
+ "osc_bin_dir": full_goinception_info.get("osc_bin_dir"),
+ "ghost_on": full_goinception_info.get("ghost_on"),
}
except Exception as e:
- goinception_info = f'获取goInception信息报错:{e}'
+ goinception_info = f"获取goInception信息报错:{e}"
full_goinception_info = goinception_info
# 备份库
try:
- bak_conn = MySQLdb.connect(host=inception_remote_backup_host,
- port=int(inception_remote_backup_port),
- user=inception_remote_backup_user,
- password=inception_remote_backup_password,
- connect_timeout=1)
+ bak_conn = MySQLdb.connect(
+ host=inception_remote_backup_host,
+ port=int(inception_remote_backup_port),
+ user=inception_remote_backup_user,
+ password=inception_remote_backup_password,
+ connect_timeout=1,
+ )
cursor = bak_conn.cursor()
- cursor.execute('select 1;')
- backup_info = 'normal'
+ cursor.execute("select 1;")
+ backup_info = "normal"
except Exception as e:
- backup_info = f'无法连接goInception备份库\n{e}'
+ backup_info = f"无法连接goInception备份库\n{e}"
# PACKAGES
installed_packages = pkg_resources.working_set
- installed_packages_list = sorted([
- "%s==%s" % (i.key, i.version) for i in installed_packages])
+ installed_packages_list = sorted(
+ ["%s==%s" % (i.key, i.version) for i in installed_packages]
+ )
# 最终集合
system_info = {
- 'archery': {
- 'version': archery.display_version
- },
- 'django_q': django_q_info,
- 'inception': {
- 'goinception_info': full_goinception_info if full else goinception_info,
- 'backup_info': backup_info
+ "archery": {"version": archery.display_version},
+ "django_q": django_q_info,
+ "inception": {
+ "goinception_info": full_goinception_info if full else goinception_info,
+ "backup_info": backup_info,
},
- 'runtime_info': {
- 'python_version': platform.python_version(),
- 'mysql_info': mysql_info,
- 'redis_info': full_redis_info if full else redis_info,
- 'sys_argv': sys.argv,
- 'platform': platform.uname()
+ "runtime_info": {
+ "python_version": platform.python_version(),
+ "mysql_info": mysql_info,
+ "redis_info": full_redis_info if full else redis_info,
+ "sys_argv": sys.argv,
+ "platform": platform.uname(),
},
- 'sys_config': sys_config,
- 'packages': installed_packages_list
+ "sys_config": sys_config,
+ "packages": installed_packages_list,
}
return JsonResponse(system_info)
@@ -197,7 +206,9 @@ def mirage(request):
for ins in Instance.objects.all():
# 忽略解密错误的数据(本身为异常数据)
try:
- Instance(pk=ins.pk, password=pc.decrypt(ins.password)).save(update_fields=['password'])
+ Instance(pk=ins.pk, password=pc.decrypt(ins.password)).save(
+ update_fields=["password"]
+ )
except:
pass
# 使用django-mirage-field重新加密