From d34d8fcef18aaa5fd9f9439d536f09ff3dde64e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B0=8F=E5=9C=88=E5=9C=88?= Date: Wed, 13 Jul 2022 09:10:24 +0800 Subject: [PATCH 1/3] Create black.yml MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 增加代码格式化检测 --- .github/workflows/black.yml | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 .github/workflows/black.yml diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml new file mode 100644 index 0000000000..b04fb15cb2 --- /dev/null +++ b/.github/workflows/black.yml @@ -0,0 +1,10 @@ +name: Lint + +on: [push, pull_request] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: psf/black@stable From 520d2f1ad9d4cc723c76785a8352f91c8ee68723 Mon Sep 17 00:00:00 2001 From: hhyo Date: Wed, 13 Jul 2022 23:37:45 +0800 Subject: [PATCH 2/3] black format --- archery/__init__.py | 2 +- archery/asgi.py | 2 +- archery/settings.py | 312 +- archery/urls.py | 6 +- common/auth.py | 130 +- common/check.py | 121 +- common/config.py | 43 +- common/dashboard.py | 143 +- common/middleware/check_login_middleware.py | 22 +- .../exception_logging_middleware.py | 3 +- common/storage.py | 2 +- common/tests.py | 509 +-- common/twofa/__init__.py | 11 +- common/twofa/sms.py | 54 +- common/twofa/totp.py | 39 +- common/utils/aes_decryptor.py | 26 +- common/utils/aliyun_sdk.py | 50 +- common/utils/aliyun_sms.py | 33 +- common/utils/chart_dao.py | 65 +- common/utils/const.py | 111 +- common/utils/convert.py | 8 +- common/utils/ding_api.py | 60 +- common/utils/extend_json_encoder.py | 18 +- common/utils/feishu_api.py | 35 +- common/utils/global_info.py | 16 +- common/utils/permission.py | 14 +- common/utils/sendmsg.py | 178 +- common/utils/tencent_sms.py | 37 +- common/utils/timer.py | 2 +- common/utils/wx_api.py | 16 +- common/views.py | 14 +- common/workflow.py | 65 +- sql/admin.py | 462 ++- sql/aliyun_rds.py | 139 +- sql/archiver.py | 373 +- sql/audit_log.py | 107 +- sql/binlog.py | 208 +- sql/data_dictionary.py | 120 +- sql/db_diagnostic.py | 224 +- sql/engines/__init__.py | 26 +- sql/engines/clickhouse.py | 376 +- sql/engines/goinception.py | 158 +- sql/engines/models.py | 93 +- sql/engines/mongo.py | 600 ++- sql/engines/mssql.py | 218 +- sql/engines/mysql.py | 341 +- sql/engines/odps.py | 48 +- sql/engines/oracle.py | 859 +++-- sql/engines/pgsql.py | 196 +- sql/engines/phoenix.py | 130 +- sql/engines/redis.py | 186 +- sql/engines/tests.py | 1946 ++++++---- sql/form.py | 15 +- sql/instance.py | 328 +- sql/instance_account.py | 294 +- sql/instance_database.py | 111 +- sql/models.py | 1268 ++++--- sql/notify.py | 280 +- sql/plugins/my2sql.py | 49 +- sql/plugins/plugin.py | 38 +- sql/plugins/pt_archiver.py | 41 +- sql/plugins/schemasync.py | 29 +- sql/plugins/soar.py | 66 +- sql/plugins/sqladvisor.py | 12 +- sql/plugins/tests.py | 216 +- sql/query.py | 232 +- sql/query_privileges.py | 415 ++- sql/resource_group.py | 208 +- sql/slowlog.py | 246 +- sql/sql_analyze.py | 68 +- sql/sql_optimize.py | 280 +- sql/sql_tuning.py | 133 +- sql/sql_workflow.py | 630 ++-- sql/templatetags/format_tags.py | 6 +- sql/tests.py | 3291 ++++++++++------- sql/urls.py | 301 +- sql/user.py | 12 +- sql/utils/data_masking.py | 106 +- sql/utils/execute_sql.py | 90 +- sql/utils/extract_tables.py | 30 +- sql/utils/resource_group.py | 13 +- sql/utils/sql_review.py | 65 +- sql/utils/sql_utils.py | 114 +- sql/utils/ssh_tunnel.py | 17 +- sql/utils/tasks.py | 40 +- sql/utils/tests.py | 1585 +++++--- sql/utils/workflow_audit.py | 358 +- sql/views.py | 465 ++- sql_api/api_instance.py | 164 +- sql_api/api_user.py | 258 +- sql_api/api_workflow.py | 468 ++- sql_api/apps.py | 2 +- sql_api/filters.py | 36 +- sql_api/pagination.py | 27 +- sql_api/permissions.py | 8 +- sql_api/serializers.py | 370 +- sql_api/tests.py | 419 ++- sql_api/urls.py | 82 +- sql_api/views.py | 179 +- 99 files changed, 13520 insertions(+), 8902 deletions(-) diff --git a/archery/__init__.py b/archery/__init__.py index 980bd0c637..c1d56d39c8 100644 --- a/archery/__init__.py +++ b/archery/__init__.py @@ -1,2 +1,2 @@ version = (1, 9, 0) -display_version = '.'.join(str(i) for i in version) +display_version = ".".join(str(i) for i in version) diff --git a/archery/asgi.py b/archery/asgi.py index 52d4205405..656f0fa9d0 100644 --- a/archery/asgi.py +++ b/archery/asgi.py @@ -11,6 +11,6 @@ from django.core.asgi import get_asgi_application -os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'archery.settings') +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "archery.settings") application = get_asgi_application() diff --git a/archery/settings.py b/archery/settings.py index 16a996258f..f4da941302 100644 --- a/archery/settings.py +++ b/archery/settings.py @@ -9,24 +9,23 @@ BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -environ.Env.read_env(os.path.join(BASE_DIR, '.env')) +environ.Env.read_env(os.path.join(BASE_DIR, ".env")) env = environ.Env( DEBUG=(bool, False), ALLOWED_HOSTS=(List[str], ["*"]), - SECRET_KEY=(str, 'hfusaf2m4ot#7)fkw#di2bu6(cv0@opwmafx5n#6=3d%x^hpl6'), + SECRET_KEY=(str, "hfusaf2m4ot#7)fkw#di2bu6(cv0@opwmafx5n#6=3d%x^hpl6"), DATABASE_URL=(str, "mysql://root:@127.0.0.1:3306/archery"), CACHE_URL=(str, "redis://127.0.0.1:6379/0"), DINGDING_CACHE_URL=(str, "redis://127.0.0.1:6379/1"), ENABLE_LDAP=(bool, False), AUTH_LDAP_ALWAYS_UPDATE_USER=(bool, True), - AUTH_LDAP_USER_ATTR_MAP=(dict, { - "username": "cn", - "display": "displayname", - "email": "mail" - }), + AUTH_LDAP_USER_ATTR_MAP=( + dict, + {"username": "cn", "display": "displayname", "email": "mail"}, + ), Q_CLUISTER_SYNC=(bool, False), # qcluster 同步模式, debug 时可以调整为 True # CSRF_TRUSTED_ORIGINS=subdomain.example.com,subdomain.example2.com subdomain.example.com - CSRF_TRUSTED_ORIGINS=(list, []) + CSRF_TRUSTED_ORIGINS=(list, []), ) # SECURITY WARNING: keep the secret key used in production secret! @@ -48,59 +47,59 @@ # Application definition INSTALLED_APPS = ( - 'django.contrib.admin', - 'django.contrib.auth', - 'django.contrib.contenttypes', - 'django.contrib.sessions', - 'django.contrib.messages', - 'django.contrib.staticfiles', - 'django_q', - 'sql', - 'sql_api', - 'common', - 'rest_framework', - 'django_filters', - 'drf_spectacular', + "django.contrib.admin", + "django.contrib.auth", + "django.contrib.contenttypes", + "django.contrib.sessions", + "django.contrib.messages", + "django.contrib.staticfiles", + "django_q", + "sql", + "sql_api", + "common", + "rest_framework", + "django_filters", + "drf_spectacular", ) MIDDLEWARE = ( - 'django.contrib.sessions.middleware.SessionMiddleware', - 'django.middleware.common.CommonMiddleware', - 'django.middleware.csrf.CsrfViewMiddleware', - 'django.contrib.auth.middleware.AuthenticationMiddleware', - 'django.contrib.messages.middleware.MessageMiddleware', - 'django.middleware.clickjacking.XFrameOptionsMiddleware', - 'django.middleware.security.SecurityMiddleware', - 'django.middleware.gzip.GZipMiddleware', - 'common.middleware.check_login_middleware.CheckLoginMiddleware', - 'common.middleware.exception_logging_middleware.ExceptionLoggingMiddleware', + "django.contrib.sessions.middleware.SessionMiddleware", + "django.middleware.common.CommonMiddleware", + "django.middleware.csrf.CsrfViewMiddleware", + "django.contrib.auth.middleware.AuthenticationMiddleware", + "django.contrib.messages.middleware.MessageMiddleware", + "django.middleware.clickjacking.XFrameOptionsMiddleware", + "django.middleware.security.SecurityMiddleware", + "django.middleware.gzip.GZipMiddleware", + "common.middleware.check_login_middleware.CheckLoginMiddleware", + "common.middleware.exception_logging_middleware.ExceptionLoggingMiddleware", ) -ROOT_URLCONF = 'archery.urls' +ROOT_URLCONF = "archery.urls" TEMPLATES = [ { - 'BACKEND': 'django.template.backends.django.DjangoTemplates', - 'DIRS': [os.path.join(BASE_DIR, 'common/templates')], - 'APP_DIRS': True, - 'OPTIONS': { - 'context_processors': [ - 'django.template.context_processors.debug', - 'django.template.context_processors.request', - 'django.contrib.auth.context_processors.auth', - 'django.contrib.messages.context_processors.messages', - 'common.utils.global_info.global_info', + "BACKEND": "django.template.backends.django.DjangoTemplates", + "DIRS": [os.path.join(BASE_DIR, "common/templates")], + "APP_DIRS": True, + "OPTIONS": { + "context_processors": [ + "django.template.context_processors.debug", + "django.template.context_processors.request", + "django.contrib.auth.context_processors.auth", + "django.contrib.messages.context_processors.messages", + "common.utils.global_info.global_info", ], }, }, ] -WSGI_APPLICATION = 'archery.wsgi.application' +WSGI_APPLICATION = "archery.wsgi.application" # Internationalization -LANGUAGE_CODE = 'zh-hans' +LANGUAGE_CODE = "zh-hans" -TIME_ZONE = 'Asia/Shanghai' +TIME_ZONE = "Asia/Shanghai" USE_I18N = True @@ -108,14 +107,16 @@ # 时间格式化 USE_L10N = False -DATETIME_FORMAT = 'Y-m-d H:i:s' -DATE_FORMAT = 'Y-m-d' +DATETIME_FORMAT = "Y-m-d H:i:s" +DATE_FORMAT = "Y-m-d" # Static files (CSS, JavaScript, Images) -STATIC_URL = '/static/' -STATIC_ROOT = os.path.join(BASE_DIR, 'static') -STATICFILES_DIRS = [os.path.join(BASE_DIR, 'common/static'), ] -STATICFILES_STORAGE = 'common.storage.ForgivingManifestStaticFilesStorage' +STATIC_URL = "/static/" +STATIC_ROOT = os.path.join(BASE_DIR, "static") +STATICFILES_DIRS = [ + os.path.join(BASE_DIR, "common/static"), +] +STATICFILES_STORAGE = "common.storage.ForgivingManifestStaticFilesStorage" # 扩展django admin里users字段用到,指定了sql/models.py里的class users AUTH_USER_MODEL = "sql.Users" @@ -123,19 +124,19 @@ # 密码校验 AUTH_PASSWORD_VALIDATORS = [ { - 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', + "NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator", }, { - 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', - 'OPTIONS': { - 'min_length': 9, - } + "NAME": "django.contrib.auth.password_validation.MinimumLengthValidator", + "OPTIONS": { + "min_length": 9, + }, }, { - 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', + "NAME": "django.contrib.auth.password_validation.CommonPasswordValidator", }, { - 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', + "NAME": "django.contrib.auth.password_validation.NumericPasswordValidator", }, ] @@ -148,36 +149,36 @@ # 该项目本身的mysql数据库地址 DATABASES = { - 'default': { + "default": { **env.db(), **{ - 'DEFAULT_CHARSET': 'utf8mb4', - 'CONN_MAX_AGE': 50, - 'OPTIONS': { - 'init_command': "SET sql_mode='STRICT_TRANS_TABLES'", - 'charset': 'utf8mb4' + "DEFAULT_CHARSET": "utf8mb4", + "CONN_MAX_AGE": 50, + "OPTIONS": { + "init_command": "SET sql_mode='STRICT_TRANS_TABLES'", + "charset": "utf8mb4", }, - 'TEST': { - 'NAME': 'test_archery', - 'CHARSET': 'utf8mb4', - } - } + "TEST": { + "NAME": "test_archery", + "CHARSET": "utf8mb4", + }, + }, } } # Django-Q Q_CLUSTER = { - 'name': 'archery', - 'workers': 4, - 'recycle': 500, - 'timeout': 60, - 'compress': True, - 'cpu_affinity': 1, - 'save_limit': 0, - 'queue_limit': 50, - 'label': 'Django Q', - 'django_redis': 'default', - 'sync': env("Q_CLUISTER_SYNC") # 本地调试可以修改为True,使用同步模式 + "name": "archery", + "workers": 4, + "recycle": 500, + "timeout": 60, + "compress": True, + "cpu_affinity": 1, + "save_limit": 0, + "queue_limit": 50, + "label": "Django Q", + "django_redis": "default", + "sync": env("Q_CLUISTER_SYNC"), # 本地调试可以修改为True,使用同步模式 } # 缓存配置 @@ -186,49 +187,46 @@ } # https://docs.djangoproject.com/en/3.2/ref/settings/#std-setting-DEFAULT_AUTO_FIELD -DEFAULT_AUTO_FIELD = 'django.db.models.AutoField' +DEFAULT_AUTO_FIELD = "django.db.models.AutoField" # API Framework REST_FRAMEWORK = { - 'DEFAULT_SCHEMA_CLASS': 'drf_spectacular.openapi.AutoSchema', - 'DEFAULT_RENDERER_CLASSES': ('rest_framework.renderers.JSONRenderer',), + "DEFAULT_SCHEMA_CLASS": "drf_spectacular.openapi.AutoSchema", + "DEFAULT_RENDERER_CLASSES": ("rest_framework.renderers.JSONRenderer",), # 鉴权 - 'DEFAULT_AUTHENTICATION_CLASSES': ( - 'rest_framework_simplejwt.authentication.JWTAuthentication', - 'rest_framework.authentication.SessionAuthentication', + "DEFAULT_AUTHENTICATION_CLASSES": ( + "rest_framework_simplejwt.authentication.JWTAuthentication", + "rest_framework.authentication.SessionAuthentication", ), # 权限 - 'DEFAULT_PERMISSION_CLASSES': ('sql_api.permissions.IsInUserWhitelist',), + "DEFAULT_PERMISSION_CLASSES": ("sql_api.permissions.IsInUserWhitelist",), # 限速(anon:未认证用户 user:认证用户) - 'DEFAULT_THROTTLE_CLASSES': ( - 'rest_framework.throttling.AnonRateThrottle', - 'rest_framework.throttling.UserRateThrottle', + "DEFAULT_THROTTLE_CLASSES": ( + "rest_framework.throttling.AnonRateThrottle", + "rest_framework.throttling.UserRateThrottle", ), - 'DEFAULT_THROTTLE_RATES': { - 'anon': '120/min', - 'user': '600/min' - }, + "DEFAULT_THROTTLE_RATES": {"anon": "120/min", "user": "600/min"}, # 过滤 - 'DEFAULT_FILTER_BACKENDS': ('django_filters.rest_framework.DjangoFilterBackend',), + "DEFAULT_FILTER_BACKENDS": ("django_filters.rest_framework.DjangoFilterBackend",), # 分页 - 'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.PageNumberPagination', - 'PAGE_SIZE': 5, + "DEFAULT_PAGINATION_CLASS": "rest_framework.pagination.PageNumberPagination", + "PAGE_SIZE": 5, } # Swagger UI SPECTACULAR_SETTINGS = { - 'TITLE': 'Archery API', - 'DESCRIPTION': 'OpenAPI 3.0', - 'VERSION': '1.0.0', + "TITLE": "Archery API", + "DESCRIPTION": "OpenAPI 3.0", + "VERSION": "1.0.0", } # API Authentication SIMPLE_JWT = { - 'ACCESS_TOKEN_LIFETIME': timedelta(hours=4), - 'REFRESH_TOKEN_LIFETIME': timedelta(days=3), - 'ALGORITHM': 'HS256', - 'SIGNING_KEY': SECRET_KEY, - 'AUTH_HEADER_TYPES': ('Bearer',), + "ACCESS_TOKEN_LIFETIME": timedelta(hours=4), + "REFRESH_TOKEN_LIFETIME": timedelta(days=3), + "ALGORITHM": "HS256", + "SIGNING_KEY": SECRET_KEY, + "AUTH_HEADER_TYPES": ("Bearer",), } # LDAP @@ -238,68 +236,78 @@ from django_auth_ldap.config import LDAPSearch AUTHENTICATION_BACKENDS = ( - 'django_auth_ldap.backend.LDAPBackend', # 配置为先使用LDAP认证,如通过认证则不再使用后面的认证方式 - 'django.contrib.auth.backends.ModelBackend', # django系统中手动创建的用户也可使用,优先级靠后。注意这2行的顺序 + "django_auth_ldap.backend.LDAPBackend", # 配置为先使用LDAP认证,如通过认证则不再使用后面的认证方式 + "django.contrib.auth.backends.ModelBackend", # django系统中手动创建的用户也可使用,优先级靠后。注意这2行的顺序 ) AUTH_LDAP_SERVER_URI = env("AUTH_LDAP_SERVER_URI", default="ldap://xxx") AUTH_LDAP_USER_DN_TEMPLATE = env("AUTH_LDAP_USER_DN_TEMPLATE", default=None) if not AUTH_LDAP_USER_DN_TEMPLATE: del AUTH_LDAP_USER_DN_TEMPLATE - AUTH_LDAP_BIND_DN = env("AUTH_LDAP_BIND_DN", default="cn=xxx,ou=xxx,dc=xxx,dc=xxx") + AUTH_LDAP_BIND_DN = env( + "AUTH_LDAP_BIND_DN", default="cn=xxx,ou=xxx,dc=xxx,dc=xxx" + ) AUTH_LDAP_BIND_PASSWORD = env("AUTH_LDAP_BIND_PASSWORD", default="***********") - AUTH_LDAP_USER_SEARCH_BASE = env("AUTH_LDAP_USER_SEARCH_BASE", default="ou=xxx,dc=xxx,dc=xxx") - AUTH_LDAP_USER_SEARCH_FILTER = env("AUTH_LDAP_USER_SEARCH_FILTER", default='(cn=%(user)s)') - AUTH_LDAP_USER_SEARCH = LDAPSearch(AUTH_LDAP_USER_SEARCH_BASE, ldap.SCOPE_SUBTREE, AUTH_LDAP_USER_SEARCH_FILTER) - AUTH_LDAP_ALWAYS_UPDATE_USER = env("AUTH_LDAP_ALWAYS_UPDATE_USER", default=True) # 每次登录从ldap同步用户信息 + AUTH_LDAP_USER_SEARCH_BASE = env( + "AUTH_LDAP_USER_SEARCH_BASE", default="ou=xxx,dc=xxx,dc=xxx" + ) + AUTH_LDAP_USER_SEARCH_FILTER = env( + "AUTH_LDAP_USER_SEARCH_FILTER", default="(cn=%(user)s)" + ) + AUTH_LDAP_USER_SEARCH = LDAPSearch( + AUTH_LDAP_USER_SEARCH_BASE, ldap.SCOPE_SUBTREE, AUTH_LDAP_USER_SEARCH_FILTER + ) + AUTH_LDAP_ALWAYS_UPDATE_USER = env( + "AUTH_LDAP_ALWAYS_UPDATE_USER", default=True + ) # 每次登录从ldap同步用户信息 AUTH_LDAP_USER_ATTR_MAP = env("AUTH_LDAP_USER_ATTR_MAP") # LOG配置 LOGGING = { - 'version': 1, - 'disable_existing_loggers': False, - 'formatters': { - 'verbose': { - 'format': '[%(asctime)s][%(threadName)s:%(thread)d][task_id:%(name)s][%(filename)s:%(lineno)d][%(levelname)s]- %(message)s' + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "verbose": { + "format": "[%(asctime)s][%(threadName)s:%(thread)d][task_id:%(name)s][%(filename)s:%(lineno)d][%(levelname)s]- %(message)s" }, }, - 'handlers': { - 'default': { - 'level': 'DEBUG', - 'class': 'logging.handlers.RotatingFileHandler', - 'filename': 'logs/archery.log', - 'maxBytes': 1024 * 1024 * 100, # 5 MB - 'backupCount': 5, - 'formatter': 'verbose', + "handlers": { + "default": { + "level": "DEBUG", + "class": "logging.handlers.RotatingFileHandler", + "filename": "logs/archery.log", + "maxBytes": 1024 * 1024 * 100, # 5 MB + "backupCount": 5, + "formatter": "verbose", + }, + "django-q": { + "level": "DEBUG", + "class": "logging.handlers.RotatingFileHandler", + "filename": "logs/qcluster.log", + "maxBytes": 1024 * 1024 * 100, # 5 MB + "backupCount": 5, + "formatter": "verbose", }, - 'django-q': { - 'level': 'DEBUG', - 'class': 'logging.handlers.RotatingFileHandler', - 'filename': 'logs/qcluster.log', - 'maxBytes': 1024 * 1024 * 100, # 5 MB - 'backupCount': 5, - 'formatter': 'verbose', + "console": { + "level": "DEBUG", + "class": "logging.StreamHandler", + "formatter": "verbose", }, - 'console': { - 'level': 'DEBUG', - 'class': 'logging.StreamHandler', - 'formatter': 'verbose' - } }, - 'loggers': { - 'default': { # default日志 - 'handlers': ['console', 'default'], - 'level': 'WARNING' + "loggers": { + "default": { # default日志 + "handlers": ["console", "default"], + "level": "WARNING", }, - 'django-q': { # django_q模块相关日志 - 'handlers': ['console', 'django-q'], - 'level': 'WARNING', - 'propagate': False + "django-q": { # django_q模块相关日志 + "handlers": ["console", "django-q"], + "level": "WARNING", + "propagate": False, }, - 'django_auth_ldap': { # django_auth_ldap模块相关日志 - 'handlers': ['console', 'default'], - 'level': 'WARNING', - 'propagate': False + "django_auth_ldap": { # django_auth_ldap模块相关日志 + "handlers": ["console", "default"], + "level": "WARNING", + "propagate": False, }, # 'django.db': { # 打印SQL语句,方便开发 # 'handlers': ['console', 'default'], @@ -311,14 +319,14 @@ # 'level': 'DEBUG', # 'propagate': False # }, - } + }, } -MEDIA_ROOT = os.path.join(BASE_DIR, 'media') +MEDIA_ROOT = os.path.join(BASE_DIR, "media") if not os.path.exists(MEDIA_ROOT): os.mkdir(MEDIA_ROOT) -PKEY_ROOT = os.path.join(MEDIA_ROOT, 'keys') +PKEY_ROOT = os.path.join(MEDIA_ROOT, "keys") if not os.path.exists(PKEY_ROOT): os.mkdir(PKEY_ROOT) diff --git a/archery/urls.py b/archery/urls.py index e1e35c40e4..89d8caf351 100644 --- a/archery/urls.py +++ b/archery/urls.py @@ -3,9 +3,9 @@ from common import views urlpatterns = [ - path('admin/', admin.site.urls), - path('api/', include(('sql_api.urls', 'sql_api'), namespace="sql_api")), - path('', include(('sql.urls', 'sql'), namespace="sql")), + path("admin/", admin.site.urls), + path("api/", include(("sql_api.urls", "sql_api"), namespace="sql_api")), + path("", include(("sql.urls", "sql"), namespace="sql")), ] handler400 = views.bad_request diff --git a/common/auth.py b/common/auth.py index f26e081c84..3bb227be6a 100644 --- a/common/auth.py +++ b/common/auth.py @@ -15,7 +15,7 @@ from common.utils.ding_api import get_ding_user_id from sql.models import Users, ResourceGroup, TwoFactorAuthConfig -logger = logging.getLogger('default') +logger = logging.getLogger("default") def init_user(user): @@ -25,17 +25,24 @@ def init_user(user): :return: """ # 添加到默认权限组 - default_auth_group = SysConfig().get('default_auth_group', '') + default_auth_group = SysConfig().get("default_auth_group", "") if default_auth_group: - default_auth_group = default_auth_group.split(',') - [user.groups.add(group) for group in Group.objects.filter(name__in=default_auth_group)] + default_auth_group = default_auth_group.split(",") + [ + user.groups.add(group) + for group in Group.objects.filter(name__in=default_auth_group) + ] # 添加到默认资源组 - default_resource_group = SysConfig().get('default_resource_group', '') + default_resource_group = SysConfig().get("default_resource_group", "") if default_resource_group: - default_resource_group = default_resource_group.split(',') - [user.resource_group.add(group) for group in - ResourceGroup.objects.filter(group_name__in=default_resource_group)] + default_resource_group = default_resource_group.split(",") + [ + user.resource_group.add(group) + for group in ResourceGroup.objects.filter( + group_name__in=default_resource_group + ) + ] class ArcheryAuth(object): @@ -55,8 +62,8 @@ def challenge(username=None, password=None): return user def authenticate(self): - username = self.request.POST.get('username') - password = self.request.POST.get('password') + username = self.request.POST.get("username") + password = self.request.POST.get("password") # 确认用户是否已经存在 try: user = Users.objects.get(username=username) @@ -65,23 +72,30 @@ def authenticate(self): if authenticated_user: # ldap 首次登录逻辑 init_user(authenticated_user) - return {'status': 0, 'msg': 'ok', 'data': authenticated_user} + return {"status": 0, "msg": "ok", "data": authenticated_user} else: - return {'status': 1, 'msg': '用户名或密码错误,请重新输入!', 'data': ''} + return {"status": 1, "msg": "用户名或密码错误,请重新输入!", "data": ""} except: - logger.error('验证用户密码时报错') + logger.error("验证用户密码时报错") logger.error(traceback.format_exc()) - return {'status': 1, 'msg': f'服务异常,请联系管理员处理', 'data': ''} + return {"status": 1, "msg": f"服务异常,请联系管理员处理", "data": ""} # 已存在用户, 验证是否在锁期间 # 读取配置文件 - lock_count = int(self.sys_config.get('lock_cnt_threshold', 5)) - lock_time = int(self.sys_config.get('lock_time_threshold', 60 * 5)) + lock_count = int(self.sys_config.get("lock_cnt_threshold", 5)) + lock_time = int(self.sys_config.get("lock_time_threshold", 60 * 5)) # 验证是否在锁, 分了几个if 防止代码太长 if user.failed_login_count and user.last_login_failed_at: if user.failed_login_count >= lock_count: now = datetime.datetime.now() - if user.last_login_failed_at + datetime.timedelta(seconds=lock_time) > now: - return {'status': 3, 'msg': f'登录失败超过限制,该账号已被锁定!请等候大约{lock_time}秒再试', 'data': ''} + if ( + user.last_login_failed_at + datetime.timedelta(seconds=lock_time) + > now + ): + return { + "status": 3, + "msg": f"登录失败超过限制,该账号已被锁定!请等候大约{lock_time}秒再试", + "data": "", + } else: # 如果锁已超时, 重置失败次数 user.failed_login_count = 0 @@ -90,11 +104,11 @@ def authenticate(self): if authenticated_user: if not authenticated_user.last_login: init_user(authenticated_user) - return {'status': 0, 'msg': 'ok', 'data': authenticated_user} + return {"status": 0, "msg": "ok", "data": authenticated_user} user.failed_login_count += 1 user.last_login_failed_at = datetime.datetime.now() user.save() - return {'status': 1, 'msg': '用户名或密码错误,请重新输入!', 'data': ''} + return {"status": 1, "msg": "用户名或密码错误,请重新输入!", "data": ""} # ajax接口,登录页面调用,用来验证用户名密码 @@ -102,69 +116,71 @@ def authenticate_entry(request): """接收http请求,然后把请求中的用户名密码传给ArcherAuth去验证""" new_auth = ArcheryAuth(request) result = new_auth.authenticate() - if result['status'] == 0: - authenticated_user = result['data'] + if result["status"] == 0: + authenticated_user = result["data"] twofa_enabled = TwoFactorAuthConfig.objects.filter(user=authenticated_user) # 是否开启全局2fa - if SysConfig().get('enforce_2fa'): + if SysConfig().get("enforce_2fa"): # 用户是否配置过2fa if twofa_enabled: - verify_mode = 'verify_only' + verify_mode = "verify_only" else: - verify_mode = 'verify_config' + verify_mode = "verify_config" # 设置无登录状态session s = SessionStore() - s['user'] = authenticated_user.username - s['verify_mode'] = verify_mode + s["user"] = authenticated_user.username + s["verify_mode"] = verify_mode s.set_expiry(300) s.create() - result = {'status': 0, 'msg': 'ok', 'data': s.session_key} + result = {"status": 0, "msg": "ok", "data": s.session_key} else: # 用户是否配置过2fa if twofa_enabled: # 设置无登录状态session s = SessionStore() - s['user'] = authenticated_user.username - s['verify_mode'] = 'verify_only' + s["user"] = authenticated_user.username + s["verify_mode"] = "verify_only" s.set_expiry(300) s.create() - result = {'status': 0, 'msg': 'ok', 'data': s.session_key} + result = {"status": 0, "msg": "ok", "data": s.session_key} else: # 未设置2fa直接登录 login(request, authenticated_user) # 从钉钉获取该用户的 dingding_id,用于单独给他发消息 - if SysConfig().get("ding_to_person") is True and "admin" not in request.POST.get('username'): - get_ding_user_id(request.POST.get('username')) - result = {'status': 0, 'msg': 'ok', 'data': None} + if SysConfig().get( + "ding_to_person" + ) is True and "admin" not in request.POST.get("username"): + get_ding_user_id(request.POST.get("username")) + result = {"status": 0, "msg": "ok", "data": None} - return HttpResponse(json.dumps(result), content_type='application/json') + return HttpResponse(json.dumps(result), content_type="application/json") # 注册用户 def sign_up(request): - sign_up_enabled = SysConfig().get('sign_up_enabled', False) + sign_up_enabled = SysConfig().get("sign_up_enabled", False) if not sign_up_enabled: - result = {'status': 1, 'msg': '注册未启用,请联系管理员开启', 'data': None} - return HttpResponse(json.dumps(result), content_type='application/json') - username = request.POST.get('username') - password = request.POST.get('password') - password2 = request.POST.get('password2') - display = request.POST.get('display') - email = request.POST.get('email') - result = {'status': 0, 'msg': 'ok', 'data': None} + result = {"status": 1, "msg": "注册未启用,请联系管理员开启", "data": None} + return HttpResponse(json.dumps(result), content_type="application/json") + username = request.POST.get("username") + password = request.POST.get("password") + password2 = request.POST.get("password2") + display = request.POST.get("display") + email = request.POST.get("email") + result = {"status": 0, "msg": "ok", "data": None} if not (username and password): - result['status'] = 1 - result['msg'] = '用户名和密码不能为空' + result["status"] = 1 + result["msg"] = "用户名和密码不能为空" elif len(Users.objects.filter(username=username)) > 0: - result['status'] = 1 - result['msg'] = '用户名已存在' + result["status"] = 1 + result["msg"] = "用户名已存在" elif password != password2: - result['status'] = 1 - result['msg'] = '两次输入密码不一致' + result["status"] = 1 + result["msg"] = "两次输入密码不一致" elif not display: - result['status'] = 1 - result['msg'] = '请填写中文名' + result["status"] = 1 + result["msg"] = "请填写中文名" else: # 验证密码 try: @@ -175,15 +191,15 @@ def sign_up(request): display=display, email=email, is_active=1, - is_staff=True + is_staff=True, ) except ValidationError as msg: - result['status'] = 1 - result['msg'] = str(msg) - return HttpResponse(json.dumps(result), content_type='application/json') + result["status"] = 1 + result["msg"] = str(msg) + return HttpResponse(json.dumps(result), content_type="application/json") # 退出登录 def sign_out(request): logout(request) - return HttpResponseRedirect(reverse('sql:login')) + return HttpResponseRedirect(reverse("sql:login")) diff --git a/common/check.py b/common/check.py index 300e17bc55..93ac84747b 100644 --- a/common/check.py +++ b/common/check.py @@ -11,106 +11,119 @@ from sql.models import Instance from common.utils.sendmsg import MsgSender -logger = logging.getLogger('default') +logger = logging.getLogger("default") # 检测inception配置 @superuser_required def go_inception(request): - result = {'status': 0, 'msg': 'ok', 'data': []} - go_inception_host = request.POST.get('go_inception_host', '') - go_inception_port = request.POST.get('go_inception_port', '') - inception_remote_backup_host = request.POST.get('inception_remote_backup_host', '') - inception_remote_backup_port = request.POST.get('inception_remote_backup_port', '') - inception_remote_backup_user = request.POST.get('inception_remote_backup_user', '') - inception_remote_backup_password = request.POST.get('inception_remote_backup_password', '') + result = {"status": 0, "msg": "ok", "data": []} + go_inception_host = request.POST.get("go_inception_host", "") + go_inception_port = request.POST.get("go_inception_port", "") + inception_remote_backup_host = request.POST.get("inception_remote_backup_host", "") + inception_remote_backup_port = request.POST.get("inception_remote_backup_port", "") + inception_remote_backup_user = request.POST.get("inception_remote_backup_user", "") + inception_remote_backup_password = request.POST.get( + "inception_remote_backup_password", "" + ) try: - conn = MySQLdb.connect(host=go_inception_host, port=int(go_inception_port), charset='utf8mb4', - connect_timeout=5) + conn = MySQLdb.connect( + host=go_inception_host, + port=int(go_inception_port), + charset="utf8mb4", + connect_timeout=5, + ) cur = conn.cursor() except Exception as e: logger.error(traceback.format_exc()) - result['status'] = 1 - result['msg'] = '无法连接goInception\n{}'.format(str(e)) - return HttpResponse(json.dumps(result), content_type='application/json') + result["status"] = 1 + result["msg"] = "无法连接goInception\n{}".format(str(e)) + return HttpResponse(json.dumps(result), content_type="application/json") else: cur.close() conn.close() try: - conn = MySQLdb.connect(host=inception_remote_backup_host, - port=int(inception_remote_backup_port), - user=inception_remote_backup_user, - password=inception_remote_backup_password, - charset='utf8mb4', - connect_timeout=5) + conn = MySQLdb.connect( + host=inception_remote_backup_host, + port=int(inception_remote_backup_port), + user=inception_remote_backup_user, + password=inception_remote_backup_password, + charset="utf8mb4", + connect_timeout=5, + ) cur = conn.cursor() except Exception as e: logger.error(traceback.format_exc()) - result['status'] = 1 - result['msg'] = '无法连接goInception备份库\n{}'.format(str(e)) + result["status"] = 1 + result["msg"] = "无法连接goInception备份库\n{}".format(str(e)) else: cur.close() conn.close() # 返回结果 - return HttpResponse(json.dumps(result), content_type='application/json') + return HttpResponse(json.dumps(result), content_type="application/json") # 检测email配置 @superuser_required def email(request): - result = {'status': 0, 'msg': 'ok', 'data': []} - mail = True if request.POST.get('mail', '') == 'true' else False - mail_ssl = True if request.POST.get('mail_ssl') == 'true' else False - mail_smtp_server = request.POST.get('mail_smtp_server', '') - mail_smtp_port = request.POST.get('mail_smtp_port', '') - mail_smtp_user = request.POST.get('mail_smtp_user', '') - mail_smtp_password = request.POST.get('mail_smtp_password', '') + result = {"status": 0, "msg": "ok", "data": []} + mail = True if request.POST.get("mail", "") == "true" else False + mail_ssl = True if request.POST.get("mail_ssl") == "true" else False + mail_smtp_server = request.POST.get("mail_smtp_server", "") + mail_smtp_port = request.POST.get("mail_smtp_port", "") + mail_smtp_user = request.POST.get("mail_smtp_user", "") + mail_smtp_password = request.POST.get("mail_smtp_password", "") if not mail: - result['status'] = 1 - result['msg'] = '请先开启邮件通知!' + result["status"] = 1 + result["msg"] = "请先开启邮件通知!" # 返回结果 - return HttpResponse(json.dumps(result), content_type='application/json') + return HttpResponse(json.dumps(result), content_type="application/json") try: mail_smtp_port = int(mail_smtp_port) if mail_smtp_port < 0: raise ValueError except ValueError: - result['status'] = 1 - result['msg'] = '端口号只能为正整数' - return HttpResponse(json.dumps(result), content_type='application/json') + result["status"] = 1 + result["msg"] = "端口号只能为正整数" + return HttpResponse(json.dumps(result), content_type="application/json") if not request.user.email: - result['status'] = 1 - result['msg'] = '请先完善当前用户邮箱信息!' - return HttpResponse(json.dumps(result), content_type='application/json') - bd = 'Archery 邮件发送测试...' - subj = 'Archery 邮件发送测试' - sender = MsgSender(server=mail_smtp_server, port=mail_smtp_port, user=mail_smtp_user, - password=mail_smtp_password, ssl=mail_ssl) + result["status"] = 1 + result["msg"] = "请先完善当前用户邮箱信息!" + return HttpResponse(json.dumps(result), content_type="application/json") + bd = "Archery 邮件发送测试..." + subj = "Archery 邮件发送测试" + sender = MsgSender( + server=mail_smtp_server, + port=mail_smtp_port, + user=mail_smtp_user, + password=mail_smtp_password, + ssl=mail_ssl, + ) sender_response = sender.send_email(subj, bd, [request.user.email]) - if sender_response != 'success': - result['status'] = 1 - result['msg'] = sender_response - return HttpResponse(json.dumps(result), content_type='application/json') - return HttpResponse(json.dumps(result), content_type='application/json') + if sender_response != "success": + result["status"] = 1 + result["msg"] = sender_response + return HttpResponse(json.dumps(result), content_type="application/json") + return HttpResponse(json.dumps(result), content_type="application/json") # 检测实例配置 @superuser_required def instance(request): - result = {'status': 0, 'msg': 'ok', 'data': []} - instance_id = request.POST.get('instance_id') + result = {"status": 0, "msg": "ok", "data": []} + instance_id = request.POST.get("instance_id") instance = Instance.objects.get(id=instance_id) try: engine = get_engine(instance=instance) test_result = engine.test_connection() if test_result.error: - result['status'] = 1 - result['msg'] = '无法连接实例,\n{}'.format(test_result.error) + result["status"] = 1 + result["msg"] = "无法连接实例,\n{}".format(test_result.error) except Exception as e: - result['status'] = 1 - result['msg'] = '无法连接实例,\n{}'.format(str(e)) + result["status"] = 1 + result["msg"] = "无法连接实例,\n{}".format(str(e)) # 返回结果 - return HttpResponse(json.dumps(result), content_type='application/json') + return HttpResponse(json.dumps(result), content_type="application/json") diff --git a/common/config.py b/common/config.py index de4865563a..89475a8c52 100644 --- a/common/config.py +++ b/common/config.py @@ -9,7 +9,7 @@ from sql.models import Config from django.db import transaction -logger = logging.getLogger('default') +logger = logging.getLogger("default") class SysConfig(object): @@ -20,14 +20,14 @@ def __init__(self): def get_all_config(self): try: # 获取系统配置信息 - all_config = Config.objects.all().values('item', 'value') + all_config = Config.objects.all().values("item", "value") sys_config = {} for items in all_config: - if items['value'] in ('true', 'True'): - items['value'] = True - elif items['value'] in ('false', 'False'): - items['value'] = False - sys_config[items['item']] = items['value'] + if items["value"] in ("true", "True"): + items["value"] = True + elif items["value"] in ("false", "False"): + items["value"] = False + sys_config[items["item"]] = items["value"] self.sys_config = sys_config except Exception as m: logger.error(f"获取系统配置信息失败:{m}{traceback.format_exc()}") @@ -36,34 +36,41 @@ def get_all_config(self): def get(self, key, default_value=None): value = self.sys_config.get(key, default_value) # 是字符串的话, 如果是空, 或者全是空格, 返回默认值 - if isinstance(value, str) and value.strip() == '': + if isinstance(value, str) and value.strip() == "": return default_value return value def set(self, key, value): if value is True: - db_value = 'true' + db_value = "true" elif value is False: - db_value = 'false' + db_value = "false" else: db_value = value - obj, created = Config.objects.update_or_create(item=key, defaults={"value": db_value}) + obj, created = Config.objects.update_or_create( + item=key, defaults={"value": db_value} + ) if created: self.sys_config.update({key: value}) def replace(self, configs): - result = {'status': 0, 'msg': 'ok', 'data': []} + result = {"status": 0, "msg": "ok", "data": []} # 清空并替换 try: with transaction.atomic(): self.purge() Config.objects.bulk_create( - [Config(item=items['key'].strip(), - value=str(items['value']).strip()) for items in json.loads(configs)]) + [ + Config( + item=items["key"].strip(), value=str(items["value"]).strip() + ) + for items in json.loads(configs) + ] + ) except Exception as e: logger.error(traceback.format_exc()) - result['status'] = 1 - result['msg'] = str(e) + result["status"] = 1 + result["msg"] = str(e) finally: self.get_all_config() return result @@ -81,8 +88,8 @@ def purge(self): # 修改系统配置 @superuser_required def change_config(request): - configs = request.POST.get('configs') + configs = request.POST.get("configs") archer_config = SysConfig() result = archer_config.replace(configs) # 返回结果 - return HttpResponse(json.dumps(result), content_type='application/json') + return HttpResponse(json.dumps(result), content_type="application/json") diff --git a/common/dashboard.py b/common/dashboard.py index ee2d274a72..af7291a847 100644 --- a/common/dashboard.py +++ b/common/dashboard.py @@ -11,10 +11,10 @@ from pyecharts import options as opts from pyecharts.charts import Pie, Bar, Line -CurrentConfig.ONLINE_HOST = '/static/echarts/' +CurrentConfig.ONLINE_HOST = "/static/echarts/" -@permission_required('sql.menu_dashboard', raise_exception=True) +@permission_required("sql.menu_dashboard", raise_exception=True) def pyecharts(request): """dashboard view""" # 工单数量统计 @@ -24,42 +24,44 @@ def pyecharts(request): one_month_before = today - relativedelta(days=+30) attr = chart_dao.get_date_list(one_month_before, today) _dict = {} - for row in data['rows']: + for row in data["rows"]: _dict[row[0]] = row[1] value = [_dict.get(day) if _dict.get(day) else 0 for day in attr] - bar1 = Bar(init_opts=opts.InitOpts(width='600', height='380px')) + bar1 = Bar(init_opts=opts.InitOpts(width="600", height="380px")) bar1.add_xaxis(attr) bar1.add_yaxis("", value) # 工单按组统计 data = chart_dao.workflow_by_group(30) - attr = [row[0] for row in data['rows']] - value = [row[1] for row in data['rows']] - pie1 = Pie(init_opts=opts.InitOpts(width='600', height='380px')) - pie1.set_global_opts(title_opts=opts.TitleOpts(title=''), - legend_opts=opts.LegendOpts( - orient="vertical", pos_top="15%", pos_left="2%", is_show=False - )) + attr = [row[0] for row in data["rows"]] + value = [row[1] for row in data["rows"]] + pie1 = Pie(init_opts=opts.InitOpts(width="600", height="380px")) + pie1.set_global_opts( + title_opts=opts.TitleOpts(title=""), + legend_opts=opts.LegendOpts( + orient="vertical", pos_top="15%", pos_left="2%", is_show=False + ), + ) pie1.set_series_opts(label_opts=opts.LabelOpts(formatter="{b}: {c}")) pie1.add("", [list(z) for z in zip(attr, value)]) if attr and data else None # 工单按人统计 data = chart_dao.workflow_by_user(30) - attr = [row[0] for row in data['rows']] - value = [row[1] for row in data['rows']] - bar2 = Bar(init_opts=opts.InitOpts(width='600', height='380px')) + attr = [row[0] for row in data["rows"]] + value = [row[1] for row in data["rows"]] + bar2 = Bar(init_opts=opts.InitOpts(width="600", height="380px")) bar2.add_xaxis(attr) bar2.add_yaxis("", value) # SQL语句类型统计 data = chart_dao.syntax_type() - attr = [row[0] for row in data['rows']] - value = [row[1] for row in data['rows']] + attr = [row[0] for row in data["rows"]] + value = [row[1] for row in data["rows"]] pie2 = Pie() - pie2.set_global_opts(title_opts=opts.TitleOpts(title='SQL上线工单统计(类型)'), - legend_opts=opts.LegendOpts( - orient="vertical", pos_top="15%", pos_left="2%" - )) + pie2.set_global_opts( + title_opts=opts.TitleOpts(title="SQL上线工单统计(类型)"), + legend_opts=opts.LegendOpts(orient="vertical", pos_top="15%", pos_left="2%"), + ) pie2.set_series_opts(label_opts=opts.LabelOpts(formatter="{b}: {c}")) pie2.add("", [list(z) for z in zip(attr, value)]) if attr and data else None @@ -67,65 +69,86 @@ def pyecharts(request): attr = chart_dao.get_date_list(one_month_before, today) effect_data = chart_dao.querylog_effect_row_by_date(30) effect_dict = {} - for row in effect_data['rows']: + for row in effect_data["rows"]: effect_dict[row[0]] = int(row[1]) effect_value = [effect_dict.get(day) if effect_dict.get(day) else 0 for day in attr] count_data = chart_dao.querylog_count_by_date(30) count_dict = {} - for row in count_data['rows']: + for row in count_data["rows"]: count_dict[row[0]] = int(row[1]) count_value = [count_dict.get(day) if count_dict.get(day) else 0 for day in attr] - line1 = Line(init_opts=opts.InitOpts(width='600', height='380px')) - line1.set_global_opts(title_opts=opts.TitleOpts(title=''), - legend_opts=opts.LegendOpts(selected_mode='single')) + line1 = Line(init_opts=opts.InitOpts(width="600", height="380px")) + line1.set_global_opts( + title_opts=opts.TitleOpts(title=""), + legend_opts=opts.LegendOpts(selected_mode="single"), + ) line1.add_xaxis(attr) - line1.add_yaxis("检索行数", effect_value, is_smooth=True, - markpoint_opts=opts.MarkPointOpts(data=[opts.MarkPointItem(type_="average")])) - line1.add_yaxis("检索次数", count_value, is_smooth=True, - markline_opts=opts.MarkLineOpts(data=[opts.MarkLineItem(type_="max"), - opts.MarkLineItem(type_="average")])) + line1.add_yaxis( + "检索行数", + effect_value, + is_smooth=True, + markpoint_opts=opts.MarkPointOpts(data=[opts.MarkPointItem(type_="average")]), + ) + line1.add_yaxis( + "检索次数", + count_value, + is_smooth=True, + markline_opts=opts.MarkLineOpts( + data=[opts.MarkLineItem(type_="max"), opts.MarkLineItem(type_="average")] + ), + ) # SQL查询统计(用户检索行数) data = chart_dao.querylog_effect_row_by_user(30) - attr = [row[0] for row in data['rows']] - value = [int(row[1]) for row in data['rows']] - pie4 = Pie(init_opts=opts.InitOpts(width='600', height='380px')) - pie4.set_global_opts(title_opts=opts.TitleOpts(title=''), - legend_opts=opts.LegendOpts( - orient="vertical", pos_top="15%", pos_left="2%", is_show=False - )) + attr = [row[0] for row in data["rows"]] + value = [int(row[1]) for row in data["rows"]] + pie4 = Pie(init_opts=opts.InitOpts(width="600", height="380px")) + pie4.set_global_opts( + title_opts=opts.TitleOpts(title=""), + legend_opts=opts.LegendOpts( + orient="vertical", pos_top="15%", pos_left="2%", is_show=False + ), + ) pie4.set_series_opts(label_opts=opts.LabelOpts(formatter="{b}: {c}")) pie4.add("", [list(z) for z in zip(attr, value)]) if attr and data else None # SQL查询统计(DB检索行数) data = chart_dao.querylog_effect_row_by_db(30) - attr = [row[0] for row in data['rows']] - value = [int(row[1]) for row in data['rows']] - pie5 = Pie(init_opts=opts.InitOpts(width='600', height='380px')) - pie5.set_global_opts(title_opts=opts.TitleOpts(title=''), - legend_opts=opts.LegendOpts( - orient="vertical", pos_top="15%", pos_left="2%", is_show=False - )) - pie5.set_series_opts(label_opts=opts.LabelOpts(formatter="{b}: {c}", position="left")) + attr = [row[0] for row in data["rows"]] + value = [int(row[1]) for row in data["rows"]] + pie5 = Pie(init_opts=opts.InitOpts(width="600", height="380px")) + pie5.set_global_opts( + title_opts=opts.TitleOpts(title=""), + legend_opts=opts.LegendOpts( + orient="vertical", pos_top="15%", pos_left="2%", is_show=False + ), + ) + pie5.set_series_opts( + label_opts=opts.LabelOpts(formatter="{b}: {c}", position="left") + ) pie5.add("", [list(z) for z in zip(attr, value)]) if attr and data else None # 慢查询db/user维度统计(最近1天) data = chart_dao.slow_query_count_by_db_by_user(1) - attr = [row[0] for row in data['rows']] - value = [int(row[1]) for row in data['rows']] - pie3 = Pie(init_opts=opts.InitOpts(width='600',height='380px')) - pie3.set_global_opts(title_opts=opts.TitleOpts(title=''), - legend_opts=opts.LegendOpts( - orient="vertical", pos_top="15%", pos_left="2%", is_show=False - )) - pie3.set_series_opts(label_opts=opts.LabelOpts(formatter="{b}: {c}", position="left")) + attr = [row[0] for row in data["rows"]] + value = [int(row[1]) for row in data["rows"]] + pie3 = Pie(init_opts=opts.InitOpts(width="600", height="380px")) + pie3.set_global_opts( + title_opts=opts.TitleOpts(title=""), + legend_opts=opts.LegendOpts( + orient="vertical", pos_top="15%", pos_left="2%", is_show=False + ), + ) + pie3.set_series_opts( + label_opts=opts.LabelOpts(formatter="{b}: {c}", position="left") + ) pie3.add("", [list(z) for z in zip(attr, value)]) if attr and data else None # 慢查询db维度统计(最近1天) data = chart_dao.slow_query_count_by_db(1) - attr = [row[0] for row in data['rows']] - value = [row[1] for row in data['rows']] - bar3 = Bar(init_opts=opts.InitOpts(width='600', height='380px')) + attr = [row[0] for row in data["rows"]] + value = [row[1] for row in data["rows"]] + bar3 = Bar(init_opts=opts.InitOpts(width="600", height="380px")) bar3.add_xaxis(attr) bar3.add_yaxis("", value) @@ -147,7 +170,11 @@ def pyecharts(request): "sql_wf_cnt": SqlWorkflow.objects.count(), "query_wf_cnt": QueryPrivilegesApply.objects.count(), "user_cnt": Users.objects.count(), - "ins_cnt": Instance.objects.count() + "ins_cnt": Instance.objects.count(), } - return render(request, "dashboard.html", {"chart": chart, "count_stats": dashboard_count_stats}) + return render( + request, + "dashboard.html", + {"chart": chart, "count_stats": dashboard_count_stats}, + ) diff --git a/common/middleware/check_login_middleware.py b/common/middleware/check_login_middleware.py index 16973b6385..59ca7876f4 100644 --- a/common/middleware/check_login_middleware.py +++ b/common/middleware/check_login_middleware.py @@ -3,15 +3,9 @@ from django.http import HttpResponseRedirect from django.utils.deprecation import MiddlewareMixin -IGNORE_URL = [ - '/login/', - '/login/2fa/', - '/authenticate/', - '/signup/', - '/api/info' -] +IGNORE_URL = ["/login/", "/login/2fa/", "/authenticate/", "/signup/", "/api/info"] -IGNORE_URL_RE = r'/api/(v1|auth)/\w+' +IGNORE_URL_RE = r"/api/(v1|auth)/\w+" class CheckLoginMiddleware(MiddlewareMixin): @@ -22,6 +16,12 @@ def process_request(request): """ if not request.user.is_authenticated: # 以下是不用跳转到login页面的url白名单 - if request.path not in IGNORE_URL and re.match(IGNORE_URL_RE, request.path) is None \ - and not (re.match(r'/user/qrcode/\w+', request.path) and request.session.get('user')): - return HttpResponseRedirect('/login/') + if ( + request.path not in IGNORE_URL + and re.match(IGNORE_URL_RE, request.path) is None + and not ( + re.match(r"/user/qrcode/\w+", request.path) + and request.session.get("user") + ) + ): + return HttpResponseRedirect("/login/") diff --git a/common/middleware/exception_logging_middleware.py b/common/middleware/exception_logging_middleware.py index 835993c5e6..cfcfa46f27 100644 --- a/common/middleware/exception_logging_middleware.py +++ b/common/middleware/exception_logging_middleware.py @@ -2,10 +2,11 @@ import logging from django.utils.deprecation import MiddlewareMixin -logger = logging.getLogger('default') +logger = logging.getLogger("default") class ExceptionLoggingMiddleware(MiddlewareMixin): def process_exception(self, request, exception): import traceback + logger.error(traceback.format_exc()) diff --git a/common/storage.py b/common/storage.py index 60097961a0..d4aeeb3500 100644 --- a/common/storage.py +++ b/common/storage.py @@ -6,7 +6,7 @@ @time: 2019/06/01 """ -__author__ = 'hhyo' +__author__ = "hhyo" from django.contrib.staticfiles.storage import ManifestStaticFilesStorage diff --git a/common/tests.py b/common/tests.py index f32eb55714..da30292cd6 100644 --- a/common/tests.py +++ b/common/tests.py @@ -8,7 +8,13 @@ from common.config import SysConfig from common.utils.sendmsg import MsgSender from sql.engines import EngineBase, ResultSet -from sql.models import Instance, SqlWorkflow, SqlWorkflowContent, QueryLog, ResourceGroup +from sql.models import ( + Instance, + SqlWorkflow, + SqlWorkflowContent, + QueryLog, + ResourceGroup, +) from common.utils.chart_dao import ChartDao from common.auth import init_user @@ -21,7 +27,7 @@ def setUp(self): def test_purge(self): archer_config = SysConfig() - archer_config.set('some_key', 'some_value') + archer_config.set("some_key", "some_value") archer_config.purge() self.assertEqual({}, archer_config.sys_config) archer_config2 = SysConfig() @@ -30,39 +36,42 @@ def test_purge(self): def test_replace_configs(self): archer_config = SysConfig() new_config = json.dumps( - [{'key': 'numconfig', 'value': 1}, - {'key': 'strconfig', 'value': 'strconfig'}, - {'key': 'boolconfig', 'value': 'false'}]) + [ + {"key": "numconfig", "value": 1}, + {"key": "strconfig", "value": "strconfig"}, + {"key": "boolconfig", "value": "false"}, + ] + ) archer_config.replace(new_config) archer_config.get_all_config() expected_config = { - 'numconfig': '1', - 'strconfig': 'strconfig', - 'boolconfig': False + "numconfig": "1", + "strconfig": "strconfig", + "boolconfig": False, } self.assertEqual(archer_config.sys_config, expected_config) def test_get_bool_transform(self): - bool_config = json.dumps([{'key': 'boolconfig2', 'value': 'false'}]) + bool_config = json.dumps([{"key": "boolconfig2", "value": "false"}]) archer_config = SysConfig() archer_config.replace(bool_config) - self.assertEqual(archer_config.sys_config['boolconfig2'], False) + self.assertEqual(archer_config.sys_config["boolconfig2"], False) def test_set_bool_transform(self): archer_config = SysConfig() - archer_config.set('boolconfig3', False) - self.assertEqual(archer_config.sys_config['boolconfig3'], False) + archer_config.set("boolconfig3", False) + self.assertEqual(archer_config.sys_config["boolconfig3"], False) def test_get_other_data(self): - new_config = json.dumps([{'key': 'other_config', 'value': 'testvalue'}]) + new_config = json.dumps([{"key": "other_config", "value": "testvalue"}]) archer_config = SysConfig() archer_config.replace(new_config) - self.assertEqual(archer_config.sys_config['other_config'], 'testvalue') + self.assertEqual(archer_config.sys_config["other_config"], "testvalue") def test_set_other_data(self): archer_config = SysConfig() - archer_config.set('other_config', 'testvalue3') - self.assertEqual(archer_config.sys_config['other_config'], 'testvalue3') + archer_config.set("other_config", "testvalue3") + self.assertEqual(archer_config.sys_config["other_config"], "testvalue3") class SendMessageTest(TestCase): @@ -70,74 +79,74 @@ class SendMessageTest(TestCase): def setUp(self): archer_config = SysConfig() - self.smtp_server = 'test_smtp_server' - self.smtp_user = 'test_smtp_user' - self.smtp_password = 'some_str' + self.smtp_server = "test_smtp_server" + self.smtp_user = "test_smtp_user" + self.smtp_password = "some_str" self.smtp_port = 1234 self.smtp_ssl = True - archer_config.set('mail_smtp_server', self.smtp_server) - archer_config.set('mail_smtp_user', self.smtp_user) - archer_config.set('mail_smtp_password', self.smtp_password) - archer_config.set('mail_smtp_port', self.smtp_port) - archer_config.set('mail_ssl', self.smtp_ssl) + archer_config.set("mail_smtp_server", self.smtp_server) + archer_config.set("mail_smtp_user", self.smtp_user) + archer_config.set("mail_smtp_password", self.smtp_password) + archer_config.set("mail_smtp_port", self.smtp_port) + archer_config.set("mail_ssl", self.smtp_ssl) def testSenderInit(self): sender = MsgSender() self.assertEqual(sender.MAIL_REVIEW_SMTP_PORT, self.smtp_port) archer_config = SysConfig() - archer_config.set('mail_smtp_port', '') + archer_config.set("mail_smtp_port", "") sender = MsgSender() self.assertEqual(sender.MAIL_REVIEW_SMTP_PORT, 465) - archer_config.set('mail_ssl', False) + archer_config.set("mail_ssl", False) sender = MsgSender() self.assertEqual(sender.MAIL_REVIEW_SMTP_PORT, 25) - @patch.object(smtplib.SMTP, '__init__', return_value=None) - @patch.object(smtplib.SMTP, 'login') - @patch.object(smtplib.SMTP, 'sendmail') - @patch.object(smtplib.SMTP, 'quit') + @patch.object(smtplib.SMTP, "__init__", return_value=None) + @patch.object(smtplib.SMTP, "login") + @patch.object(smtplib.SMTP, "sendmail") + @patch.object(smtplib.SMTP, "quit") def testNoPasswordSendMail(self, _quit, sendmail, login, _): """无密码测试""" - some_sub = 'test_subject' - some_body = 'mail_body' - some_to = ['mail_to'] + some_sub = "test_subject" + some_body = "mail_body" + some_to = ["mail_to"] archer_config = SysConfig() - archer_config.set('mail_ssl', '') + archer_config.set("mail_ssl", "") - archer_config.set('mail_smtp_password', '') + archer_config.set("mail_smtp_password", "") sender2 = MsgSender() sender2.send_email(some_sub, some_body, some_to) login.assert_not_called() - @patch.object(smtplib.SMTP, '__init__', return_value=None) - @patch.object(smtplib.SMTP, 'login') - @patch.object(smtplib.SMTP, 'sendmail') - @patch.object(smtplib.SMTP, 'quit') + @patch.object(smtplib.SMTP, "__init__", return_value=None) + @patch.object(smtplib.SMTP, "login") + @patch.object(smtplib.SMTP, "sendmail") + @patch.object(smtplib.SMTP, "quit") def testSendMail(self, _quit, sendmail, login, _): """有密码测试""" - some_sub = 'test_subject' - some_body = 'mail_body' - some_to = ['mail_to'] + some_sub = "test_subject" + some_body = "mail_body" + some_to = ["mail_to"] archer_config = SysConfig() - archer_config.set('mail_ssl', '') - archer_config.set('mail_smtp_password', self.smtp_password) + archer_config.set("mail_ssl", "") + archer_config.set("mail_smtp_password", self.smtp_password) sender = MsgSender() sender.send_email(some_sub, some_body, some_to) login.assert_called_once() sendmail.assert_called_with(self.smtp_user, some_to, ANY) _quit.assert_called_once() - @patch.object(smtplib.SMTP, '__init__', return_value=None) - @patch.object(smtplib.SMTP, 'login') - @patch.object(smtplib.SMTP, 'sendmail') - @patch.object(smtplib.SMTP, 'quit') + @patch.object(smtplib.SMTP, "__init__", return_value=None) + @patch.object(smtplib.SMTP, "login") + @patch.object(smtplib.SMTP, "sendmail") + @patch.object(smtplib.SMTP, "quit") def testSSLSendMail(self, _quit, sendmail, login, _): """SSL 测试""" - some_sub = 'test_subject' - some_body = 'mail_body' - some_to = ['mail_to'] + some_sub = "test_subject" + some_body = "mail_body" + some_to = ["mail_to"] archer_config = SysConfig() - archer_config.set('mail_ssl', True) + archer_config.set("mail_ssl", True) sender = MsgSender() sender.send_email(some_sub, some_body, some_to) sendmail.assert_called_with(self.smtp_user, some_to, ANY) @@ -145,36 +154,33 @@ def testSSLSendMail(self, _quit, sendmail, login, _): def tearDown(self): archer_config = SysConfig() - archer_config.set('mail_smtp_server', '') - archer_config.set('mail_smtp_user', '') - archer_config.set('mail_smtp_password', '') - archer_config.set('mail_smtp_port', '') - archer_config.set('mail_ssl', '') + archer_config.set("mail_smtp_server", "") + archer_config.set("mail_smtp_user", "") + archer_config.set("mail_smtp_password", "") + archer_config.set("mail_smtp_port", "") + archer_config.set("mail_ssl", "") class DingTest(TestCase): - def setUp(self): - self.url = 'some_url' - self.content = 'some_content' + self.url = "some_url" + self.content = "some_content" - @patch('requests.post') + @patch("requests.post") def testDing(self, post): sender = MsgSender() - post.return_value.json.return_value = {'errcode': 0} - with self.assertLogs('default', level='DEBUG') as lg: + post.return_value.json.return_value = {"errcode": 0} + with self.assertLogs("default", level="DEBUG") as lg: sender.send_ding(self.url, self.content) - post.assert_called_once_with(url=self.url, json={ - 'msgtype': 'text', - 'text': { - 'content': self.content - } - }) - self.assertIn('钉钉Webhook推送成功', lg.output[0]) - post.return_value.json.return_value = {'errcode': 1, 'errmsg': 'test_error'} - with self.assertLogs('default', level='ERROR') as lg: + post.assert_called_once_with( + url=self.url, + json={"msgtype": "text", "text": {"content": self.content}}, + ) + self.assertIn("钉钉Webhook推送成功", lg.output[0]) + post.return_value.json.return_value = {"errcode": 1, "errmsg": "test_error"} + with self.assertLogs("default", level="ERROR") as lg: sender.send_ding(self.url, self.content) - self.assertIn('test_error', lg.output[0]) + self.assertIn("test_error", lg.output[0]) def tearDown(self): pass @@ -182,26 +188,26 @@ def tearDown(self): class GlobalInfoTest(TestCase): def setUp(self): - self.u1 = User(username='test_user', display='中文显示', is_active=True) + self.u1 = User(username="test_user", display="中文显示", is_active=True) self.u1.save() - @patch('sql.utils.workflow_audit.Audit.todo') + @patch("sql.utils.workflow_audit.Audit.todo") def testGlobalInfo(self, todo): """测试""" c = Client() - r = c.get('/', follow=True) + r = c.get("/", follow=True) todo.assert_not_called() - self.assertEqual(r.context['todo'], 0) + self.assertEqual(r.context["todo"], 0) # 已登录用户 c.force_login(self.u1) todo.return_value = 3 - r = c.get('/', follow=True) + r = c.get("/", follow=True) todo.assert_called_once_with(self.u1) - self.assertEqual(r.context['todo'], 3) + self.assertEqual(r.context["todo"], 3) # 报异常 - todo.side_effect = NameError('some exception') - r = c.get('/', follow=True) - self.assertEqual(r.context['todo'], 0) + todo.side_effect = NameError("some exception") + r = c.get("/", follow=True) + self.assertEqual(r.context["todo"], 0) def tearDown(self): self.u1.delete() @@ -211,120 +217,153 @@ class CheckTest(TestCase): """检查功能测试""" def setUp(self): - self.superuser1 = User(username='test_user', display='中文显示', is_active=True, is_superuser=True, - email='XXX@xxx.com') + self.superuser1 = User( + username="test_user", + display="中文显示", + is_active=True, + is_superuser=True, + email="XXX@xxx.com", + ) self.superuser1.save() - self.slave1 = Instance(instance_name='some_name', host='some_host', type='slave', db_type='mysql', - user='some_user', port=1234, password='some_str') + self.slave1 = Instance( + instance_name="some_name", + host="some_host", + type="slave", + db_type="mysql", + user="some_user", + port=1234, + password="some_str", + ) self.slave1.save() def tearDown(self): self.superuser1.delete() - @patch.object(MsgSender, '__init__', return_value=None) - @patch.object(MsgSender, 'send_email') + @patch.object(MsgSender, "__init__", return_value=None) + @patch.object(MsgSender, "send_email") def testEmailCheck(self, send_email, mailsender): """邮箱配置检查""" - mail_switch = 'true' - smtp_ssl = 'false' - smtp_server = 'some_server' - smtp_port = '1234' - smtp_user = 'some_user' - smtp_pass = 'some_str' + mail_switch = "true" + smtp_ssl = "false" + smtp_server = "some_server" + smtp_port = "1234" + smtp_user = "some_user" + smtp_pass = "some_str" # 略过superuser校验 # 未开启mail开关 - mail_switch = 'false' + mail_switch = "false" c = Client() c.force_login(self.superuser1) - r = c.post('/check/email/', data={ - 'mail': mail_switch, - 'mail_ssl': smtp_ssl, - 'mail_smtp_server': smtp_server, - 'mail_smtp_port': smtp_port, - 'mail_smtp_user': smtp_user, - 'mail_smtp_password': smtp_pass - }) + r = c.post( + "/check/email/", + data={ + "mail": mail_switch, + "mail_ssl": smtp_ssl, + "mail_smtp_server": smtp_server, + "mail_smtp_port": smtp_port, + "mail_smtp_user": smtp_user, + "mail_smtp_password": smtp_pass, + }, + ) r_json = r.json() - self.assertEqual(r_json['status'], 1) - self.assertEqual(r_json['msg'], '请先开启邮件通知!') - mail_switch = 'true' + self.assertEqual(r_json["status"], 1) + self.assertEqual(r_json["msg"], "请先开启邮件通知!") + mail_switch = "true" # 填写非正整数端口号 - smtp_port = '-3' - r = c.post('/check/email/', data={ - 'mail': mail_switch, - 'mail_ssl': smtp_ssl, - 'mail_smtp_server': smtp_server, - 'mail_smtp_port': smtp_port, - 'mail_smtp_user': smtp_user, - 'mail_smtp_password': smtp_pass - }) + smtp_port = "-3" + r = c.post( + "/check/email/", + data={ + "mail": mail_switch, + "mail_ssl": smtp_ssl, + "mail_smtp_server": smtp_server, + "mail_smtp_port": smtp_port, + "mail_smtp_user": smtp_user, + "mail_smtp_password": smtp_pass, + }, + ) r_json = r.json() - self.assertEqual(r_json['status'], 1) - self.assertEqual(r_json['msg'], '端口号只能为正整数') - smtp_port = '1234' + self.assertEqual(r_json["status"], 1) + self.assertEqual(r_json["msg"], "端口号只能为正整数") + smtp_port = "1234" # 未填写用户邮箱 - self.superuser1.email = '' + self.superuser1.email = "" self.superuser1.save() - r = c.post('/check/email/', data={ - 'mail': mail_switch, - 'mail_ssl': smtp_ssl, - 'mail_smtp_server': smtp_server, - 'mail_smtp_port': smtp_port, - 'mail_smtp_user': smtp_user, - 'mail_smtp_password': smtp_pass - }) + r = c.post( + "/check/email/", + data={ + "mail": mail_switch, + "mail_ssl": smtp_ssl, + "mail_smtp_server": smtp_server, + "mail_smtp_port": smtp_port, + "mail_smtp_user": smtp_user, + "mail_smtp_password": smtp_pass, + }, + ) r_json = r.json() - self.assertEqual(r_json['status'], 1) - self.assertEqual(r_json['msg'], '请先完善当前用户邮箱信息!') - self.superuser1.email = 'XXX@xxx.com' + self.assertEqual(r_json["status"], 1) + self.assertEqual(r_json["msg"], "请先完善当前用户邮箱信息!") + self.superuser1.email = "XXX@xxx.com" self.superuser1.save() # 发送失败, 显示traceback - send_email.return_value = 'some traceback' - r = c.post('/check/email/', data={ - 'mail': mail_switch, - 'mail_ssl': smtp_ssl, - 'mail_smtp_server': smtp_server, - 'mail_smtp_port': smtp_port, - 'mail_smtp_user': smtp_user, - 'mail_smtp_password': smtp_pass - }) + send_email.return_value = "some traceback" + r = c.post( + "/check/email/", + data={ + "mail": mail_switch, + "mail_ssl": smtp_ssl, + "mail_smtp_server": smtp_server, + "mail_smtp_port": smtp_port, + "mail_smtp_user": smtp_user, + "mail_smtp_password": smtp_pass, + }, + ) r_json = r.json() - self.assertEqual(r_json['status'], 1) - self.assertIn('some traceback', r_json['msg']) + self.assertEqual(r_json["status"], 1) + self.assertIn("some traceback", r_json["msg"]) send_email.reset_mock() # 重置``Mock``的调用计数 mailsender.reset_mock() # 发送成功 - send_email.return_value = 'success' - r = c.post('/check/email/', data={ - 'mail': mail_switch, - 'mail_ssl': smtp_ssl, - 'mail_smtp_server': smtp_server, - 'mail_smtp_port': smtp_port, - 'mail_smtp_user': smtp_user, - 'mail_smtp_password': smtp_pass - }) + send_email.return_value = "success" + r = c.post( + "/check/email/", + data={ + "mail": mail_switch, + "mail_ssl": smtp_ssl, + "mail_smtp_server": smtp_server, + "mail_smtp_port": smtp_port, + "mail_smtp_user": smtp_user, + "mail_smtp_password": smtp_pass, + }, + ) r_json = r.json() - mailsender.assert_called_once_with(server=smtp_server, port=int(smtp_port), user=smtp_user, - password=smtp_pass, ssl=False) - send_email.called_once_with('Archery 邮件发送测试', 'Archery 邮件发送测试...', - [self.superuser1.email]) - self.assertEqual(r_json['status'], 0) - self.assertEqual(r_json['msg'], 'ok') - - @patch('MySQLdb.connect') - @patch('common.check.get_engine') + mailsender.assert_called_once_with( + server=smtp_server, + port=int(smtp_port), + user=smtp_user, + password=smtp_pass, + ssl=False, + ) + send_email.called_once_with( + "Archery 邮件发送测试", "Archery 邮件发送测试...", [self.superuser1.email] + ) + self.assertEqual(r_json["status"], 0) + self.assertEqual(r_json["msg"], "ok") + + @patch("MySQLdb.connect") + @patch("common.check.get_engine") def testInstanceCheck(self, _get_engine, _conn): _get_engine.return_value.get_connection = _conn - _get_engine.return_value.get_all_databases.return_value.rows.return_value = ResultSet( - rows=((),), - error='Wrong password') + _get_engine.return_value.get_all_databases.return_value.rows.return_value = ( + ResultSet(rows=((),), error="Wrong password") + ) c = Client() c.force_login(self.superuser1) - r = c.post('/check/instance/', data={'instance_id': self.slave1.id}) + r = c.post("/check/instance/", data={"instance_id": self.slave1.id}) r_json = r.json() - self.assertEqual(r_json['status'], 1) + self.assertEqual(r_json["status"], 1) - @patch('MySQLdb.connect') + @patch("MySQLdb.connect") def test_go_inception_check(self, _conn): c = Client() c.force_login(self.superuser1) @@ -334,11 +373,11 @@ def test_go_inception_check(self, _conn): "inception_remote_backup_host": "mysql", "inception_remote_backup_port": 3306, "inception_remote_backup_user": "mysql", - "inception_remote_backup_password": "123456" + "inception_remote_backup_password": "123456", } - r = c.post('/check/go_inception/', data=data) + r = c.post("/check/go_inception/", data=data) r_json = r.json() - self.assertEqual(r_json['status'], 0) + self.assertEqual(r_json["status"], 0) class ChartTest(TestCase): @@ -346,57 +385,78 @@ class ChartTest(TestCase): @classmethod def setUpClass(cls): - cls.u1 = User(username='some_user', display='用户1') + cls.u1 = User(username="some_user", display="用户1") cls.u1.save() - cls.u2 = User(username='some_other_user', display='用户2') + cls.u2 = User(username="some_other_user", display="用户2") cls.u2.save() - cls.superuser1 = User(username='super1', is_superuser=True) + cls.superuser1 = User(username="super1", is_superuser=True) cls.superuser1.save() cls.now = datetime.datetime.now() - cls.slave1 = Instance(instance_name='test_slave_instance', type='slave', db_type='mysql', - host='testhost', port=3306, user='mysql_user', password='mysql_password') + cls.slave1 = Instance( + instance_name="test_slave_instance", + type="slave", + db_type="mysql", + host="testhost", + port=3306, + user="mysql_user", + password="mysql_password", + ) cls.slave1.save() # 批量创建数据 ddl ,u1 ,g1, yesterday 组, 2 个数据 - ddl_workflow = [SqlWorkflow( - workflow_name='ddl %s' % i, - group_id=1, - group_name='g1', - engineer=cls.u1.username, - engineer_display=cls.u1.display, - audit_auth_groups='some_group', - create_time=cls.now - datetime.timedelta(days=1), - status='workflow_finish', - is_backup=True, - instance=cls.slave1, - db_name='some_db', - syntax_type=1 - ) for i in range(2)] + ddl_workflow = [ + SqlWorkflow( + workflow_name="ddl %s" % i, + group_id=1, + group_name="g1", + engineer=cls.u1.username, + engineer_display=cls.u1.display, + audit_auth_groups="some_group", + create_time=cls.now - datetime.timedelta(days=1), + status="workflow_finish", + is_backup=True, + instance=cls.slave1, + db_name="some_db", + syntax_type=1, + ) + for i in range(2) + ] # 批量创建数据 dml ,u1 ,g2, the day before yesterday 组, 3 个数据 - dml_workflow = [SqlWorkflow( - workflow_name='Test %s' % i, - group_id=2, - group_name='g2', - engineer=cls.u2.username, - engineer_display=cls.u2.display, - audit_auth_groups='some_group', - create_time=cls.now - datetime.timedelta(days=2), - status='workflow_finish', - is_backup=True, - instance=cls.slave1, - db_name='some_db', - syntax_type=2 - ) for i in range(3)] + dml_workflow = [ + SqlWorkflow( + workflow_name="Test %s" % i, + group_id=2, + group_name="g2", + engineer=cls.u2.username, + engineer_display=cls.u2.display, + audit_auth_groups="some_group", + create_time=cls.now - datetime.timedelta(days=2), + status="workflow_finish", + is_backup=True, + instance=cls.slave1, + db_name="some_db", + syntax_type=2, + ) + for i in range(3) + ] SqlWorkflow.objects.bulk_create(ddl_workflow + dml_workflow) # 保存内容数据 - ddl_workflow_content = [SqlWorkflowContent( - workflow=SqlWorkflow.objects.get(workflow_name='ddl %s' % i), - sql_content='some_sql', - ) for i in range(2)] - dml_workflow_content = [SqlWorkflowContent( - workflow=SqlWorkflow.objects.get(workflow_name='Test %s' % i), - sql_content='some_sql', - ) for i in range(3)] - SqlWorkflowContent.objects.bulk_create(ddl_workflow_content + dml_workflow_content) + ddl_workflow_content = [ + SqlWorkflowContent( + workflow=SqlWorkflow.objects.get(workflow_name="ddl %s" % i), + sql_content="some_sql", + ) + for i in range(2) + ] + dml_workflow_content = [ + SqlWorkflowContent( + workflow=SqlWorkflow.objects.get(workflow_name="Test %s" % i), + sql_content="some_sql", + ) + for i in range(3) + ] + SqlWorkflowContent.objects.bulk_create( + ddl_workflow_content + dml_workflow_content + ) # query_logs = [QueryLog( # instance_name = 'some_instance', @@ -419,35 +479,35 @@ def testGetDateList(self): begin = end - datetime.timedelta(days=3) result = dao.get_date_list(begin, end) self.assertEqual(len(result), 4) - self.assertEqual(result[0], begin.strftime('%Y-%m-%d')) - self.assertEqual(result[-1], end.strftime('%Y-%m-%d')) + self.assertEqual(result[0], begin.strftime("%Y-%m-%d")) + self.assertEqual(result[-1], end.strftime("%Y-%m-%d")) def testSyntaxList(self): """工单以语法类型分组""" dao = ChartDao() - expected_rows = (('DDL', 2), ('DML', 3)) + expected_rows = (("DDL", 2), ("DML", 3)) result = dao.syntax_type() - self.assertEqual(result['rows'], expected_rows) + self.assertEqual(result["rows"], expected_rows) def testWorkflowByDate(self): """TODO 按日分组工单数量统计测试""" dao = ChartDao() result = dao.workflow_by_date(30) - self.assertEqual(len(result['rows'][0]), 2) + self.assertEqual(len(result["rows"][0]), 2) def testWorkflowByGroup(self): """按组统计测试""" dao = ChartDao() result = dao.workflow_by_group(30) - expected_rows = (('g2', 3), ('g1', 2)) - self.assertEqual(result['rows'], expected_rows) + expected_rows = (("g2", 3), ("g1", 2)) + self.assertEqual(result["rows"], expected_rows) def testWorkflowByUser(self): """按用户统计测试""" dao = ChartDao() result = dao.workflow_by_user(30) expected_rows = ((self.u2.display, 3), (self.u1.display, 2)) - self.assertEqual(result['rows'], expected_rows) + self.assertEqual(result["rows"], expected_rows) def testDashboard(self): """Dashboard测试""" @@ -455,20 +515,19 @@ def testDashboard(self): # TODO 需要具体查看pyecharst有没有被调用, 以及调用的参数 c = Client() c.force_login(self.superuser1) - r = c.get('/dashboard/') + r = c.get("/dashboard/") self.assertEqual(r.status_code, 200) class AuthTest(TestCase): - def setUp(self): - self.username = 'some_user' - self.password = 'some_str' - self.u1 = User(username=self.username, password=self.password, display='用户1') + self.username = "some_user" + self.password = "some_str" + self.u1 = User(username=self.username, password=self.password, display="用户1") self.u1.save() - self.resource_group1 = ResourceGroup.objects.create(group_name='some_group') + self.resource_group1 = ResourceGroup.objects.create(group_name="some_group") sys_config = SysConfig() - sys_config.set('default_resource_group', self.resource_group1.group_name) + sys_config.set("default_resource_group", self.resource_group1.group_name) def tearDown(self): self.u1.delete() diff --git a/common/twofa/__init__.py b/common/twofa/__init__.py index 58b8e75af0..4dfdcdd6f8 100644 --- a/common/twofa/__init__.py +++ b/common/twofa/__init__.py @@ -2,7 +2,6 @@ class TwoFactorAuthBase: - def __init__(self, user=None): self.user = user @@ -14,16 +13,18 @@ def verify(self, otp): def enable(self): """启用""" - result = {'status': 1, 'msg': 'failed'} + result = {"status": 1, "msg": "failed"} return result def disable(self, auth_type): """禁用""" - result = {'status': 0, 'msg': 'ok'} + result = {"status": 0, "msg": "ok"} try: - TwoFactorAuthConfig.objects.get(user=self.user, auth_type=auth_type).delete() + TwoFactorAuthConfig.objects.get( + user=self.user, auth_type=auth_type + ).delete() except TwoFactorAuthConfig.DoesNotExist as e: - result = {'status': 0, 'msg': str(e)} + result = {"status": 0, "msg": str(e)} return result @property diff --git a/common/twofa/sms.py b/common/twofa/sms.py index 63e01cdeac..7e69f05415 100644 --- a/common/twofa/sms.py +++ b/common/twofa/sms.py @@ -9,7 +9,7 @@ import json import time -logger = logging.getLogger('default') +logger = logging.getLogger("default") class SMS(TwoFactorAuthBase): @@ -19,61 +19,65 @@ def __init__(self, user=None): super(SMS, self).__init__(user=user) self.user = user - sms_provider = SysConfig().get('sms_provider', 'disabled') - if sms_provider == 'aliyun': + sms_provider = SysConfig().get("sms_provider", "disabled") + if sms_provider == "aliyun": from common.utils.aliyun_sms import AliyunSMS + self.client = AliyunSMS() - elif sms_provider == 'tencent': + elif sms_provider == "tencent": from common.utils.tencent_sms import TencentSMS + self.client = TencentSMS() else: self.client = None def get_captcha(self, **kwargs): """获取验证码""" - result = {'status': 0, 'msg': 'ok'} - r = get_redis_connection('default') + result = {"status": 0, "msg": "ok"} + r = get_redis_connection("default") data = r.get(f"captcha-{kwargs['phone']}") if data: - captcha = json.loads(data.decode('utf8')) - if int(time.time()) - captcha['update_time'] > 60: + captcha = json.loads(data.decode("utf8")) + if int(time.time()) - captcha["update_time"] > 60: if self.client: result = self.client.send_code(**kwargs) else: - result = {'status': 1, 'msg': '系统未配置短信服务商!'} + result = {"status": 1, "msg": "系统未配置短信服务商!"} else: - result['status'] = 1 - result['msg'] = f"获取验证码太频繁,请于{captcha['update_time'] - int(time.time()) + 60}秒后再试" + result["status"] = 1 + result[ + "msg" + ] = f"获取验证码太频繁,请于{captcha['update_time'] - int(time.time()) + 60}秒后再试" else: if self.client: result = self.client.send_code(**kwargs) else: - result = {'status': 1, 'msg': '系统未配置短信服务商!'} + result = {"status": 1, "msg": "系统未配置短信服务商!"} return result def verify(self, otp, phone=None): """校验验证码""" - result = {'status': 0, 'msg': 'ok'} + result = {"status": 0, "msg": "ok"} if phone: phone = phone else: phone = TwoFactorAuthConfig.objects.get(username=self.user.username).phone - r = get_redis_connection('default') - data = r.get(f'captcha-{phone}') + r = get_redis_connection("default") + data = r.get(f"captcha-{phone}") if not data: - result['status'] = 1 - result['msg'] = '未获取验证码或验证码已过期!' + result["status"] = 1 + result["msg"] = "未获取验证码或验证码已过期!" else: - captcha = json.loads(data.decode('utf8')) - if otp != captcha['otp']: - result['status'] = 1 - result['msg'] = '验证码不正确!' + captcha = json.loads(data.decode("utf8")) + if otp != captcha["otp"]: + result["status"] = 1 + result["msg"] = "验证码不正确!" return result def save(self, phone): """保存2fa配置""" - result = {'status': 0, 'msg': 'ok'} + result = {"status": 0, "msg": "ok"} try: with transaction.atomic(): @@ -84,11 +88,11 @@ def save(self, phone): username=self.user.username, auth_type=self.auth_type, phone=phone, - user=self.user + user=self.user, ) except Exception as msg: - result['status'] = 1 - result['msg'] = str(msg) + result["status"] = 1 + result["msg"] = str(msg) logger.error(traceback.format_exc()) return result diff --git a/common/twofa/totp.py b/common/twofa/totp.py index 25a4aad2fb..c3a4f88bb1 100644 --- a/common/twofa/totp.py +++ b/common/twofa/totp.py @@ -8,7 +8,7 @@ import logging import pyotp -logger = logging.getLogger('default') +logger = logging.getLogger("default") class TOTP(TwoFactorAuthBase): @@ -20,34 +20,32 @@ def __init__(self, user=None): def verify(self, otp, key=None): """校验一次性验证码""" - result = {'status': 0, 'msg': 'ok'} + result = {"status": 0, "msg": "ok"} if key: secret_key = key else: - secret_key = TwoFactorAuthConfig.objects.get(username=self.user.username, - auth_type=self.auth_type).secret_key + secret_key = TwoFactorAuthConfig.objects.get( + username=self.user.username, auth_type=self.auth_type + ).secret_key t = pyotp.TOTP(secret_key) status = t.verify(otp) - result['status'] = 0 if status else 1 - result['msg'] = 'ok' if status else '验证码不正确!' + result["status"] = 0 if status else 1 + result["msg"] = "ok" if status else "验证码不正确!" return result def generate_key(self): """生成secret key""" - result = {'status': 0, 'msg': 'ok', 'data': {}} + result = {"status": 0, "msg": "ok", "data": {}} # 生成用户secret_key secret_key = pyotp.random_base32(32) - result['data'] = { - 'auth_type': self.auth_type, - 'key': secret_key - } + result["data"] = {"auth_type": self.auth_type, "key": secret_key} return result def save(self, secret_key): """保存2fa配置""" - result = {'status': 0, 'msg': 'ok'} + result = {"status": 0, "msg": "ok"} try: with transaction.atomic(): @@ -58,11 +56,11 @@ def save(self, secret_key): username=self.user.username, auth_type=self.auth_type, secret_key=secret_key, - user=self.user + user=self.user, ) except Exception as msg: - result['status'] = 1 - result['msg'] = str(msg) + result["status"] = 1 + result["msg"] = str(msg) logger.error(traceback.format_exc()) return result @@ -77,13 +75,16 @@ def generate_qrcode(request, data): """生成并返回二维码图片流""" user = request.user - username = user.username if user.is_authenticated else request.session.get('user') + username = user.username if user.is_authenticated else request.session.get("user") secret_key = data # 生成二维码 - qr_data = pyotp.totp.TOTP(secret_key).provisioning_uri(username, issuer_name="Archery") - qrcode = QRCode(version=1, error_correction=constants.ERROR_CORRECT_L, - box_size=6, border=4) + qr_data = pyotp.totp.TOTP(secret_key).provisioning_uri( + username, issuer_name="Archery" + ) + qrcode = QRCode( + version=1, error_correction=constants.ERROR_CORRECT_L, box_size=6, border=4 + ) try: qrcode.add_data(qr_data) qrcode.make(fit=True) diff --git a/common/utils/aes_decryptor.py b/common/utils/aes_decryptor.py index d51aff3579..10f9f6c040 100644 --- a/common/utils/aes_decryptor.py +++ b/common/utils/aes_decryptor.py @@ -2,42 +2,42 @@ from binascii import b2a_hex, a2b_hex -class Prpcrypt(): +class Prpcrypt: def __init__(self): - self.key = 'eCcGFZQj6PNoSSma31LR39rTzTbLkU8E'.encode('utf-8') + self.key = "eCcGFZQj6PNoSSma31LR39rTzTbLkU8E".encode("utf-8") self.mode = AES.MODE_CBC # 加密函数,如果text不足16位就用空格补足为16位, # 如果大于16当时不是16的倍数,那就补足为16的倍数。 def encrypt(self, text): - cryptor = AES.new(self.key, self.mode, b'0000000000000000') + cryptor = AES.new(self.key, self.mode, b"0000000000000000") # 这里密钥key 长度必须为16(AES-128), # 24(AES-192),或者32 (AES-256)Bytes 长度 # 目前AES-128 足够目前使用 length = 16 count = len(text) if count < length: - add = (length - count) + add = length - count # \0 backspace - text = text + ('\0' * add) + text = text + ("\0" * add) elif count > length: - add = (length - (count % length)) - text = text + ('\0' * add) - self.ciphertext = cryptor.encrypt(text.encode('utf-8')) + add = length - (count % length) + text = text + ("\0" * add) + self.ciphertext = cryptor.encrypt(text.encode("utf-8")) # 因为AES加密时候得到的字符串不一定是ascii字符集的,输出到终端或者保存时候可能存在问题 # 所以这里统一把加密后的字符串转化为16进制字符串 - return b2a_hex(self.ciphertext).decode(encoding='utf-8') + return b2a_hex(self.ciphertext).decode(encoding="utf-8") # 解密后,去掉补足的空格用strip() 去掉 def decrypt(self, text): - cryptor = AES.new(self.key, self.mode, b'0000000000000000') + cryptor = AES.new(self.key, self.mode, b"0000000000000000") plain_text = cryptor.decrypt(a2b_hex(text)) - return plain_text.decode().rstrip('\0') + return plain_text.decode().rstrip("\0") -if __name__ == '__main__': +if __name__ == "__main__": pc = Prpcrypt() # 初始化密钥 - e = pc.encrypt('123456') # 加密 + e = pc.encrypt("123456") # 加密 d = pc.decrypt(e) # 解密 print("加密:", str(e)) print("解密:", str(d)) diff --git a/common/utils/aliyun_sdk.py b/common/utils/aliyun_sdk.py index 60d61fc299..1d1edb1051 100644 --- a/common/utils/aliyun_sdk.py +++ b/common/utils/aliyun_sdk.py @@ -3,12 +3,15 @@ import traceback from aliyunsdkcore.client import AcsClient -from aliyunsdkrds.request.v20140815 import DescribeSlowLogsRequest, DescribeSlowLogRecordsRequest, \ - RequestServiceOfCloudDBARequest +from aliyunsdkrds.request.v20140815 import ( + DescribeSlowLogsRequest, + DescribeSlowLogRecordsRequest, + RequestServiceOfCloudDBARequest, +) import simplejson as json import logging -logger = logging.getLogger('default') +logger = logging.getLogger("default") class Aliyun(object): @@ -19,16 +22,21 @@ def __init__(self, rds): secret = rds.ak.raw_key_secret self.clt = AcsClient(ak=ak, secret=secret) except Exception as m: - raise Exception(f'阿里云认证失败:{m}{traceback.format_exc()}') + raise Exception(f"阿里云认证失败:{m}{traceback.format_exc()}") def request_api(self, request, *values): if values: for value in values: for k, v in value.items(): request.add_query_param(k, v) - request.set_accept_format('json') + request.set_accept_format("json") result = self.clt.do_action_with_exception(request) - return json.dumps(json.loads(result.decode('utf-8')), indent=4, sort_keys=False, ensure_ascii=False) + return json.dumps( + json.loads(result.decode("utf-8")), + indent=4, + sort_keys=False, + ensure_ascii=False, + ) # 阿里云UTC时间转换为本地时区时间 @staticmethod @@ -42,8 +50,13 @@ def utc2local(utc, utc_format): def DescribeSlowLogs(self, StartTime, EndTime, **kwargs): """获取实例慢日志列表DBName,SortKey、PageSize、PageNumber""" request = DescribeSlowLogsRequest.DescribeSlowLogsRequest() - values = {"action_name": "DescribeSlowLogs", "DBInstanceId": self.DBInstanceId, - "StartTime": StartTime, "EndTime": EndTime, "SortKey": "TotalExecutionCounts"} + values = { + "action_name": "DescribeSlowLogs", + "DBInstanceId": self.DBInstanceId, + "StartTime": StartTime, + "EndTime": EndTime, + "SortKey": "TotalExecutionCounts", + } values = dict(values, **kwargs) result = self.request_api(request, values) return result @@ -51,14 +64,19 @@ def DescribeSlowLogs(self, StartTime, EndTime, **kwargs): def DescribeSlowLogRecords(self, StartTime, EndTime, **kwargs): """查看慢日志明细SQLId,DBName、PageSize、PageNumber""" request = DescribeSlowLogRecordsRequest.DescribeSlowLogRecordsRequest() - values = {"action_name": "DescribeSlowLogRecords", "DBInstanceId": self.DBInstanceId, - "StartTime": StartTime, "EndTime": EndTime} + values = { + "action_name": "DescribeSlowLogRecords", + "DBInstanceId": self.DBInstanceId, + "StartTime": StartTime, + "EndTime": EndTime, + } values = dict(values, **kwargs) result = self.request_api(request, values) return result - def RequestServiceOfCloudDBA(self, ServiceRequestType, ServiceRequestParam, - **kwargs): + def RequestServiceOfCloudDBA( + self, ServiceRequestType, ServiceRequestParam, **kwargs + ): """ 获取统计信息:'GetTimedMonData',{"Language":"zh","KeyGroup":"mem_cpu_usage","KeyName":"","StartTime":"2018-01-15T04:03:26Z","EndTime":"2018-01-15T05:03:26Z"} mem_cpu_usage、iops_usage、detailed_disk_space @@ -68,8 +86,12 @@ def RequestServiceOfCloudDBA(self, ServiceRequestType, ServiceRequestParam, 获取资源利用信息:'GetResourceUsage',{"Language":"zh"} """ request = RequestServiceOfCloudDBARequest.RequestServiceOfCloudDBARequest() - values = {"action_name": "RequestServiceOfCloudDBA", "DBInstanceId": self.DBInstanceId, - "ServiceRequestType": ServiceRequestType, "ServiceRequestParam": ServiceRequestParam} + values = { + "action_name": "RequestServiceOfCloudDBA", + "DBInstanceId": self.DBInstanceId, + "ServiceRequestType": ServiceRequestType, + "ServiceRequestParam": ServiceRequestParam, + } values = dict(values, **kwargs) result = self.request_api(request, values) return result diff --git a/common/utils/aliyun_sms.py b/common/utils/aliyun_sms.py index 162b039fcf..92458ed8cf 100644 --- a/common/utils/aliyun_sms.py +++ b/common/utils/aliyun_sms.py @@ -7,45 +7,44 @@ import traceback import logging -logger = logging.getLogger('default') +logger = logging.getLogger("default") class AliyunSMS: def __init__(self): all_config = SysConfig() - self.access_key_id = all_config.get('aliyun_access_key_id', '') - self.access_key_secret = all_config.get('aliyun_access_key_secret', '') - self.sign_name = all_config.get('aliyun_sign_name', '') - self.template_code = all_config.get('aliyun_template_code', '') - self.variable_name = all_config.get('aliyun_variable_name', 'code') + self.access_key_id = all_config.get("aliyun_access_key_id", "") + self.access_key_secret = all_config.get("aliyun_access_key_secret", "") + self.sign_name = all_config.get("aliyun_sign_name", "") + self.template_code = all_config.get("aliyun_template_code", "") + self.variable_name = all_config.get("aliyun_variable_name", "code") def create_client(self): config = open_api_models.Config( - access_key_id=self.access_key_id, - access_key_secret=self.access_key_secret + access_key_id=self.access_key_id, access_key_secret=self.access_key_secret ) - config.endpoint = f'dysmsapi.aliyuncs.com' + config.endpoint = f"dysmsapi.aliyuncs.com" return Dysmsapi20170525Client(config) def send_code(self, **kwargs): - result = {'status': 0, 'msg': 'ok'} + result = {"status": 0, "msg": "ok"} client = AliyunSMS.create_client(self) send_sms_request = dysmsapi_20170525_models.SendSmsRequest( - phone_numbers=kwargs['phone'], + phone_numbers=kwargs["phone"], sign_name=self.sign_name, template_code=self.template_code, - template_param=f"{{{self.variable_name}: \'{kwargs['otp']}\'}}" + template_param=f"{{{self.variable_name}: '{kwargs['otp']}'}}", ) runtime = util_models.RuntimeOptions() try: response = client.send_sms_with_options(send_sms_request, runtime) - if response.body.code != 'OK': - result['status'] = 1 - result['msg'] = response.body.message + if response.body.code != "OK": + result["status"] = 1 + result["msg"] = response.body.message except Exception as e: - result['status'] = 1 - result['msg'] = str(e) + result["status"] = 1 + result["msg"] = str(e) logger.error(str(e)) logger.error(traceback.format_exc()) diff --git a/common/utils/chart_dao.py b/common/utils/chart_dao.py index 2e4b250ee0..c4858326e0 100644 --- a/common/utils/chart_dao.py +++ b/common/utils/chart_dao.py @@ -16,10 +16,7 @@ def __query(sql): if fields: for i in fields: column_list.append(i[0]) - return { - 'column_list': column_list, - 'rows': rows - } + return {"column_list": column_list, "rows": rows} # 获取连续时间 @staticmethod @@ -33,7 +30,7 @@ def get_date_list(begin_date, end_date): # 语法类型 def syntax_type(self): - sql = ''' + sql = """ select case when syntax_type = 1 then 'DDL' @@ -43,73 +40,83 @@ def syntax_type(self): end as syntax_type, count(*) from sql_workflow - group by syntax_type;''' + group by syntax_type;""" return self.__query(sql) # 工单数量统计 def workflow_by_date(self, cycle): - sql = ''' + sql = """ select date_format(create_time, '%Y-%m-%d'), count(*) from sql_workflow where create_time >= date_add(now(), interval -{} day) group by date_format(create_time, '%Y-%m-%d') - order by 1 asc;'''.format(cycle) + order by 1 asc;""".format( + cycle + ) return self.__query(sql) # 工单按组统计 def workflow_by_group(self, cycle): - sql = ''' + sql = """ select group_name, count(*) from sql_workflow where create_time >= date_add(now(), interval -{} day ) group by group_id - order by count(*) desc;'''.format(cycle) + order by count(*) desc;""".format( + cycle + ) return self.__query(sql) def workflow_by_user(self, cycle): """工单按人统计""" # TODO select 的对象应该为engineer ID, 查询时应作联合查询查出用户中文名 - sql = ''' + sql = """ select engineer_display, count(*) from sql_workflow where create_time >= date_add(now(), interval -{} day) group by engineer_display - order by count(*) desc;'''.format(cycle) + order by count(*) desc;""".format( + cycle + ) return self.__query(sql) # SQL查询统计(每日检索行数) def querylog_effect_row_by_date(self, cycle): - sql = ''' + sql = """ select date_format(create_time, '%Y-%m-%d'), sum(effect_row) from query_log where create_time >= date_add(now(), interval -{} day ) group by date_format(create_time, '%Y-%m-%d') - order by sum(effect_row) desc;'''.format(cycle) + order by sum(effect_row) desc;""".format( + cycle + ) return self.__query(sql) # SQL查询统计(每日检索次数) def querylog_count_by_date(self, cycle): - sql = ''' + sql = """ select date_format(create_time, '%Y-%m-%d'), count(*) from query_log where create_time >= date_add(now(), interval -{} day ) group by date_format(create_time, '%Y-%m-%d') - order by count(*) desc;'''.format(cycle) + order by count(*) desc;""".format( + cycle + ) return self.__query(sql) # SQL查询统计(用户检索行数) def querylog_effect_row_by_user(self, cycle): - sql = ''' + sql = """ select user_display, sum(effect_row) @@ -117,12 +124,14 @@ def querylog_effect_row_by_user(self, cycle): where create_time >= date_add(now(), interval -{} day) group by user_display order by sum(effect_row) desc - limit 10;'''.format(cycle) + limit 10;""".format( + cycle + ) return self.__query(sql) # SQL查询统计(DB检索行数) def querylog_effect_row_by_db(self, cycle): - sql = ''' + sql = """ select db_name, sum(effect_row) @@ -130,7 +139,9 @@ def querylog_effect_row_by_db(self, cycle): where create_time >= date_add(now(), interval -{} day) group by db_name order by sum(effect_row) desc - limit 10;'''.format(cycle) + limit 10;""".format( + cycle + ) return self.__query(sql) # 慢日志历史趋势图(按次数) @@ -151,24 +162,28 @@ def slow_query_review_history_by_pct_95_time(self, checksum): # 慢日志db/user维度统计 def slow_query_count_by_db_by_user(self, cycle): - sql = ''' + sql = """ select concat(db_max,' user: ' ,user_max), sum(ts_cnt) from mysql_slow_query_review_history where ts_min >= date_sub(now(), interval {} day) group by db_max,user_max order by sum(ts_cnt) desc limit 50; - '''.format(cycle) + """.format( + cycle + ) return self.__query(sql) # 慢日志db维度统计 def slow_query_count_by_db(self, cycle): - sql = ''' + sql = """ select db_max, sum(ts_cnt) from mysql_slow_query_review_history where ts_min >= date_sub(now(), interval {} day) group by db_max order by sum(ts_cnt) desc limit 50; - '''.format(cycle) - return self.__query(sql) \ No newline at end of file + """.format( + cycle + ) + return self.__query(sql) diff --git a/common/utils/const.py b/common/utils/const.py index dff832d080..96d4c363a5 100644 --- a/common/utils/const.py +++ b/common/utils/const.py @@ -1,76 +1,77 @@ # -*- coding: UTF-8 -*- + class Const(object): # 定时任务id的前缀 workflowJobprefix = { - 'query': 'query', - 'sqlreview': 'sqlreview', - 'archive': 'archive' + "query": "query", + "sqlreview": "sqlreview", + "archive": "archive", } class WorkflowDict: # 工作流申请类型,1.query,2.SQL上线申请 workflow_type = { - 'query': 1, - 'query_display': '查询权限申请', - 'sqlreview': 2, - 'sqlreview_display': 'SQL上线申请', - 'archive': 3, - 'archive_display': '数据归档申请', + "query": 1, + "query_display": "查询权限申请", + "sqlreview": 2, + "sqlreview_display": "SQL上线申请", + "archive": 3, + "archive_display": "数据归档申请", } # 工作流状态,0.待审核 1.审核通过 2.审核不通过 3.审核取消 workflow_status = { - 'audit_wait': 0, - 'audit_wait_display': '待审核', - 'audit_success': 1, - 'audit_success_display': '审核通过', - 'audit_reject': 2, - 'audit_reject_display': '审核不通过', - 'audit_abort': 3, - 'audit_abort_display': '审核取消', + "audit_wait": 0, + "audit_wait_display": "待审核", + "audit_success": 1, + "audit_success_display": "审核通过", + "audit_reject": 2, + "audit_reject_display": "审核不通过", + "audit_abort": 3, + "audit_abort_display": "审核取消", } class SQLTuning: SYS_PARM_FILTER = [ - 'BINLOG_CACHE_SIZE', - 'BULK_INSERT_BUFFER_SIZE', - 'HAVE_PARTITION_ENGINE', - 'HAVE_QUERY_CACHE', - 'INTERACTIVE_TIMEOUT', - 'JOIN_BUFFER_SIZE', - 'KEY_BUFFER_SIZE', - 'KEY_CACHE_AGE_THRESHOLD', - 'KEY_CACHE_BLOCK_SIZE', - 'KEY_CACHE_DIVISION_LIMIT', - 'LARGE_PAGES', - 'LOCKED_IN_MEMORY', - 'LONG_QUERY_TIME', - 'MAX_ALLOWED_PACKET', - 'MAX_BINLOG_CACHE_SIZE', - 'MAX_BINLOG_SIZE', - 'MAX_CONNECT_ERRORS', - 'MAX_CONNECTIONS', - 'MAX_JOIN_SIZE', - 'MAX_LENGTH_FOR_SORT_DATA', - 'MAX_SEEKS_FOR_KEY', - 'MAX_SORT_LENGTH', - 'MAX_TMP_TABLES', - 'MAX_USER_CONNECTIONS', - 'OPTIMIZER_PRUNE_LEVEL', - 'OPTIMIZER_SEARCH_DEPTH', - 'QUERY_CACHE_SIZE', - 'QUERY_CACHE_TYPE', - 'QUERY_PREALLOC_SIZE', - 'RANGE_ALLOC_BLOCK_SIZE', - 'READ_BUFFER_SIZE', - 'READ_RND_BUFFER_SIZE', - 'SORT_BUFFER_SIZE', - 'SQL_MODE', - 'TABLE_CACHE', - 'THREAD_CACHE_SIZE', - 'TMP_TABLE_SIZE', - 'WAIT_TIMEOUT' + "BINLOG_CACHE_SIZE", + "BULK_INSERT_BUFFER_SIZE", + "HAVE_PARTITION_ENGINE", + "HAVE_QUERY_CACHE", + "INTERACTIVE_TIMEOUT", + "JOIN_BUFFER_SIZE", + "KEY_BUFFER_SIZE", + "KEY_CACHE_AGE_THRESHOLD", + "KEY_CACHE_BLOCK_SIZE", + "KEY_CACHE_DIVISION_LIMIT", + "LARGE_PAGES", + "LOCKED_IN_MEMORY", + "LONG_QUERY_TIME", + "MAX_ALLOWED_PACKET", + "MAX_BINLOG_CACHE_SIZE", + "MAX_BINLOG_SIZE", + "MAX_CONNECT_ERRORS", + "MAX_CONNECTIONS", + "MAX_JOIN_SIZE", + "MAX_LENGTH_FOR_SORT_DATA", + "MAX_SEEKS_FOR_KEY", + "MAX_SORT_LENGTH", + "MAX_TMP_TABLES", + "MAX_USER_CONNECTIONS", + "OPTIMIZER_PRUNE_LEVEL", + "OPTIMIZER_SEARCH_DEPTH", + "QUERY_CACHE_SIZE", + "QUERY_CACHE_TYPE", + "QUERY_PREALLOC_SIZE", + "RANGE_ALLOC_BLOCK_SIZE", + "READ_BUFFER_SIZE", + "READ_RND_BUFFER_SIZE", + "SORT_BUFFER_SIZE", + "SQL_MODE", + "TABLE_CACHE", + "THREAD_CACHE_SIZE", + "TMP_TABLE_SIZE", + "WAIT_TIMEOUT", ] diff --git a/common/utils/convert.py b/common/utils/convert.py index d464e78c82..9fb4e0eace 100644 --- a/common/utils/convert.py +++ b/common/utils/convert.py @@ -9,9 +9,11 @@ class Convert(Func): """ def __init__(self, expression, transcoding_name, **extra): - super(Convert, self).__init__(expression=expression, transcoding_name=transcoding_name, **extra) + super(Convert, self).__init__( + expression=expression, transcoding_name=transcoding_name, **extra + ) def as_mysql(self, compiler, connection): - self.function = 'CONVERT' - self.template = '%(function)s(%(expression)s USING %(transcoding_name)s)' + self.function = "CONVERT" + self.template = "%(function)s(%(expression)s USING %(transcoding_name)s)" return super(Convert, self).as_sql(compiler, connection) diff --git a/common/utils/ding_api.py b/common/utils/ding_api.py index 58bdef4195..6160204dad 100644 --- a/common/utils/ding_api.py +++ b/common/utils/ding_api.py @@ -10,8 +10,8 @@ from sql.models import Users from sql.utils.tasks import add_sync_ding_user_schedule -logger = logging.getLogger('default') -rs = get_redis_connection('default') +logger = logging.getLogger("default") +rs = get_redis_connection("default") def get_access_token(): @@ -26,13 +26,13 @@ def get_access_token(): return access_token.decode() # 请求钉钉接口获取 sys_config = SysConfig() - app_key = sys_config.get('ding_app_key') - app_secret = sys_config.get('ding_app_secret') + app_key = sys_config.get("ding_app_key") + app_secret = sys_config.get("ding_app_secret") url = f"https://oapi.dingtalk.com/gettoken?appkey={app_key}&appsecret={app_secret}" resp = requests.get(url, timeout=3).json() - if resp.get('errcode') == 0: - access_token = resp.get('access_token') - expires_in = resp.get('expires_in') + if resp.get("errcode") == 0: + access_token = resp.get("access_token") + expires_in = resp.get("expires_in") rs.execute_command(f"SETEX ding_access_token {expires_in-60} {access_token}") return access_token else: @@ -43,12 +43,12 @@ def get_access_token(): def get_ding_user_id(username): """更新用户ding_user_id""" try: - ding_user_id = rs.execute_command('GET {}'.format(username.lower())) + ding_user_id = rs.execute_command("GET {}".format(username.lower())) if ding_user_id: user = Users.objects.get(username=username) if user.ding_user_id != str(ding_user_id, encoding="utf8"): user.ding_user_id = str(ding_user_id, encoding="utf8") - user.save(update_fields=['ding_user_id']) + user.save(update_fields=["ding_user_id"]) except Exception as e: logger.error(f"更新用户ding_user_id失败:{e}") @@ -56,9 +56,13 @@ def get_ding_user_id(username): def get_dept_list_id_fetch_child(token, parent_dept_id): """获取所有子部门列表""" ids = [int(parent_dept_id)] - url = 'https://oapi.dingtalk.com/department/list_ids?id={0}&access_token={1}'.format(parent_dept_id, token) + url = ( + "https://oapi.dingtalk.com/department/list_ids?id={0}&access_token={1}".format( + parent_dept_id, token + ) + ) resp = requests.get(url, timeout=3).json() - if resp.get('errcode') == 0: + if resp.get("errcode") == 0: for dept_id in resp.get("sub_dept_id_list"): ids.extend(get_dept_list_id_fetch_child(token, dept_id)) return list(set(ids)) @@ -70,40 +74,46 @@ def sync_ding_user_id(): 所以可根据钉钉中 jobnumber 查到该用户的 ding_user_id。 """ sys_config = SysConfig() - ding_dept_ids = sys_config.get('ding_dept_ids', '') - username2ding = sys_config.get('ding_archery_username') + ding_dept_ids = sys_config.get("ding_dept_ids", "") + username2ding = sys_config.get("ding_archery_username") token = get_access_token() if not token: return False # 获取全部部门列表 sub_dept_id_list = [] - for dept_id in list(set(ding_dept_ids.split(','))): + for dept_id in list(set(ding_dept_ids.split(","))): sub_dept_id_list.extend(get_dept_list_id_fetch_child(token, dept_id)) # 遍历部门下的用户 user_ids = [] for sdi in sub_dept_id_list: - url = f'https://oapi.dingtalk.com/user/getDeptMember?access_token={token}&deptId={sdi}' + url = f"https://oapi.dingtalk.com/user/getDeptMember?access_token={token}&deptId={sdi}" try: resp = requests.get(url, timeout=3).json() - if resp.get('errcode') == 0: - user_ids.extend(resp.get('userIds')) + if resp.get("errcode") == 0: + user_ids.extend(resp.get("userIds")) else: - raise Exception(f'获取部门用户出错:{resp}') + raise Exception(f"获取部门用户出错:{resp}") except Exception as e: - raise Exception(f'获取部门用户出错:{e}') + raise Exception(f"获取部门用户出错:{e}") # 获取所有用户信息并缓存 for user_id in list(set(user_ids)): - url = f'https://oapi.dingtalk.com/user/get?access_token={token}&userid={user_id}' + url = ( + f"https://oapi.dingtalk.com/user/get?access_token={token}&userid={user_id}" + ) try: resp = requests.get(url, timeout=3).json() - if resp.get('errcode') == 0: + if resp.get("errcode") == 0: if not resp.get(username2ding): - raise Exception(f'钉钉用户信息不包含{username2ding}字段,无法获取id信息,请确认ding_archery_username配置{resp}') - rs.execute_command(f"SETEX {resp.get(username2ding).lower()} 86400 {resp.get('userid')}") + raise Exception( + f"钉钉用户信息不包含{username2ding}字段,无法获取id信息,请确认ding_archery_username配置{resp}" + ) + rs.execute_command( + f"SETEX {resp.get(username2ding).lower()} 86400 {resp.get('userid')}" + ) else: - raise Exception(f'获取用户信息出错:{resp}') + raise Exception(f"获取用户信息出错:{resp}") except Exception as e: - raise Exception(f'获取用户信息出错:{e}') + raise Exception(f"获取用户信息出错:{e}") return True diff --git a/common/utils/extend_json_encoder.py b/common/utils/extend_json_encoder.py index 8f8837ad1c..3232ccd393 100644 --- a/common/utils/extend_json_encoder.py +++ b/common/utils/extend_json_encoder.py @@ -13,17 +13,17 @@ @singledispatch def convert(o): - raise TypeError('can not convert type') + raise TypeError("can not convert type") @convert.register(datetime) def _(o): - return o.strftime('%Y-%m-%d %H:%M:%S') + return o.strftime("%Y-%m-%d %H:%M:%S") @convert.register(date) def _(o): - return o.strftime('%Y-%m-%d') + return o.strftime("%Y-%m-%d") @convert.register(timedelta) @@ -80,11 +80,10 @@ def default(self, obj): class ExtendJSONEncoderFTime(json.JSONEncoder): - def default(self, obj): try: if isinstance(obj, datetime): - return obj.isoformat(' ') + return obj.isoformat(" ") else: return convert(obj) except TypeError: @@ -93,19 +92,20 @@ def default(self, obj): # 使用simplejson处理形如 b'\xaa' 的bytes类型数据会失败,但使用json模块构造这个对象时不能使用bigint_as_string方法 import json + + class ExtendJSONEncoderBytes(json.JSONEncoder): - def default(self, obj): + def default(self, obj): try: # 使用convert.register处理会报错 ValueError: Circular reference detected # 不是utf-8格式的bytes格式需要先进行base64编码转换 if isinstance(obj, bytes): try: - return o.decode('utf-8') + return o.decode("utf-8") except: - return base64.b64encode(obj).decode('utf-8') + return base64.b64encode(obj).decode("utf-8") else: return convert(obj) except TypeError: print(type(obj)) return super(ExtendJSONEncoderBytes, self).default(obj) - diff --git a/common/utils/feishu_api.py b/common/utils/feishu_api.py index 27e977070d..b765d8e696 100644 --- a/common/utils/feishu_api.py +++ b/common/utils/feishu_api.py @@ -5,13 +5,13 @@ from common.config import SysConfig from django.core.cache import cache -logger = logging.getLogger('default') +logger = logging.getLogger("default") def get_feishu_access_token(): # 优先获取缓存 try: - token = cache.get('feishu_access_token') + token = cache.get("feishu_access_token") except Exception as e: logger.error(f"获取飞书token缓存出错:{e}") token = None @@ -19,20 +19,17 @@ def get_feishu_access_token(): return token # 请求飞书接口获取 sys_config = SysConfig() - app_id = sys_config.get('feishu_appid') - app_secret = sys_config.get('feishu_app_secret') + app_id = sys_config.get("feishu_appid") + app_secret = sys_config.get("feishu_app_secret") url = f"https://open.feishu.cn/open-apis/auth/v3/app_access_token/internal/" - data = { - "app_id": app_id, - "app_secret": app_secret - } + data = {"app_id": app_id, "app_secret": app_secret} resp = requests.post(url, json=data, timeout=5).json() # resp = requests.get(url, timeout=3).json() logger.info(f"获取飞书access_token信息成功:{resp}") - if resp.get('code') == 0: - access_token = resp.get('app_access_token') - expires_in = resp.get('expire') - cache.set('feishu_access_token', access_token, timeout=expires_in - 60) + if resp.get("code") == 0: + access_token = resp.get("app_access_token") + expires_in = resp.get("expire") + cache.set("feishu_access_token", access_token, timeout=expires_in - 60) return access_token else: logger.error(f"获取飞书access_token出错:{resp}") @@ -40,11 +37,15 @@ def get_feishu_access_token(): def get_feishu_open_id(email): - url = 'https://open.feishu.cn/open-apis/user/v1/batch_get_id?' - resp = requests.get(url, timeout=3, headers={'Authorization': "Bearer " + get_feishu_access_token()}, - params={"emails": email}).json() + url = "https://open.feishu.cn/open-apis/user/v1/batch_get_id?" + resp = requests.get( + url, + timeout=3, + headers={"Authorization": "Bearer " + get_feishu_access_token()}, + params={"emails": email}, + ).json() result = [] - if resp.get('code') == 0: - for key in resp.get('data').get('email_users').values(): + if resp.get("code") == 0: + for key in resp.get("data").get("email_users").values(): result.append(key[0][f"open_id"]) return result diff --git a/common/utils/global_info.py b/common/utils/global_info.py index 1994f1af75..5dc400a6c8 100644 --- a/common/utils/global_info.py +++ b/common/utils/global_info.py @@ -8,7 +8,7 @@ def global_info(request): """存放用户,菜单信息等.""" user = request.user - twofa_type = 'disabled' + twofa_type = "disabled" if user and user.is_authenticated: # 获取待办数量 try: @@ -20,17 +20,15 @@ def global_info(request): if twofa_config: twofa_type = twofa_config[0].auth_type else: - twofa_type = 'disabled' + twofa_type = "disabled" else: todo = 0 - watermark_enabled = SysConfig().get('watermark_enabled', False) - - + watermark_enabled = SysConfig().get("watermark_enabled", False) return { - 'todo': todo, - 'archery_version': display_version, - 'watermark_enabled': watermark_enabled, - 'twofa_type': twofa_type + "todo": todo, + "archery_version": display_version, + "watermark_enabled": watermark_enabled, + "twofa_type": twofa_type, } diff --git a/common/utils/permission.py b/common/utils/permission.py index cc334be1ab..0b8c10228b 100644 --- a/common/utils/permission.py +++ b/common/utils/permission.py @@ -12,10 +12,10 @@ def wrapper(request, *args, **kw): if user.is_superuser is False: if request.is_ajax(): - result = {'status': 1, 'msg': '您无权操作,请联系管理员', 'data': []} - return HttpResponse(json.dumps(result), content_type='application/json') + result = {"status": 1, "msg": "您无权操作,请联系管理员", "data": []} + return HttpResponse(json.dumps(result), content_type="application/json") else: - context = {'errMsg': "您无权操作,请联系管理员"} + context = {"errMsg": "您无权操作,请联系管理员"} return render(request, "error.html", context) return func(request, *args, **kw) @@ -31,10 +31,12 @@ def wrapper(request, *args, **kw): user = request.user if user.role not in roles and user.is_superuser is False: if request.is_ajax(): - result = {'status': 1, 'msg': '您无权操作,请联系管理员', 'data': []} - return HttpResponse(json.dumps(result), content_type='application/json') + result = {"status": 1, "msg": "您无权操作,请联系管理员", "data": []} + return HttpResponse( + json.dumps(result), content_type="application/json" + ) else: - context = {'errMsg': "您无权操作,请联系管理员"} + context = {"errMsg": "您无权操作,请联系管理员"} return render(request, "error.html", context) return func(request, *args, **kw) diff --git a/common/utils/sendmsg.py b/common/utils/sendmsg.py index 4da8b3ebef..db551a2069 100755 --- a/common/utils/sendmsg.py +++ b/common/utils/sendmsg.py @@ -15,33 +15,32 @@ from common.utils.wx_api import get_wx_access_token from common.utils.feishu_api import * -logger = logging.getLogger('default') +logger = logging.getLogger("default") class MsgSender(object): - def __init__(self, **kwargs): if kwargs: - self.MAIL_REVIEW_SMTP_SERVER = kwargs.get('server') - self.MAIL_REVIEW_SMTP_PORT = kwargs.get('port', 0) - self.MAIL_REVIEW_FROM_ADDR = kwargs.get('user') - self.MAIL_REVIEW_FROM_PASSWORD = kwargs.get('password') - self.MAIL_SSL = kwargs.get('ssl') + self.MAIL_REVIEW_SMTP_SERVER = kwargs.get("server") + self.MAIL_REVIEW_SMTP_PORT = kwargs.get("port", 0) + self.MAIL_REVIEW_FROM_ADDR = kwargs.get("user") + self.MAIL_REVIEW_FROM_PASSWORD = kwargs.get("password") + self.MAIL_SSL = kwargs.get("ssl") else: sys_config = SysConfig() # email信息 - self.MAIL_REVIEW_SMTP_SERVER = sys_config.get('mail_smtp_server') - self.MAIL_REVIEW_SMTP_PORT = sys_config.get('mail_smtp_port', 0) - self.MAIL_SSL = sys_config.get('mail_ssl') - self.MAIL_REVIEW_FROM_ADDR = sys_config.get('mail_smtp_user') - self.MAIL_REVIEW_FROM_PASSWORD = sys_config.get('mail_smtp_password') + self.MAIL_REVIEW_SMTP_SERVER = sys_config.get("mail_smtp_server") + self.MAIL_REVIEW_SMTP_PORT = sys_config.get("mail_smtp_port", 0) + self.MAIL_SSL = sys_config.get("mail_ssl") + self.MAIL_REVIEW_FROM_ADDR = sys_config.get("mail_smtp_user") + self.MAIL_REVIEW_FROM_PASSWORD = sys_config.get("mail_smtp_password") # 钉钉信息 - self.ding_agent_id = sys_config.get('ding_agent_id') + self.ding_agent_id = sys_config.get("ding_agent_id") # 企业微信信息 - self.wx_agent_id = sys_config.get('wx_agent_id') + self.wx_agent_id = sys_config.get("wx_agent_id") # 飞书信息 - self.feishu_appid = sys_config.get('feishu_appid') - self.feishu_app_secret = sys_config.get('feishu_app_secret') + self.feishu_appid = sys_config.get("feishu_appid") + self.feishu_app_secret = sys_config.get("feishu_app_secret") if self.MAIL_REVIEW_SMTP_PORT: self.MAIL_REVIEW_SMTP_PORT = int(self.MAIL_REVIEW_SMTP_PORT) @@ -57,11 +56,14 @@ def _add_attachment(filename): :param filename: :return: """ - file_msg = email.mime.base.MIMEBase('application', 'octet-stream') - file_msg.set_payload(open(filename, 'rb').read()) + file_msg = email.mime.base.MIMEBase("application", "octet-stream") + file_msg.set_payload(open(filename, "rb").read()) # 附件如果有中文会出现乱码问题,加入gbk - file_msg.add_header('Content-Disposition', 'attachment', filename=('gbk', '', - filename.split('/')[-1])) + file_msg.add_header( + "Content-Disposition", + "attachment", + filename=("gbk", "", filename.split("/")[-1]), + ) encoders.encode_base64(file_msg) return file_msg @@ -79,49 +81,55 @@ def send_email(self, subject, body, to, **kwargs): try: if not to: - logger.warning('收件人为空,无法发送邮件') + logger.warning("收件人为空,无法发送邮件") return if not isinstance(to, list): - raise TypeError('收件人需要为列表') - list_cc = kwargs.get('list_cc_addr', []) + raise TypeError("收件人需要为列表") + list_cc = kwargs.get("list_cc_addr", []) if not isinstance(list_cc, list): - raise TypeError('抄送人需要为列表') + raise TypeError("抄送人需要为列表") # 构造MIMEMultipart对象做为根容器 main_msg = email.mime.multipart.MIMEMultipart() # 添加文本内容 - text_msg = email.mime.text.MIMEText(body, 'plain', 'utf-8') + text_msg = email.mime.text.MIMEText(body, "plain", "utf-8") main_msg.attach(text_msg) # 添加附件 - filename_list = kwargs.get('filename_list') + filename_list = kwargs.get("filename_list") if filename_list: - for filename in kwargs['filename_list']: + for filename in kwargs["filename_list"]: file_msg = self._add_attachment(filename) main_msg.attach(file_msg) # 消息内容: - main_msg['Subject'] = Header(subject, "utf-8").encode() - main_msg['From'] = formataddr(["Archery 通知", self.MAIL_REVIEW_FROM_ADDR]) - main_msg['To'] = ','.join(list(set(to))) - main_msg['Cc'] = ', '.join(str(cc) for cc in list(set(list_cc))) - main_msg['Date'] = email.utils.formatdate() + main_msg["Subject"] = Header(subject, "utf-8").encode() + main_msg["From"] = formataddr(["Archery 通知", self.MAIL_REVIEW_FROM_ADDR]) + main_msg["To"] = ",".join(list(set(to))) + main_msg["Cc"] = ", ".join(str(cc) for cc in list(set(list_cc))) + main_msg["Date"] = email.utils.formatdate() if self.MAIL_SSL: - server = smtplib.SMTP_SSL(self.MAIL_REVIEW_SMTP_SERVER, self.MAIL_REVIEW_SMTP_PORT, timeout=3) + server = smtplib.SMTP_SSL( + self.MAIL_REVIEW_SMTP_SERVER, self.MAIL_REVIEW_SMTP_PORT, timeout=3 + ) else: - server = smtplib.SMTP(self.MAIL_REVIEW_SMTP_SERVER, self.MAIL_REVIEW_SMTP_PORT, timeout=3) + server = smtplib.SMTP( + self.MAIL_REVIEW_SMTP_SERVER, self.MAIL_REVIEW_SMTP_PORT, timeout=3 + ) # 如果提供的密码为空,则不需要登录 if self.MAIL_REVIEW_FROM_PASSWORD: server.login(self.MAIL_REVIEW_FROM_ADDR, self.MAIL_REVIEW_FROM_PASSWORD) - server.sendmail(self.MAIL_REVIEW_FROM_ADDR, to + list_cc, main_msg.as_string()) + server.sendmail( + self.MAIL_REVIEW_FROM_ADDR, to + list_cc, main_msg.as_string() + ) server.quit() - logger.debug(f'邮件推送成功\n消息标题:{subject}\n通知对象:{to + list_cc}\n消息内容:{body}') - return 'success' + logger.debug(f"邮件推送成功\n消息标题:{subject}\n通知对象:{to + list_cc}\n消息内容:{body}") + return "success" except Exception: - errmsg = '邮件推送失败\n{}'.format(traceback.format_exc()) + errmsg = "邮件推送失败\n{}".format(traceback.format_exc()) logger.error(errmsg) return errmsg @@ -135,14 +143,12 @@ def send_ding(url, content): """ data = { "msgtype": "text", - "text": { - "content": "{}".format(content) - }, + "text": {"content": "{}".format(content)}, } r = requests.post(url=url, json=data) r_json = r.json() - if r_json['errcode'] == 0: - logger.debug(f'钉钉Webhook推送成功\n通知对象:{url}\n消息内容:{content}') + if r_json["errcode"] == 0: + logger.debug(f"钉钉Webhook推送成功\n通知对象:{url}\n消息内容:{content}") else: logger.error(f"钉钉Webhook推送失败错误码\n请求url:{url}\n请求data:{data}\n请求响应:{r_json}") @@ -156,89 +162,83 @@ def send_ding2user(self, userid_list, content): access_token = get_access_token() send_url = f"https://oapi.dingtalk.com/topapi/message/corpconversation/asyncsend_v2?access_token={access_token}" data = { - "userid_list": ','.join(list(set(userid_list))), + "userid_list": ",".join(list(set(userid_list))), "agent_id": self.ding_agent_id, "msg": {"msgtype": "text", "text": {"content": f"{content}"}}, } r = requests.post(url=send_url, json=data, timeout=5) r_json = r.json() - if r_json['errcode'] == 0: - logger.debug(f'钉钉推送成功\n通知对象:{userid_list}\n消息内容:{content}') + if r_json["errcode"] == 0: + logger.debug(f"钉钉推送成功\n通知对象:{userid_list}\n消息内容:{content}") else: - logger.error(f'钉钉推送失败\n请求连接:{send_url}\n请求参数:{data}\n请求响应:{r_json}') + logger.error(f"钉钉推送失败\n请求连接:{send_url}\n请求参数:{data}\n请求响应:{r_json}") def send_wx2user(self, msg, user_list): if not user_list: - logger.error(f'企业微信推送失败,无法获取到推送的用户.') + logger.error(f"企业微信推送失败,无法获取到推送的用户.") return - to_user = '|'.join(list(set(user_list))) + to_user = "|".join(list(set(user_list))) access_token = get_wx_access_token() - send_url = f'https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token={access_token}' + send_url = f"https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token={access_token}" data = { "touser": to_user, "msgtype": "text", "agentid": self.wx_agent_id, - "text": { - "content": msg - }, + "text": {"content": msg}, } res = requests.post(url=send_url, json=data, timeout=5) r_json = res.json() - if r_json['errcode'] == 0: - logger.debug(f'企业微信推送成功\n通知对象:{to_user}') + if r_json["errcode"] == 0: + logger.debug(f"企业微信推送成功\n通知对象:{to_user}") else: - logger.error(f'企业微信推送失败\n请求连接:{send_url}\n请求参数:{data}\n请求响应:{r_json}') + logger.error(f"企业微信推送失败\n请求连接:{send_url}\n请求参数:{data}\n请求响应:{r_json}") - def send_qywx_webhook(self,qywx_webhook, msg): + def send_qywx_webhook(self, qywx_webhook, msg): send_url = qywx_webhook # 对链接进行转换 - _msg = re.findall('https://.+(?=\n)|http://.+(?=\n)', msg) + _msg = re.findall("https://.+(?=\n)|http://.+(?=\n)", msg) for url in _msg: # 防止如 [xxx](http://www.a.com)\n 的字符串被再次替换 if url.strip()[-1] != ")": - msg=msg.replace(url,'[请点击链接](%s)' % url) + msg = msg.replace(url, "[请点击链接](%s)" % url) data = { "msgtype": "markdown", - "markdown": { - "content": msg - }, + "markdown": {"content": msg}, } res = requests.post(url=send_url, json=data, timeout=5) r_json = res.json() - if r_json['errcode'] == 0: - logger.debug(f'企业微信机器人推送成功\n通知对象:机器人') + if r_json["errcode"] == 0: + logger.debug(f"企业微信机器人推送成功\n通知对象:机器人") else: - logger.error(f'企业微信机器人推送失败\n请求连接:{send_url}\n请求参数:{data}\n请求响应:{r_json}') + logger.error(f"企业微信机器人推送失败\n请求连接:{send_url}\n请求参数:{data}\n请求响应:{r_json}") @staticmethod def send_feishu_webhook(url, title, content): - data = { - "title": title, - "text": content - } - if '/v2/' in url: + data = {"title": title, "text": content} + if "/v2/" in url: data = { - 'msg_type': 'post', - 'content': { - 'post': { - 'zh_cn': { - 'title': title, - 'content': [[{ - 'tag': 'text', - 'text': content - }]] + "msg_type": "post", + "content": { + "post": { + "zh_cn": { + "title": title, + "content": [[{"tag": "text", "text": content}]], } } - } + }, } r = requests.post(url=url, json=data) r_json = r.json() - if 'ok' in r_json or ('StatusCode' in r_json and r_json['StatusCode'] == 0) or ('code' in r_json and r_json['code'] == 0): - logger.debug(f'飞书Webhook推送成功\n通知对象:{url}\n消息内容:{content}') + if ( + "ok" in r_json + or ("StatusCode" in r_json and r_json["StatusCode"] == 0) + or ("code" in r_json and r_json["code"] == 0) + ): + logger.debug(f"飞书Webhook推送成功\n通知对象:{url}\n消息内容:{content}") else: logger.error(f"飞书Webhook推送失败错误码\n请求url:{url}\n请求data:{data}\n请求响应:{r_json}") @@ -252,12 +252,14 @@ def send_feishu_user(title, content, open_id, user_mail): data = { "open_ids": open_id, "msg_type": "text", - "content": { - "text": f'{title}\n{content}' - } + "content": {"text": f"{title}\n{content}"}, } - r = requests.post(url=url, json=data, headers={'Authorization': "Bearer " + get_feishu_access_token()}).json() - if r['code'] == 0: - logger.debug(f'飞书单推推送成功\n通知对象:{url}\n消息内容:{content}') + r = requests.post( + url=url, + json=data, + headers={"Authorization": "Bearer " + get_feishu_access_token()}, + ).json() + if r["code"] == 0: + logger.debug(f"飞书单推推送成功\n通知对象:{url}\n消息内容:{content}") else: logger.error(f"飞书单推推送失败错误码\n请求url:{url}\n请求data:{data}\n请求响应:{r}") diff --git a/common/utils/tencent_sms.py b/common/utils/tencent_sms.py index 9cf224aa9e..0cec8539e2 100644 --- a/common/utils/tencent_sms.py +++ b/common/utils/tencent_sms.py @@ -1,33 +1,32 @@ # -*- coding: utf-8 -*- from tencentcloud.common import credential -from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException +from tencentcloud.common.exception.tencent_cloud_sdk_exception import ( + TencentCloudSDKException, +) from tencentcloud.sms.v20210111 import sms_client, models from common.config import SysConfig import traceback import logging -logger = logging.getLogger('default') +logger = logging.getLogger("default") class TencentSMS: def __init__(self): all_config = SysConfig() - self.secret_id = all_config.get('tencent_secret_id', '') - self.secret_key = all_config.get('tencent_secret_key', '') - self.sign_name = all_config.get('tencent_sign_name', '') - self.template_id = all_config.get('tencent_template_id', '') - self.sdk_appid = all_config.get('tencent_sdk_appid', '') + self.secret_id = all_config.get("tencent_secret_id", "") + self.secret_key = all_config.get("tencent_secret_key", "") + self.sign_name = all_config.get("tencent_sign_name", "") + self.template_id = all_config.get("tencent_template_id", "") + self.sdk_appid = all_config.get("tencent_sdk_appid", "") def create_client(self): - cred = credential.Credential( - self.secret_id, - self.secret_key - ) + cred = credential.Credential(self.secret_id, self.secret_key) client = sms_client.SmsClient(cred, "ap-guangzhou") return client def send_code(self, **kwargs): - result = {'status': 0, 'msg': 'ok'} + result = {"status": 0, "msg": "ok"} client = TencentSMS.create_client(self) try: @@ -35,16 +34,16 @@ def send_code(self, **kwargs): req.SmsSdkAppId = self.sdk_appid req.SignName = self.sign_name req.TemplateId = self.template_id - req.TemplateParamSet = [kwargs['otp']] - req.PhoneNumberSet = [kwargs['phone']] + req.TemplateParamSet = [kwargs["otp"]] + req.PhoneNumberSet = [kwargs["phone"]] resp = client.SendSms(req) - if resp.SendStatusSet[0].Code != 'Ok': - result['status'] = 1 - result['msg'] = resp.SendStatusSet[0].Message + if resp.SendStatusSet[0].Code != "Ok": + result["status"] = 1 + result["msg"] = resp.SendStatusSet[0].Message except TencentCloudSDKException as e: - result['status'] = 1 - result['msg'] = str(e) + result["status"] = 1 + result["msg"] = str(e) logger.error(str(e)) logger.error(traceback.format_exc()) diff --git a/common/utils/timer.py b/common/utils/timer.py index cea0f91e5b..c5cc11c17d 100644 --- a/common/utils/timer.py +++ b/common/utils/timer.py @@ -7,7 +7,7 @@ """ import datetime -__author__ = 'hhyo' +__author__ = "hhyo" class FuncTimer(object): diff --git a/common/utils/wx_api.py b/common/utils/wx_api.py index 6f8bc5d5d6..3975004249 100644 --- a/common/utils/wx_api.py +++ b/common/utils/wx_api.py @@ -5,13 +5,13 @@ from common.config import SysConfig from django.core.cache import cache -logger = logging.getLogger('default') +logger = logging.getLogger("default") def get_wx_access_token(): # 优先获取缓存 try: - token = cache.get('wx_access_token') + token = cache.get("wx_access_token") except Exception as e: logger.error(f"获取企业微信token缓存出错:{e}") token = None @@ -19,14 +19,14 @@ def get_wx_access_token(): return token # 请求企业微信接口获取 sys_config = SysConfig() - corp_id = sys_config.get('wx_corpid') - corp_secret = sys_config.get('wx_app_secret') + corp_id = sys_config.get("wx_corpid") + corp_secret = sys_config.get("wx_app_secret") url = f"https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid={corp_id}&corpsecret={corp_secret}" resp = requests.get(url, timeout=3).json() - if resp.get('errcode') == 0: - access_token = resp.get('access_token') - expires_in = resp.get('expires_in') - cache.set('wx_access_token', access_token, timeout=expires_in - 60) + if resp.get("errcode") == 0: + access_token = resp.get("access_token") + expires_in = resp.get("expires_in") + cache.set("wx_access_token", access_token, timeout=expires_in - 60) return access_token else: logger.error(f"获取企业微信access_token出错:{resp}") diff --git a/common/views.py b/common/views.py index 5fdfac752e..cbf8a1d704 100644 --- a/common/views.py +++ b/common/views.py @@ -7,30 +7,32 @@ """ from django.shortcuts import render -__author__ = 'hhyo' +__author__ = "hhyo" from django.http import ( - HttpResponseBadRequest, HttpResponseForbidden, HttpResponseNotFound, + HttpResponseBadRequest, + HttpResponseForbidden, + HttpResponseNotFound, HttpResponseServerError, ) from django.views.decorators.csrf import requires_csrf_token @requires_csrf_token -def bad_request(request, exception, template_name='errors/400.html'): +def bad_request(request, exception, template_name="errors/400.html"): return HttpResponseBadRequest(render(request, template_name)) @requires_csrf_token -def permission_denied(request, exception, template_name='errors/403.html'): +def permission_denied(request, exception, template_name="errors/403.html"): return HttpResponseForbidden(render(request, template_name)) @requires_csrf_token -def page_not_found(request, exception, template_name='errors/404.html'): +def page_not_found(request, exception, template_name="errors/404.html"): return HttpResponseNotFound(render(request, template_name)) @requires_csrf_token -def server_error(request, template_name='errors/500.html'): +def server_error(request, template_name="errors/500.html"): return HttpResponseServerError(render(request, template_name)) diff --git a/common/workflow.py b/common/workflow.py index dd7988711a..2ca3edcfbc 100644 --- a/common/workflow.py +++ b/common/workflow.py @@ -13,11 +13,11 @@ def lists(request): # 获取用户信息 user = request.user - limit = int(request.POST.get('limit')) - offset = int(request.POST.get('offset')) - workflow_type = int(request.POST.get('workflow_type')) + limit = int(request.POST.get("limit")) + offset = int(request.POST.get("offset")) + workflow_type = int(request.POST.get("workflow_type")) limit = offset + limit - search = request.POST.get('search', '') + search = request.POST.get("search", "") # 先获取用户所在资源组列表 group_list = user_groups(user) @@ -31,43 +31,56 @@ def lists(request): # 只返回所在资源组当前待自己审核的数据 workflow_audit = WorkflowAudit.objects.filter( workflow_title__icontains=search, - current_status=WorkflowDict.workflow_status['audit_wait'], + current_status=WorkflowDict.workflow_status["audit_wait"], group_id__in=group_ids, - current_audit__in=auth_group_ids + current_audit__in=auth_group_ids, ) # 过滤工单类型 if workflow_type != 0: workflow_audit = workflow_audit.filter(workflow_type=workflow_type) audit_list_count = workflow_audit.count() - audit_list = workflow_audit.order_by('-audit_id')[offset:limit].values( - 'audit_id', 'workflow_type', - 'workflow_title', 'create_user_display', - 'create_time', 'current_status', - 'audit_auth_groups', - 'current_audit', - 'group_name') + audit_list = workflow_audit.order_by("-audit_id")[offset:limit].values( + "audit_id", + "workflow_type", + "workflow_title", + "create_user_display", + "create_time", + "current_status", + "audit_auth_groups", + "current_audit", + "group_name", + ) # QuerySet 序列化 rows = [row for row in audit_list] result = {"total": audit_list_count, "rows": rows} # 返回查询结果 - return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), - content_type='application/json') + return HttpResponse( + json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), + content_type="application/json", + ) # 获取工单日志 def log(request): - workflow_id = request.POST.get('workflow_id') - workflow_type = request.POST.get('workflow_type') + workflow_id = request.POST.get("workflow_id") + workflow_type = request.POST.get("workflow_type") try: - audit_id = WorkflowAudit.objects.get(workflow_id=workflow_id, workflow_type=workflow_type).audit_id - workflow_logs = WorkflowLog.objects.filter(audit_id=audit_id).order_by('-id').values( - 'operation_type_desc', - 'operation_info', - 'operator_display', - 'operation_time') + audit_id = WorkflowAudit.objects.get( + workflow_id=workflow_id, workflow_type=workflow_type + ).audit_id + workflow_logs = ( + WorkflowLog.objects.filter(audit_id=audit_id) + .order_by("-id") + .values( + "operation_type_desc", + "operation_info", + "operator_display", + "operation_time", + ) + ) count = WorkflowLog.objects.filter(audit_id=audit_id).count() except Exception: workflow_logs = [] @@ -77,5 +90,7 @@ def log(request): rows = [row for row in workflow_logs] result = {"total": count, "rows": rows} # 返回查询结果 - return HttpResponse(json.dumps(result, cls=ExtendJSONEncoderFTime, bigint_as_string=True), - content_type='application/json') + return HttpResponse( + json.dumps(result, cls=ExtendJSONEncoderFTime, bigint_as_string=True), + content_type="application/json", + ) diff --git a/sql/admin.py b/sql/admin.py index 36c9a6b713..13e80bb1c5 100755 --- a/sql/admin.py +++ b/sql/admin.py @@ -5,11 +5,31 @@ # Register your models here. from django.forms import PasswordInput -from .models import Users, Instance, SqlWorkflow, SqlWorkflowContent, QueryLog, DataMaskingColumns, DataMaskingRules, \ - AliyunRdsConfig, CloudAccessKey, ResourceGroup, QueryPrivilegesApply, \ - QueryPrivileges, InstanceAccount, InstanceDatabase, ArchiveConfig, \ - WorkflowAudit, WorkflowLog, ParamTemplate, ParamHistory, InstanceTag, \ - Tunnel, AuditEntry, TwoFactorAuthConfig +from .models import ( + Users, + Instance, + SqlWorkflow, + SqlWorkflowContent, + QueryLog, + DataMaskingColumns, + DataMaskingRules, + AliyunRdsConfig, + CloudAccessKey, + ResourceGroup, + QueryPrivilegesApply, + QueryPrivileges, + InstanceAccount, + InstanceDatabase, + ArchiveConfig, + WorkflowAudit, + WorkflowLog, + ParamTemplate, + ParamHistory, + InstanceTag, + Tunnel, + AuditEntry, + TwoFactorAuthConfig, +) from sql.form import TunnelForm, InstanceForm @@ -17,65 +37,150 @@ # 用户管理 @admin.register(Users) class UsersAdmin(UserAdmin): - list_display = ('id', 'username', 'display', 'email', 'is_superuser', 'is_staff', 'is_active') - search_fields = ('id', 'username', 'display', 'email') - list_display_links = ('id', 'username',) - ordering = ('id',) + list_display = ( + "id", + "username", + "display", + "email", + "is_superuser", + "is_staff", + "is_active", + ) + search_fields = ("id", "username", "display", "email") + list_display_links = ( + "id", + "username", + ) + ordering = ("id",) # 编辑页显示内容 fieldsets = ( - ('认证信息', {'fields': ('username', 'password')}), - ('个人信息', {'fields': ('display', 'email', 'ding_user_id', 'wx_user_id', 'feishu_open_id')}), - ('权限信息', {'fields': ('is_superuser', 'is_active', 'is_staff', 'groups', 'user_permissions')}), - ('资源组', {'fields': ('resource_group',)}), - ('其他信息', {'fields': ('date_joined',)}), + ("认证信息", {"fields": ("username", "password")}), + ( + "个人信息", + { + "fields": ( + "display", + "email", + "ding_user_id", + "wx_user_id", + "feishu_open_id", + ) + }, + ), + ( + "权限信息", + { + "fields": ( + "is_superuser", + "is_active", + "is_staff", + "groups", + "user_permissions", + ) + }, + ), + ("资源组", {"fields": ("resource_group",)}), + ("其他信息", {"fields": ("date_joined",)}), ) # 添加页显示内容 add_fieldsets = ( - ('认证信息', {'fields': ('username', 'password1', 'password2')}), - ('个人信息', {'fields': ('display', 'email', 'ding_user_id', 'wx_user_id', 'feishu_open_id')}), - ('权限信息', {'fields': ('is_superuser', 'is_active', 'is_staff', 'groups', 'user_permissions')}), - ('资源组', {'fields': ('resource_group',)}), + ("认证信息", {"fields": ("username", "password1", "password2")}), + ( + "个人信息", + { + "fields": ( + "display", + "email", + "ding_user_id", + "wx_user_id", + "feishu_open_id", + ) + }, + ), + ( + "权限信息", + { + "fields": ( + "is_superuser", + "is_active", + "is_staff", + "groups", + "user_permissions", + ) + }, + ), + ("资源组", {"fields": ("resource_group",)}), ) - filter_horizontal = ('groups', 'user_permissions', 'resource_group') - list_filter = ('is_staff', 'is_superuser', 'is_active', 'groups', 'resource_group') + filter_horizontal = ("groups", "user_permissions", "resource_group") + list_filter = ("is_staff", "is_superuser", "is_active", "groups", "resource_group") # 用户2fa管理 @admin.register(TwoFactorAuthConfig) class TwoFactorAuthConfigAdmin(admin.ModelAdmin): - list_display = ('id', 'username', 'auth_type', 'phone', 'secret_key', 'user_id') + list_display = ("id", "username", "auth_type", "phone", "secret_key", "user_id") # 资源组管理 @admin.register(ResourceGroup) class ResourceGroupAdmin(admin.ModelAdmin): - list_display = ('group_id', 'group_name', 'ding_webhook', 'feishu_webhook', 'qywx_webhook', 'is_deleted') - exclude = ('group_parent_id', 'group_sort', 'group_level',) + list_display = ( + "group_id", + "group_name", + "ding_webhook", + "feishu_webhook", + "qywx_webhook", + "is_deleted", + ) + exclude = ( + "group_parent_id", + "group_sort", + "group_level", + ) # 实例标签配置 @admin.register(InstanceTag) class InstanceTagAdmin(admin.ModelAdmin): - list_display = ('id', 'tag_code', 'tag_name', 'active', 'create_time') - list_display_links = ('id', 'tag_code',) - fieldsets = (None, {'fields': ('tag_code', 'tag_name', 'active'), }), + list_display = ("id", "tag_code", "tag_name", "active", "create_time") + list_display_links = ( + "id", + "tag_code", + ) + fieldsets = ( + ( + None, + { + "fields": ("tag_code", "tag_name", "active"), + }, + ), + ) # 不支持修改标签代码 def get_readonly_fields(self, request, obj=None): - return ('tag_code',) if obj else () + return ("tag_code",) if obj else () # 实例管理 @admin.register(Instance) class InstanceAdmin(admin.ModelAdmin): form = InstanceForm - list_display = ('id', 'instance_name', 'db_type', 'type', 'host', 'port', 'user', 'create_time') - search_fields = ['instance_name', 'host', 'port', 'user'] - list_filter = ('db_type', 'type', 'instance_tag') + list_display = ( + "id", + "instance_name", + "db_type", + "type", + "host", + "port", + "user", + "create_time", + ) + search_fields = ["instance_name", "host", "port", "user"] + list_filter = ("db_type", "type", "instance_tag") def formfield_for_dbfield(self, db_field, **kwargs): - if db_field.name == 'password': - kwargs['widget'] = PasswordInput(render_value=True) + if db_field.name == "password": + kwargs["widget"] = PasswordInput(render_value=True) return super(InstanceAdmin, self).formfield_for_dbfield(db_field, **kwargs) # 阿里云实例关系配置 @@ -83,7 +188,10 @@ class AliRdsConfigInline(admin.TabularInline): model = AliyunRdsConfig # 实例资源组关联配置 - filter_horizontal = ('resource_group', 'instance_tag',) + filter_horizontal = ( + "resource_group", + "instance_tag", + ) inlines = [AliRdsConfigInline] @@ -91,28 +199,48 @@ class AliRdsConfigInline(admin.TabularInline): # SSH隧道 @admin.register(Tunnel) class TunnelAdmin(admin.ModelAdmin): - list_display = ('id', 'tunnel_name', 'host', 'port', 'create_time') - list_display_links = ('id', 'tunnel_name',) - search_fields = ('id', 'tunnel_name') + list_display = ("id", "tunnel_name", "host", "port", "create_time") + list_display_links = ( + "id", + "tunnel_name", + ) + search_fields = ("id", "tunnel_name") fieldsets = ( - None, - {'fields': ('tunnel_name', 'host', 'port', 'user', 'password', 'pkey_path', 'pkey_password', 'pkey'), }), - ordering = ('id',) + ( + None, + { + "fields": ( + "tunnel_name", + "host", + "port", + "user", + "password", + "pkey_path", + "pkey_password", + "pkey", + ), + }, + ), + ) + ordering = ("id",) # 添加页显示内容 add_fieldsets = ( - ('隧道信息', {'fields': ('tunnel_name', 'host', 'port')}), - ('连接信息', {'fields': ('user', 'password', 'pkey_path', 'pkey_password', 'pkey')}), + ("隧道信息", {"fields": ("tunnel_name", "host", "port")}), + ( + "连接信息", + {"fields": ("user", "password", "pkey_path", "pkey_password", "pkey")}, + ), ) form = TunnelForm def formfield_for_dbfield(self, db_field, **kwargs): - if db_field.name in ['password', 'pkey_password']: - kwargs['widget'] = PasswordInput(render_value=True) + if db_field.name in ["password", "pkey_password"]: + kwargs["widget"] = PasswordInput(render_value=True) return super(TunnelAdmin, self).formfield_for_dbfield(db_field, **kwargs) # 不支持修改标签代码 def get_readonly_fields(self, request, obj=None): - return ('id',) if obj else () + return ("id",) if obj else () # SQL工单内容 @@ -124,10 +252,31 @@ class SqlWorkflowContentInline(admin.TabularInline): @admin.register(SqlWorkflow) class SqlWorkflowAdmin(admin.ModelAdmin): list_display = ( - 'id', 'workflow_name', 'group_name', 'instance', 'engineer_display', 'create_time', 'status', 'is_backup') - search_fields = ['id', 'workflow_name', 'engineer_display', 'sqlworkflowcontent__sql_content'] - list_filter = ('group_name', 'instance__instance_name', 'status', 'syntax_type',) - list_display_links = ('id', 'workflow_name',) + "id", + "workflow_name", + "group_name", + "instance", + "engineer_display", + "create_time", + "status", + "is_backup", + ) + search_fields = [ + "id", + "workflow_name", + "engineer_display", + "sqlworkflowcontent__sql_content", + ] + list_filter = ( + "group_name", + "instance__instance_name", + "status", + "syntax_type", + ) + list_display_links = ( + "id", + "workflow_name", + ) inlines = [SqlWorkflowContentInline] @@ -135,130 +284,247 @@ class SqlWorkflowAdmin(admin.ModelAdmin): @admin.register(QueryLog) class QueryLogAdmin(admin.ModelAdmin): list_display = ( - 'instance_name', 'db_name', 'sqllog', 'effect_row', 'cost_time', 'user_display', 'create_time') - search_fields = ['sqllog', 'user_display'] - list_filter = ('instance_name', 'db_name', 'user_display', 'priv_check', 'hit_rule', 'masking',) + "instance_name", + "db_name", + "sqllog", + "effect_row", + "cost_time", + "user_display", + "create_time", + ) + search_fields = ["sqllog", "user_display"] + list_filter = ( + "instance_name", + "db_name", + "user_display", + "priv_check", + "hit_rule", + "masking", + ) # 查询权限列表 @admin.register(QueryPrivileges) class QueryPrivilegesAdmin(admin.ModelAdmin): - list_display = ('privilege_id', 'user_display', 'instance', 'db_name', 'table_name', - 'valid_date', 'limit_num', 'create_time') - search_fields = ['user_display', 'instance__instance_name'] - list_filter = ('user_display', 'instance', 'db_name', 'table_name',) + list_display = ( + "privilege_id", + "user_display", + "instance", + "db_name", + "table_name", + "valid_date", + "limit_num", + "create_time", + ) + search_fields = ["user_display", "instance__instance_name"] + list_filter = ( + "user_display", + "instance", + "db_name", + "table_name", + ) # 查询权限申请记录 @admin.register(QueryPrivilegesApply) class QueryPrivilegesApplyAdmin(admin.ModelAdmin): - list_display = ('apply_id', 'user_display', 'group_name', 'instance', 'valid_date', 'limit_num', 'create_time') - search_fields = ['user_display', 'instance__instance_name', 'db_list', 'table_list'] - list_filter = ('user_display', 'group_name', 'instance') + list_display = ( + "apply_id", + "user_display", + "group_name", + "instance", + "valid_date", + "limit_num", + "create_time", + ) + search_fields = ["user_display", "instance__instance_name", "db_list", "table_list"] + list_filter = ("user_display", "group_name", "instance") # 脱敏字段页面定义 @admin.register(DataMaskingColumns) class DataMaskingColumnsAdmin(admin.ModelAdmin): list_display = ( - 'column_id', 'rule_type', 'active', 'instance', 'table_schema', 'table_name', 'column_name', 'column_comment', - 'create_time',) - search_fields = ['table_name', 'column_name'] - list_filter = ('rule_type', 'active', 'instance__instance_name') + "column_id", + "rule_type", + "active", + "instance", + "table_schema", + "table_name", + "column_name", + "column_comment", + "create_time", + ) + search_fields = ["table_name", "column_name"] + list_filter = ("rule_type", "active", "instance__instance_name") # 脱敏规则页面定义 @admin.register(DataMaskingRules) class DataMaskingRulesAdmin(admin.ModelAdmin): list_display = ( - 'rule_type', 'rule_regex', 'hide_group', 'rule_desc', 'sys_time',) + "rule_type", + "rule_regex", + "hide_group", + "rule_desc", + "sys_time", + ) # 工作流审批列表 @admin.register(WorkflowAudit) class WorkflowAuditAdmin(admin.ModelAdmin): list_display = ( - 'workflow_title', 'group_name', 'workflow_type', 'current_status', 'create_user_display', 'create_time') - search_fields = ['workflow_title', 'create_user_display'] - list_filter = ('create_user_display', 'group_name', 'workflow_type', 'current_status') + "workflow_title", + "group_name", + "workflow_type", + "current_status", + "create_user_display", + "create_time", + ) + search_fields = ["workflow_title", "create_user_display"] + list_filter = ( + "create_user_display", + "group_name", + "workflow_type", + "current_status", + ) # 工作流日志表 @admin.register(WorkflowLog) class WorkflowLogAdmin(admin.ModelAdmin): list_display = ( - 'operation_type_desc', 'operation_info', 'operator_display', 'operation_time',) - list_filter = ('operation_type_desc', 'operator_display') + "operation_type_desc", + "operation_info", + "operator_display", + "operation_time", + ) + list_filter = ("operation_type_desc", "operator_display") # 实例数据库列表 @admin.register(InstanceDatabase) class InstanceDatabaseAdmin(admin.ModelAdmin): - list_display = ('db_name', 'owner_display', 'instance', 'remark') - search_fields = ('db_name',) - list_filter = ('instance', 'owner_display') - list_display_links = ('db_name',) + list_display = ("db_name", "owner_display", "instance", "remark") + search_fields = ("db_name",) + list_filter = ("instance", "owner_display") + list_display_links = ("db_name",) # 仅支持修改备注 def get_readonly_fields(self, request, obj=None): - return ('instance', 'owner', 'owner_display') if obj else () + return ("instance", "owner", "owner_display") if obj else () # 实例用户列表 @admin.register(InstanceAccount) class InstanceAccountAdmin(admin.ModelAdmin): - list_display = ('user', 'host', 'password', 'instance', 'remark') - search_fields = ('user', 'host') - list_filter = ('instance', 'host') - list_display_links = ('user',) + list_display = ("user", "host", "password", "instance", "remark") + search_fields = ("user", "host") + list_filter = ("instance", "host") + list_display_links = ("user",) # 仅支持修改备注 def get_readonly_fields(self, request, obj=None): - return ('user', 'host', 'instance',) if obj else () + return ( + ( + "user", + "host", + "instance", + ) + if obj + else () + ) # 实例参数配置表 @admin.register(ParamTemplate) class ParamTemplateAdmin(admin.ModelAdmin): - list_display = ('variable_name', 'db_type', 'default_value', 'editable', 'valid_values') - search_fields = ('variable_name',) - list_filter = ('db_type', 'editable') - list_display_links = ('variable_name',) + list_display = ( + "variable_name", + "db_type", + "default_value", + "editable", + "valid_values", + ) + search_fields = ("variable_name",) + list_filter = ("db_type", "editable") + list_display_links = ("variable_name",) # 实例参数修改历史 @admin.register(ParamHistory) class ParamHistoryAdmin(admin.ModelAdmin): - list_display = ('variable_name', 'instance', 'old_var', 'new_var', 'user_display', 'create_time') - search_fields = ('variable_name',) - list_filter = ('instance', 'user_display') + list_display = ( + "variable_name", + "instance", + "old_var", + "new_var", + "user_display", + "create_time", + ) + search_fields = ("variable_name",) + list_filter = ("instance", "user_display") # 归档配置 @admin.register(ArchiveConfig) class ArchiveConfigAdmin(admin.ModelAdmin): list_display = ( - 'id', 'title', 'src_instance', 'src_db_name', 'src_table_name', - 'dest_instance', 'dest_db_name', 'dest_table_name', - 'mode', 'no_delete', 'status', 'state', 'user_display', 'create_time', 'resource_group') - search_fields = ('title', 'src_table_name') - list_display_links = ('id', 'title') - list_filter = ('src_instance', 'src_db_name', 'mode', 'no_delete', 'state') + "id", + "title", + "src_instance", + "src_db_name", + "src_table_name", + "dest_instance", + "dest_db_name", + "dest_table_name", + "mode", + "no_delete", + "status", + "state", + "user_display", + "create_time", + "resource_group", + ) + search_fields = ("title", "src_table_name") + list_display_links = ("id", "title") + list_filter = ("src_instance", "src_db_name", "mode", "no_delete", "state") # 编辑页显示内容 - fields = ('title', 'resource_group', 'src_instance', 'src_db_name', 'src_table_name', - 'dest_instance', 'dest_db_name', 'dest_table_name', - 'mode', 'condition', 'sleep', 'no_delete', 'state', 'user_name', 'user_display') + fields = ( + "title", + "resource_group", + "src_instance", + "src_db_name", + "src_table_name", + "dest_instance", + "dest_db_name", + "dest_table_name", + "mode", + "condition", + "sleep", + "no_delete", + "state", + "user_name", + "user_display", + ) # 云服务认证信息配置 @admin.register(CloudAccessKey) class CloudAccessKeyAdmin(admin.ModelAdmin): - list_display = ('type', 'key_id', 'key_secret', 'remark') + list_display = ("type", "key_id", "key_secret", "remark") # 登录审计日志 @admin.register(AuditEntry) class AuditEntryAdmin(admin.ModelAdmin): - list_display = ('user_id', 'user_name', 'user_display', 'action', 'extra_info', 'action_time') - list_filter = ('user_id', 'user_name', 'user_display', 'action', 'extra_info') - + list_display = ( + "user_id", + "user_name", + "user_display", + "action", + "extra_info", + "action_time", + ) + list_filter = ("user_id", "user_name", "user_display", "action", "extra_info") diff --git a/sql/aliyun_rds.py b/sql/aliyun_rds.py index d94552ae23..e2a9fd0416 100644 --- a/sql/aliyun_rds.py +++ b/sql/aliyun_rds.py @@ -9,23 +9,23 @@ # 获取SQL慢日志统计 def slowquery_review(request): - instance_name = request.POST.get('instance_name') - db_name = request.POST.get('db_name') - start_time = request.POST.get('StartTime') - end_time = request.POST.get('EndTime') - limit = request.POST.get('limit') - offset = request.POST.get('offset') + instance_name = request.POST.get("instance_name") + db_name = request.POST.get("db_name") + start_time = request.POST.get("StartTime") + end_time = request.POST.get("EndTime") + limit = request.POST.get("limit") + offset = request.POST.get("offset") # 计算页数 page_number = (int(offset) + int(limit)) / int(limit) values = {"PageSize": int(limit), "PageNumber": int(page_number)} # DBName非必传 if db_name: - values['DBName'] = db_name + values["DBName"] = db_name # UTC时间转化成阿里云需求的时间格式 - start_time = '%sZ' % start_time - end_time = '%sZ' % end_time + start_time = "%sZ" % start_time + end_time = "%sZ" % end_time # 通过实例名称获取关联的rds实例id instance_info = AliyunRdsConfig.objects.get(instance__instance_name=instance_name) @@ -33,55 +33,70 @@ def slowquery_review(request): slowsql = Aliyun(rds=instance_info).DescribeSlowLogs(start_time, end_time, **values) # 解决table数据丢失精度、格式化时间 - sql_slow_log = json.loads(slowsql)['Items']['SQLSlowLog'] + sql_slow_log = json.loads(slowsql)["Items"]["SQLSlowLog"] for SlowLog in sql_slow_log: - SlowLog['SQLId'] = str(SlowLog['SQLHASH']) - SlowLog['CreateTime'] = Aliyun.utc2local(SlowLog['CreateTime'], utc_format="%Y-%m-%dZ") - - result = {"total": json.loads(slowsql)['TotalRecordCount'], "rows": sql_slow_log, - "PageSize": json.loads(slowsql)['PageRecordCount'], "PageNumber": json.loads(slowsql)['PageNumber']} + SlowLog["SQLId"] = str(SlowLog["SQLHASH"]) + SlowLog["CreateTime"] = Aliyun.utc2local( + SlowLog["CreateTime"], utc_format="%Y-%m-%dZ" + ) + + result = { + "total": json.loads(slowsql)["TotalRecordCount"], + "rows": sql_slow_log, + "PageSize": json.loads(slowsql)["PageRecordCount"], + "PageNumber": json.loads(slowsql)["PageNumber"], + } # 返回查询结果 return result # 获取SQL慢日志明细 def slowquery_review_history(request): - instance_name = request.POST.get('instance_name') - start_time = request.POST.get('StartTime') - end_time = request.POST.get('EndTime') - db_name = request.POST.get('db_name') - sql_id = request.POST.get('SQLId') - limit = request.POST.get('limit') - offset = request.POST.get('offset') + instance_name = request.POST.get("instance_name") + start_time = request.POST.get("StartTime") + end_time = request.POST.get("EndTime") + db_name = request.POST.get("db_name") + sql_id = request.POST.get("SQLId") + limit = request.POST.get("limit") + offset = request.POST.get("offset") # 计算页数 page_number = (int(offset) + int(limit)) / int(limit) values = {"PageSize": int(limit), "PageNumber": int(page_number)} # SQLId、DBName非必传 if sql_id: - values['SQLHASH'] = sql_id + values["SQLHASH"] = sql_id if db_name: - values['DBName'] = db_name + values["DBName"] = db_name # UTC时间转化成阿里云需求的时间格式 - start_time = datetime.datetime.strptime(start_time, "%Y-%m-%d").date() - datetime.timedelta(days=1) - start_time = '%sT16:00Z' % start_time - end_time = '%sT15:59Z' % end_time + start_time = datetime.datetime.strptime( + start_time, "%Y-%m-%d" + ).date() - datetime.timedelta(days=1) + start_time = "%sT16:00Z" % start_time + end_time = "%sT15:59Z" % end_time # 通过实例名称获取关联的rds实例id instance_info = AliyunRdsConfig.objects.get(instance__instance_name=instance_name) # 调用aliyun接口获取SQL慢日志统计 - slowsql = Aliyun(rds=instance_info).DescribeSlowLogRecords(start_time, end_time, **values) + slowsql = Aliyun(rds=instance_info).DescribeSlowLogRecords( + start_time, end_time, **values + ) # 格式化时间\过滤HostAddress - sql_slow_record = json.loads(slowsql)['Items']['SQLSlowRecord'] + sql_slow_record = json.loads(slowsql)["Items"]["SQLSlowRecord"] for SlowRecord in sql_slow_record: - SlowRecord['ExecutionStartTime'] = Aliyun.utc2local(SlowRecord['ExecutionStartTime'], - utc_format='%Y-%m-%dT%H:%M:%SZ') - SlowRecord['HostAddress'] = SlowRecord['HostAddress'].split('[')[0] - - result = {"total": json.loads(slowsql)['TotalRecordCount'], "rows": sql_slow_record, - "PageSize": json.loads(slowsql)['PageRecordCount'], "PageNumber": json.loads(slowsql)['PageNumber']} + SlowRecord["ExecutionStartTime"] = Aliyun.utc2local( + SlowRecord["ExecutionStartTime"], utc_format="%Y-%m-%dT%H:%M:%SZ" + ) + SlowRecord["HostAddress"] = SlowRecord["HostAddress"].split("[")[0] + + result = { + "total": json.loads(slowsql)["TotalRecordCount"], + "rows": sql_slow_record, + "PageSize": json.loads(slowsql)["PageRecordCount"], + "PageNumber": json.loads(slowsql)["PageNumber"], + } # 返回查询结果 return result @@ -89,23 +104,24 @@ def slowquery_review_history(request): # 问题诊断--进程列表 def process_status(request): - instance_name = request.POST.get('instance_name') - command_type = request.POST.get('command_type') + instance_name = request.POST.get("instance_name") + command_type = request.POST.get("command_type") - if command_type is None or command_type == '': - command_type = 'Query' + if command_type is None or command_type == "": + command_type = "Query" # 通过实例名称获取关联的rds实例id instance_info = AliyunRdsConfig.objects.get(instance__instance_name=instance_name) # 调用aliyun接口获取进程数据 process_info = Aliyun(rds=instance_info).RequestServiceOfCloudDBA( - 'ShowProcessList', {"Language": "zh", "Command": command_type}) + "ShowProcessList", {"Language": "zh", "Command": command_type} + ) # 提取进程列表 - process_list = json.loads(process_info)['AttrData'] - process_list = json.loads(process_list)['ProcessList'] + process_list = json.loads(process_info)["AttrData"] + process_list = json.loads(process_list)["ProcessList"] - result = {'status': 0, 'msg': 'ok', 'rows': process_list} + result = {"status": 0, "msg": "ok", "rows": process_list} # 返回查询结果 return result @@ -113,20 +129,22 @@ def process_status(request): # 问题诊断--通过进程id构建请求id def create_kill_session(request): - instance_name = request.POST.get('instance_name') - thread_ids = request.POST.get('ThreadIDs') + instance_name = request.POST.get("instance_name") + thread_ids = request.POST.get("ThreadIDs") - result = {'status': 0, 'msg': 'ok', 'data': []} + result = {"status": 0, "msg": "ok", "data": []} # 通过实例名称获取关联的rds实例id instance_info = AliyunRdsConfig.objects.get(instance__instance_name=instance_name) # 调用aliyun接口获取进程数据 request_info = Aliyun(rds=instance_info).RequestServiceOfCloudDBA( - 'CreateKillSessionRequest', {"Language": "zh", "ThreadIDs": json.loads(thread_ids)}) + "CreateKillSessionRequest", + {"Language": "zh", "ThreadIDs": json.loads(thread_ids)}, + ) # 提取进程列表 - request_list = json.loads(request_info)['AttrData'] + request_list = json.loads(request_info)["AttrData"] - result['data'] = request_list + result["data"] = request_list # 返回处理结果 return result @@ -134,21 +152,23 @@ def create_kill_session(request): # 问题诊断--终止会话 def kill_session(request): - instance_name = request.POST.get('instance_name') - request_params = request.POST.get('request_params') + instance_name = request.POST.get("instance_name") + request_params = request.POST.get("request_params") - result = {'status': 0, 'msg': 'ok', 'data': []} + result = {"status": 0, "msg": "ok", "data": []} # 通过实例名称获取关联的rds实例id instance_info = AliyunRdsConfig.objects.get(instance__instance_name=instance_name) # 调用aliyun接口获取终止进程 request_params = json.loads(request_params) service_request_param = dict({"Language": "zh"}, **request_params) - kill_result = Aliyun(rds=instance_info).RequestServiceOfCloudDBA('ConfirmKillSessionRequest', service_request_param) + kill_result = Aliyun(rds=instance_info).RequestServiceOfCloudDBA( + "ConfirmKillSessionRequest", service_request_param + ) # 获取处理结果 - kill_result = json.loads(kill_result)['AttrData'] + kill_result = json.loads(kill_result)["AttrData"] - result['data'] = kill_result + result["data"] = kill_result # 返回查询结果 return result @@ -156,22 +176,23 @@ def kill_session(request): # 问题诊断--空间列表 def sapce_status(request): - instance_name = request.POST.get('instance_name') + instance_name = request.POST.get("instance_name") # 通过实例名称获取关联的rds实例id instance_info = AliyunRdsConfig.objects.get(instance__instance_name=instance_name) # 调用aliyun接口获取进程数据 space_info = Aliyun(rds=instance_info).RequestServiceOfCloudDBA( - 'GetSpaceStatForTables', {"Language": "zh", "OrderType": "Data"}) + "GetSpaceStatForTables", {"Language": "zh", "OrderType": "Data"} + ) # 提取进程列表 - space_list = json.loads(space_info)['ListData'] + space_list = json.loads(space_info)["ListData"] if space_list: space_list = json.loads(space_list) else: space_list = [] - result = {'status': 0, 'msg': 'ok', 'rows': space_list} + result = {"status": 0, "msg": "ok", "rows": space_list} # 返回查询结果 return result diff --git a/sql/archiver.py b/sql/archiver.py index 063aed1eaf..9189149e57 100644 --- a/sql/archiver.py +++ b/sql/archiver.py @@ -32,11 +32,11 @@ from sql.models import ArchiveConfig, ArchiveLog, Instance, ResourceGroup from sql.utils.workflow_audit import Audit -logger = logging.getLogger('default') -__author__ = 'hhyo' +logger = logging.getLogger("default") +__author__ = "hhyo" -@permission_required('sql.menu_archive', raise_exception=True) +@permission_required("sql.menu_archive", raise_exception=True) def archive_list(request): """ 获取归档申请列表 @@ -44,47 +44,62 @@ def archive_list(request): :return: """ user = request.user - filter_instance_id = request.GET.get('filter_instance_id') - state = request.GET.get('state') - limit = int(request.GET.get('limit', 0)) - offset = int(request.GET.get('offset', 0)) + filter_instance_id = request.GET.get("filter_instance_id") + state = request.GET.get("state") + limit = int(request.GET.get("limit", 0)) + offset = int(request.GET.get("offset", 0)) limit = offset + limit - search = request.GET.get('search', '') + search = request.GET.get("search", "") # 组合筛选项 filter_dict = dict() if filter_instance_id: - filter_dict['src_instance'] = filter_instance_id - if state == 'true': - filter_dict['state'] = True - elif state == 'false': - filter_dict['state'] = False + filter_dict["src_instance"] = filter_instance_id + if state == "true": + filter_dict["state"] = True + elif state == "false": + filter_dict["state"] = False # 管理员可以看到全部数据 if user.is_superuser: pass # 拥有审核权限、可以查看组内所有工单 - elif user.has_perm('sql.archive_review'): + elif user.has_perm("sql.archive_review"): # 先获取用户所在资源组列表 group_list = user_groups(user) group_ids = [group.group_id for group in group_list] - filter_dict['resource_group__in'] = group_ids + filter_dict["resource_group__in"] = group_ids # 其他人只能看到自己提交的工单 else: - filter_dict['user_name'] = user.username + filter_dict["user_name"] = user.username # 过滤组合筛选项 archive_config = ArchiveConfig.objects.filter(**filter_dict) # 过滤搜索项,支持模糊搜索标题、用户 if search: - archive_config = archive_config.filter(Q(title__icontains=search) | Q(user_display__icontains=search)) + archive_config = archive_config.filter( + Q(title__icontains=search) | Q(user_display__icontains=search) + ) count = archive_config.count() - lists = archive_config.order_by('-id')[offset:limit].values( - 'id', 'title', 'src_instance__instance_name', 'src_db_name', 'src_table_name', - 'dest_instance__instance_name', 'dest_db_name', 'dest_table_name', 'sleep', - 'mode', 'no_delete', 'status', 'state', 'user_display', 'create_time', 'resource_group__group_name' + lists = archive_config.order_by("-id")[offset:limit].values( + "id", + "title", + "src_instance__instance_name", + "src_db_name", + "src_table_name", + "dest_instance__instance_name", + "dest_db_name", + "dest_table_name", + "sleep", + "mode", + "no_delete", + "status", + "state", + "user_display", + "create_time", + "resource_group__group_name", ) # QuerySet 序列化 @@ -92,55 +107,75 @@ def archive_list(request): result = {"total": count, "rows": rows} # 返回查询结果 - return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), - content_type='application/json') + return HttpResponse( + json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), + content_type="application/json", + ) -@permission_required('sql.archive_apply', raise_exception=True) +@permission_required("sql.archive_apply", raise_exception=True) def archive_apply(request): """申请归档实例数据""" user = request.user - title = request.POST.get('title') - group_name = request.POST.get('group_name') - src_instance_name = request.POST.get('src_instance_name') - src_db_name = request.POST.get('src_db_name') - src_table_name = request.POST.get('src_table_name') - mode = request.POST.get('mode') - dest_instance_name = request.POST.get('dest_instance_name') - dest_db_name = request.POST.get('dest_db_name') - dest_table_name = request.POST.get('dest_table_name') - condition = request.POST.get('condition') - no_delete = True if request.POST.get('no_delete') == 'true' else False - sleep = request.POST.get('sleep') or 0 - result = {'status': 0, 'msg': 'ok', 'data': {}} + title = request.POST.get("title") + group_name = request.POST.get("group_name") + src_instance_name = request.POST.get("src_instance_name") + src_db_name = request.POST.get("src_db_name") + src_table_name = request.POST.get("src_table_name") + mode = request.POST.get("mode") + dest_instance_name = request.POST.get("dest_instance_name") + dest_db_name = request.POST.get("dest_db_name") + dest_table_name = request.POST.get("dest_table_name") + condition = request.POST.get("condition") + no_delete = True if request.POST.get("no_delete") == "true" else False + sleep = request.POST.get("sleep") or 0 + result = {"status": 0, "msg": "ok", "data": {}} # 参数校验 - if not all( - [title, group_name, src_instance_name, src_db_name, src_table_name, mode, condition]) or no_delete is None: - return JsonResponse({'status': 1, 'msg': '请填写完整!', 'data': {}}) - if mode == 'dest' and not all([dest_instance_name, dest_db_name, dest_table_name]): - return JsonResponse({'status': 1, 'msg': '归档到实例时目标实例信息必选!', 'data': {}}) + if ( + not all( + [ + title, + group_name, + src_instance_name, + src_db_name, + src_table_name, + mode, + condition, + ] + ) + or no_delete is None + ): + return JsonResponse({"status": 1, "msg": "请填写完整!", "data": {}}) + if mode == "dest" and not all([dest_instance_name, dest_db_name, dest_table_name]): + return JsonResponse({"status": 1, "msg": "归档到实例时目标实例信息必选!", "data": {}}) # 获取源实例信息 try: - s_ins = user_instances(request.user, db_type=['mysql']).get(instance_name=src_instance_name) + s_ins = user_instances(request.user, db_type=["mysql"]).get( + instance_name=src_instance_name + ) except Instance.DoesNotExist: - return JsonResponse({'status': 1, 'msg': '你所在组未关联该实例!', 'data': {}}) + return JsonResponse({"status": 1, "msg": "你所在组未关联该实例!", "data": {}}) # 获取目标实例信息 - if mode == 'dest': + if mode == "dest": try: - d_ins = user_instances(request.user, db_type=['mysql']).get(instance_name=dest_instance_name) + d_ins = user_instances(request.user, db_type=["mysql"]).get( + instance_name=dest_instance_name + ) except Instance.DoesNotExist: - return JsonResponse({'status': 1, 'msg': '你所在组未关联该实例!', 'data': {}}) + return JsonResponse({"status": 1, "msg": "你所在组未关联该实例!", "data": {}}) else: d_ins = None # 获取资源组和审批信息 res_group = ResourceGroup.objects.get(group_name=group_name) - audit_auth_groups = Audit.settings(res_group.group_id, WorkflowDict.workflow_type['archive']) + audit_auth_groups = Audit.settings( + res_group.group_id, WorkflowDict.workflow_type["archive"] + ) if not audit_auth_groups: - return JsonResponse({'status': 1, 'msg': '审批流程不能为空,请先配置审批流程', 'data': {}}) + return JsonResponse({"status": 1, "msg": "审批流程不能为空,请先配置审批流程", "data": {}}) # 使用事务保持数据一致性 try: @@ -160,28 +195,34 @@ def archive_apply(request): mode=mode, no_delete=no_delete, sleep=sleep, - status=WorkflowDict.workflow_status['audit_wait'], + status=WorkflowDict.workflow_status["audit_wait"], state=False, user_name=user.username, user_display=user.display, ) archive_id = archive_info.id # 调用工作流插入审核信息 - audit_result = Audit.add(WorkflowDict.workflow_type['archive'], archive_id) + audit_result = Audit.add(WorkflowDict.workflow_type["archive"], archive_id) except Exception as msg: logger.error(traceback.format_exc()) - result['status'] = 1 - result['msg'] = str(msg) + result["status"] = 1 + result["msg"] = str(msg) else: result = audit_result # 消息通知 - audit_id = Audit.detail_by_workflow_id(workflow_id=archive_id, - workflow_type=WorkflowDict.workflow_type['archive']).audit_id - async_task(notify_for_audit, audit_id=audit_id, timeout=60, task_name=f'archive-apply-{archive_id}') - return HttpResponse(json.dumps(result), content_type='application/json') - - -@permission_required('sql.archive_review', raise_exception=True) + audit_id = Audit.detail_by_workflow_id( + workflow_id=archive_id, workflow_type=WorkflowDict.workflow_type["archive"] + ).audit_id + async_task( + notify_for_audit, + audit_id=audit_id, + timeout=60, + task_name=f"archive-apply-{archive_id}", + ) + return HttpResponse(json.dumps(result), content_type="application/json") + + +@permission_required("sql.archive_review", raise_exception=True) def archive_audit(request): """ 审核数据归档申请 @@ -190,39 +231,51 @@ def archive_audit(request): """ # 获取用户信息 user = request.user - archive_id = int(request.POST['archive_id']) - audit_status = int(request.POST['audit_status']) - audit_remark = request.POST.get('audit_remark') + archive_id = int(request.POST["archive_id"]) + audit_status = int(request.POST["audit_status"]) + audit_remark = request.POST.get("audit_remark") if audit_remark is None: - audit_remark = '' + audit_remark = "" if Audit.can_review(request.user, archive_id, 3) is False: - context = {'errMsg': '你无权操作当前工单!'} - return render(request, 'error.html', context) + context = {"errMsg": "你无权操作当前工单!"} + return render(request, "error.html", context) # 使用事务保持数据一致性 try: with transaction.atomic(): - audit_id = Audit.detail_by_workflow_id(workflow_id=archive_id, - workflow_type=WorkflowDict.workflow_type['archive']).audit_id + audit_id = Audit.detail_by_workflow_id( + workflow_id=archive_id, + workflow_type=WorkflowDict.workflow_type["archive"], + ).audit_id # 调用工作流插入审核信息,更新业务表审核状态 - audit_status = Audit.audit(audit_id, audit_status, user.username, audit_remark)['data']['workflow_status'] - ArchiveConfig(id=archive_id, - status=audit_status, - state=True if audit_status == WorkflowDict.workflow_status['audit_success'] else False - ).save(update_fields=['status', 'state']) + audit_status = Audit.audit( + audit_id, audit_status, user.username, audit_remark + )["data"]["workflow_status"] + ArchiveConfig( + id=archive_id, + status=audit_status, + state=True + if audit_status == WorkflowDict.workflow_status["audit_success"] + else False, + ).save(update_fields=["status", "state"]) except Exception as msg: logger.error(traceback.format_exc()) - context = {'errMsg': msg} - return render(request, 'error.html', context) + context = {"errMsg": msg} + return render(request, "error.html", context) else: # 消息通知 - async_task(notify_for_audit, audit_id=audit_id, audit_remark=audit_remark, timeout=60, - task_name=f'archive-audit-{archive_id}') + async_task( + notify_for_audit, + audit_id=audit_id, + audit_remark=audit_remark, + timeout=60, + task_name=f"archive-audit-{archive_id}", + ) - return HttpResponseRedirect(reverse('sql:archive_detail', args=(archive_id,))) + return HttpResponseRedirect(reverse("sql:archive_detail", args=(archive_id,))) def add_archive_task(archive_ids=None): @@ -237,18 +290,25 @@ def add_archive_task(archive_ids=None): # 没有传archive_id代表全部归档任务统一调度 if archive_ids: archive_cnf_list = ArchiveConfig.objects.filter( - id__in=archive_ids, state=True, status=WorkflowDict.workflow_status['audit_success']) + id__in=archive_ids, + state=True, + status=WorkflowDict.workflow_status["audit_success"], + ) else: archive_cnf_list = ArchiveConfig.objects.filter( - state=True, status=WorkflowDict.workflow_status['audit_success']) + state=True, status=WorkflowDict.workflow_status["audit_success"] + ) # 添加task任务 for archive_info in archive_cnf_list: archive_id = archive_info.id - async_task('sql.archiver.archive', - archive_id, - group=f'archive-{time.strftime("%Y-%m-%d %H:%M:%S ")}', timeout=-1, - task_name=f'archive-{archive_id}') + async_task( + "sql.archiver.archive", + archive_id, + group=f'archive-{time.strftime("%Y-%m-%d %H:%M:%S ")}', + timeout=-1, + task_name=f"archive-{archive_id}", + ) def archive(archive_id): @@ -269,28 +329,30 @@ def archive(archive_id): s_engine = get_engine(s_ins) s_db = s_engine.schema_object.databases[src_db_name] s_tb = s_db.tables[src_table_name] - s_charset = s_tb.options['charset'].value + s_charset = s_tb.options["charset"].value if s_charset is None: - s_charset = s_db.options['charset'].value + s_charset = s_db.options["charset"].value pt_archiver = PtArchiver() # 准备参数 - source = fr"h={s_ins.host},u={s_ins.user},p='{s_ins.password}'," \ - fr"P={s_ins.port},D={src_db_name},t={src_table_name},A={s_charset}" + source = ( + rf"h={s_ins.host},u={s_ins.user},p='{s_ins.password}'," + rf"P={s_ins.port},D={src_db_name},t={src_table_name},A={s_charset}" + ) args = { "no-version-check": True, "source": source, "where": condition, "progress": 5000, "statistics": True, - "charset": 'utf8', + "charset": "utf8", "limit": 10000, "txn-size": 1000, - "sleep": sleep + "sleep": sleep, } # 归档到目标实例 - if mode == 'dest': + if mode == "dest": d_ins = archive_info.dest_instance dest_db_name = archive_info.dest_db_name dest_table_name = archive_info.dest_table_name @@ -298,27 +360,31 @@ def archive(archive_id): d_engine = get_engine(d_ins) d_db = d_engine.schema_object.databases[dest_db_name] d_tb = d_db.tables[dest_table_name] - d_charset = d_tb.options['charset'].value + d_charset = d_tb.options["charset"].value if d_charset is None: - d_charset = d_db.options['charset'].value + d_charset = d_db.options["charset"].value # dest - dest = fr"h={d_ins.host},u={d_ins.user},p={d_ins.password},P={d_ins.port}," \ - fr"D={dest_db_name},t={dest_table_name},A={d_charset}" - args['dest'] = dest + dest = ( + rf"h={d_ins.host},u={d_ins.user},p={d_ins.password},P={d_ins.port}," + rf"D={dest_db_name},t={dest_table_name},A={d_charset}" + ) + args["dest"] = dest if no_delete: - args['no-delete'] = True - elif mode == 'file': - output_directory = os.path.join(settings.BASE_DIR, 'downloads/archiver') + args["no-delete"] = True + elif mode == "file": + output_directory = os.path.join(settings.BASE_DIR, "downloads/archiver") os.makedirs(output_directory, exist_ok=True) - args['file'] = f'{output_directory}/{s_ins.instance_name}-{src_db_name}-{src_table_name}.txt' + args[ + "file" + ] = f"{output_directory}/{s_ins.instance_name}-{src_db_name}-{src_table_name}.txt" if no_delete: - args['no-delete'] = True - elif mode == 'purge': - args['purge'] = True + args["no-delete"] = True + elif mode == "purge": + args["purge"] = True # 参数检查 args_check_result = pt_archiver.check_args(args) - if args_check_result['status'] == 1: + if args_check_result["status"] == 1: return JsonResponse(args_check_result) # 参数转换 cmd_args = pt_archiver.generate_args2cmd(args, shell=True) @@ -328,15 +394,15 @@ def archive(archive_id): delete_cnt = 0 with FuncTimer() as t: p = pt_archiver.execute_cmd(cmd_args, shell=True) - stdout = '' - for line in iter(p.stdout.readline, ''): - if re.match(r'^SELECT\s(\d+)$', line, re.I): - select_cnt = re.findall(r'^SELECT\s(\d+)$', line) - elif re.match(r'^INSERT\s(\d+)$', line, re.I): - insert_cnt = re.findall(r'^INSERT\s(\d+)$', line) - elif re.match(r'^DELETE\s(\d+)$', line, re.I): - delete_cnt = re.findall(r'^DELETE\s(\d+)$', line) - stdout += f'{line}\n' + stdout = "" + for line in iter(p.stdout.readline, ""): + if re.match(r"^SELECT\s(\d+)$", line, re.I): + select_cnt = re.findall(r"^SELECT\s(\d+)$", line) + elif re.match(r"^INSERT\s(\d+)$", line, re.I): + insert_cnt = re.findall(r"^INSERT\s(\d+)$", line) + elif re.match(r"^DELETE\s(\d+)$", line, re.I): + delete_cnt = re.findall(r"^DELETE\s(\d+)$", line) + stdout += f"{line}\n" statistics = stdout # 获取异常信息 stderr = p.stderr.read() @@ -347,22 +413,22 @@ def archive(archive_id): select_cnt = int(select_cnt[0]) if select_cnt else 0 insert_cnt = int(insert_cnt[0]) if insert_cnt else 0 delete_cnt = int(delete_cnt[0]) if delete_cnt else 0 - error_info = '' + error_info = "" success = True if stderr: - error_info = f'命令执行报错:{stderr}' + error_info = f"命令执行报错:{stderr}" success = False - if mode == 'dest': + if mode == "dest": # 删除源数据,判断删除数量和写入数量 if not no_delete and (insert_cnt != delete_cnt): error_info = f"删除和写入数量不一致:{insert_cnt}!={delete_cnt}" success = False - elif mode == 'file': + elif mode == "file": # 删除源数据,判断查询数量和删除数量 if not no_delete and (select_cnt != delete_cnt): error_info = f"查询和删除数量不一致:{select_cnt}!={delete_cnt}" success = False - elif mode == 'purge': + elif mode == "purge": # 直接删除。判断查询数量和删除数量 if select_cnt != delete_cnt: error_info = f"查询和删除数量不一致:{select_cnt}!={delete_cnt}" @@ -372,12 +438,15 @@ def archive(archive_id): if connection.connection and not connection.is_usable(): close_old_connections() # 更新最后归档时间 - ArchiveConfig(id=archive_id, last_archive_time=t.end).save(update_fields=['last_archive_time']) + ArchiveConfig(id=archive_id, last_archive_time=t.end).save( + update_fields=["last_archive_time"] + ) # 替换密码信息后保存 ArchiveLog.objects.create( archive=archive_info, - cmd=cmd_args.replace(s_ins.password, '***').replace( - d_ins.password, '***') if mode == 'dest' else cmd_args.replace(s_ins.password, '***'), + cmd=cmd_args.replace(s_ins.password, "***").replace(d_ins.password, "***") + if mode == "dest" + else cmd_args.replace(s_ins.password, "***"), condition=condition, mode=mode, no_delete=no_delete, @@ -389,51 +458,69 @@ def archive(archive_id): success=success, error_info=error_info, start_time=t.start, - end_time=t.end + end_time=t.end, ) if not success: - raise Exception(f'{error_info}\n{statistics}') + raise Exception(f"{error_info}\n{statistics}") -@permission_required('sql.menu_archive', raise_exception=True) +@permission_required("sql.menu_archive", raise_exception=True) def archive_log(request): """获取归档日志列表""" - limit = int(request.GET.get('limit', 0)) - offset = int(request.GET.get('offset', 0)) + limit = int(request.GET.get("limit", 0)) + offset = int(request.GET.get("offset", 0)) limit = offset + limit - archive_id = request.GET.get('archive_id') + archive_id = request.GET.get("archive_id") archive_logs = ArchiveLog.objects.filter(archive=archive_id).annotate( - info=Concat('cmd', V('\n'), 'statistics', output_field=TextField())) + info=Concat("cmd", V("\n"), "statistics", output_field=TextField()) + ) count = archive_logs.count() - lists = archive_logs.order_by('-id')[offset:limit].values( - 'cmd', 'info', 'condition', 'mode', 'no_delete', 'select_cnt', - 'insert_cnt', 'delete_cnt', 'success', 'error_info', 'start_time', 'end_time' + lists = archive_logs.order_by("-id")[offset:limit].values( + "cmd", + "info", + "condition", + "mode", + "no_delete", + "select_cnt", + "insert_cnt", + "delete_cnt", + "success", + "error_info", + "start_time", + "end_time", ) # QuerySet 序列化 rows = [row for row in lists] result = {"total": count, "rows": rows} # 返回查询结果 - return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), - content_type='application/json') + return HttpResponse( + json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), + content_type="application/json", + ) -@permission_required('sql.archive_mgt', raise_exception=True) +@permission_required("sql.archive_mgt", raise_exception=True) def archive_switch(request): """开启关闭归档任务""" - archive_id = request.POST.get('archive_id') - state = True if request.POST.get('state') == 'true' else False + archive_id = request.POST.get("archive_id") + state = True if request.POST.get("state") == "true" else False # 更新启用状态 try: - ArchiveConfig(id=archive_id, state=state).save(update_fields=['state']) - return JsonResponse({'status': 0, 'msg': 'ok', 'data': {}}) + ArchiveConfig(id=archive_id, state=state).save(update_fields=["state"]) + return JsonResponse({"status": 0, "msg": "ok", "data": {}}) except Exception as msg: - return JsonResponse({'status': 1, 'msg': f'{msg}', 'data': {}}) + return JsonResponse({"status": 1, "msg": f"{msg}", "data": {}}) -@permission_required('sql.archive_mgt', raise_exception=True) +@permission_required("sql.archive_mgt", raise_exception=True) def archive_once(request): """单次立即调用归档任务""" - archive_id = request.GET.get('archive_id') - async_task('sql.archiver.archive', archive_id, timeout=-1, task_name=f'archive-{archive_id}') - return JsonResponse({'status': 0, 'msg': 'ok', 'data': {}}) + archive_id = request.GET.get("archive_id") + async_task( + "sql.archiver.archive", + archive_id, + timeout=-1, + task_name=f"archive-{archive_id}", + ) + return JsonResponse({"status": 0, "msg": "ok", "data": {}}) diff --git a/sql/audit_log.py b/sql/audit_log.py index daaccf217d..f4e29ec326 100644 --- a/sql/audit_log.py +++ b/sql/audit_log.py @@ -7,77 +7,93 @@ from django.utils import timezone from django.http import HttpResponse from django.db.models import Q -from django.contrib.auth.decorators import login_required,permission_required -from django.contrib.auth.signals import user_logged_in, user_logged_out, user_login_failed +from django.contrib.auth.decorators import login_required, permission_required +from django.contrib.auth.signals import ( + user_logged_in, + user_logged_out, + user_login_failed, +) from .models import AuditEntry, Users from common.utils.permission import superuser_required from common.utils.extend_json_encoder import ExtendJSONEncoder -log = logging.getLogger('default') +log = logging.getLogger("default") @login_required def audit_input(request): """用户提交的操作信息""" result = {} - action = request.POST.get('action') - extra_info = request.POST.get('extra_info','') + action = request.POST.get("action") + extra_info = request.POST.get("extra_info", "") - result['user_id'] = request.user.id - result['user_name'] = request.user.username - result['user_display'] = request.user.display - result['action'] = action - result['extra_info'] = extra_info + result["user_id"] = request.user.id + result["user_name"] = request.user.username + result["user_display"] = request.user.display + result["action"] = action + result["extra_info"] = extra_info audit = AuditEntry(**result) audit.save() - return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), - content_type='application/json') + return HttpResponse( + json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), + content_type="application/json", + ) -@permission_required('sql.audit_user', raise_exception=True) +@permission_required("sql.audit_user", raise_exception=True) def audit_log(request): """获取审计日志列表""" - limit = int(request.POST.get('limit',0)) - offset = int(request.POST.get('offset',0)) + limit = int(request.POST.get("limit", 0)) + offset = int(request.POST.get("offset", 0)) limit = offset + limit limit = limit if limit else None - search = request.POST.get('search', '') - action = request.POST.get('action','') - start_date = request.POST.get('start_date') - end_date = request.POST.get('end_date') + search = request.POST.get("search", "") + action = request.POST.get("action", "") + start_date = request.POST.get("start_date") + end_date = request.POST.get("end_date") filter_dict = dict() if start_date and end_date: - end_date = datetime.datetime.strptime(end_date, '%Y-%m-%d') + datetime.timedelta(days=1) - filter_dict['action_time__range'] = (start_date, end_date) + end_date = datetime.datetime.strptime( + end_date, "%Y-%m-%d" + ) + datetime.timedelta(days=1) + filter_dict["action_time__range"] = (start_date, end_date) if action: - filter_dict['action'] = action - + filter_dict["action"] = action + audit_log_obj = AuditEntry.objects.filter(**filter_dict) if search: - audit_log_obj = audit_log_obj.filter(Q(user_name__icontains=search) | Q(action__icontains=search)| Q(extra_info__icontains=search)) + audit_log_obj = audit_log_obj.filter( + Q(user_name__icontains=search) + | Q(action__icontains=search) + | Q(extra_info__icontains=search) + ) audit_log_count = audit_log_obj.count() - audit_log_list = audit_log_obj.order_by('-action_time')[offset:limit].values('user_id', 'user_name', 'user_display', 'action', 'extra_info', 'action_time') + audit_log_list = audit_log_obj.order_by("-action_time")[offset:limit].values( + "user_id", "user_name", "user_display", "action", "extra_info", "action_time" + ) # QuerySet 序列化 rows = [row for row in audit_log_list] result = {"total": audit_log_count, "rows": rows} # 返回查询结果 - return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), - content_type='application/json') + return HttpResponse( + json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), + content_type="application/json", + ) def get_client_ip(request): - x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR') + x_forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR") if x_forwarded_for: - ip = x_forwarded_for.split(',')[0] + ip = x_forwarded_for.split(",")[0] else: - ip = request.META.get('REMOTE_ADDR') + ip = request.META.get("REMOTE_ADDR") return ip @@ -85,26 +101,45 @@ def get_client_ip(request): def user_logged_in_callback(sender, request, user, **kwargs): ip = get_client_ip(request) now = timezone.now() - AuditEntry.objects.create(action=u'登入', extra_info=ip, user_id=user.id, user_name=user.username, user_display=user.display, action_time=now) + AuditEntry.objects.create( + action="登入", + extra_info=ip, + user_id=user.id, + user_name=user.username, + user_display=user.display, + action_time=now, + ) @receiver(user_logged_out) def user_logged_out_callback(sender, request, user, **kwargs): ip = get_client_ip(request) now = timezone.now() - AuditEntry.objects.create(action=u'登出', extra_info=ip, user_id=user.id, user_name=user.username, user_display=user.display, action_time=now) + AuditEntry.objects.create( + action="登出", + extra_info=ip, + user_id=user.id, + user_name=user.username, + user_display=user.display, + action_time=now, + ) @receiver(user_login_failed) def user_login_failed_callback(sender, credentials, **kwargs): now = timezone.now() - user_name = credentials.get('username', None) + user_name = credentials.get("username", None) user_obj = Users.objects.filter(username=user_name)[0:1] user_count = user_obj.count() user_id = 0 - user_display = '' + user_display = "" if user_count > 0: user_id = user_obj[0].id user_display = user_obj[0].display - AuditEntry.objects.create(action=u'登入失败', user_id=user_id, user_name=user_name, user_display=user_display, action_time=now) - + AuditEntry.objects.create( + action="登入失败", + user_id=user_id, + user_name=user_name, + user_display=user_display, + action_time=now, + ) diff --git a/sql/binlog.py b/sql/binlog.py index c16ec02f87..2f1b70b33e 100644 --- a/sql/binlog.py +++ b/sql/binlog.py @@ -19,24 +19,24 @@ from sql.notify import notify_for_my2sql from .models import Instance -logger = logging.getLogger('default') +logger = logging.getLogger("default") -@permission_required('sql.menu_my2sql', raise_exception=True) +@permission_required("sql.menu_my2sql", raise_exception=True) def binlog_list(request): """ 获取binlog列表 :param request: :return: """ - instance_name = request.POST.get('instance_name') + instance_name = request.POST.get("instance_name") try: instance = Instance.objects.get(instance_name=instance_name) except Instance.DoesNotExist: - result = {'status': 1, 'msg': '实例不存在', 'data': []} - return HttpResponse(json.dumps(result), content_type='application/json') + result = {"status": 1, "msg": "实例不存在", "data": []} + return HttpResponse(json.dumps(result), content_type="application/json") query_engine = get_engine(instance=instance) - query_result = query_engine.query('information_schema', 'show binary logs;') + query_result = query_engine.query("information_schema", "show binary logs;") if not query_result.error: column_list = query_result.column_list rows = [] @@ -45,103 +45,129 @@ def binlog_list(request): for row_index, row_item in enumerate(row): row_info[column_list[row_index]] = row_item rows.append(row_info) - result = {'status': 0, 'msg': 'ok', 'data': rows} + result = {"status": 0, "msg": "ok", "data": rows} else: - result = {'status': 1, 'msg': query_result.error} + result = {"status": 1, "msg": query_result.error} # 返回查询结果 - return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), - content_type='application/json') + return HttpResponse( + json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), + content_type="application/json", + ) -@permission_required('sql.binlog_del', raise_exception=True) +@permission_required("sql.binlog_del", raise_exception=True) def del_binlog(request): - instance_id = request.POST.get('instance_id') - binlog = request.POST.get('binlog', '') + instance_id = request.POST.get("instance_id") + binlog = request.POST.get("binlog", "") try: instance = Instance.objects.get(id=instance_id) except Instance.DoesNotExist: - result = {'status': 1, 'msg': '实例不存在', 'data': []} - return HttpResponse(json.dumps(result), content_type='application/json') + result = {"status": 1, "msg": "实例不存在", "data": []} + return HttpResponse(json.dumps(result), content_type="application/json") # escape - binlog = MySQLdb.escape_string(binlog).decode('utf-8') + binlog = MySQLdb.escape_string(binlog).decode("utf-8") if binlog: query_engine = get_engine(instance=instance) - query_result = query_engine.query(sql=fr"purge master logs to '{binlog}';") + query_result = query_engine.query(sql=rf"purge master logs to '{binlog}';") if query_result.error is None: - result = {'status': 0, 'msg': '清理成功', 'data': ''} + result = {"status": 0, "msg": "清理成功", "data": ""} else: - result = {'status': 2, 'msg': f'清理失败,Error:{query_result.error}', 'data': ''} + result = { + "status": 2, + "msg": f"清理失败,Error:{query_result.error}", + "data": "", + } else: - result = {'status': 1, 'msg': 'Error:未选择binlog!', 'data': ''} - return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), - content_type='application/json') + result = {"status": 1, "msg": "Error:未选择binlog!", "data": ""} + return HttpResponse( + json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), + content_type="application/json", + ) -@permission_required('sql.menu_my2sql', raise_exception=True) +@permission_required("sql.menu_my2sql", raise_exception=True) def my2sql(request): """ 通过解析binlog获取SQL--使用my2sql :param request: :return: """ - instance_name = request.POST.get('instance_name') - save_sql = True if request.POST.get('save_sql') == 'true' else False + instance_name = request.POST.get("instance_name") + save_sql = True if request.POST.get("save_sql") == "true" else False instance = Instance.objects.get(instance_name=instance_name) - work_type = 'rollback' if request.POST.get('rollback') == 'true' else '2sql' - num = 30 if request.POST.get('num') == '' else int(request.POST.get('num')) - threads = 4 if request.POST.get('threads') == '' else int(request.POST.get('threads')) - start_file = request.POST.get('start_file') - start_pos = request.POST.get('start_pos') if request.POST.get('start_pos') == '' else int( - request.POST.get('start_pos')) - end_file = request.POST.get('end_file') - end_pos = request.POST.get('end_pos') if request.POST.get('end_pos') == '' else int(request.POST.get('end_pos')) - stop_time = request.POST.get('stop_time') - start_time = request.POST.get('start_time') - only_schemas = request.POST.getlist('only_schemas') - only_tables = request.POST.getlist('only_tables[]') - sql_type = [] if request.POST.getlist('sql_type[]') == [] else request.POST.getlist('sql_type[]') - extra_info = True if request.POST.get('extra_info') == 'true' else False - ignore_primary_key = True if request.POST.get('ignore_primary_key') == 'true' else False - full_columns = True if request.POST.get('full_columns') == 'true' else False - no_db_prefix = True if request.POST.get('no_db_prefix') == 'true' else False - file_per_table = True if request.POST.get('file_per_table') == 'true' else False - - result = {'status': 0, 'msg': 'ok', 'data': []} + work_type = "rollback" if request.POST.get("rollback") == "true" else "2sql" + num = 30 if request.POST.get("num") == "" else int(request.POST.get("num")) + threads = ( + 4 if request.POST.get("threads") == "" else int(request.POST.get("threads")) + ) + start_file = request.POST.get("start_file") + start_pos = ( + request.POST.get("start_pos") + if request.POST.get("start_pos") == "" + else int(request.POST.get("start_pos")) + ) + end_file = request.POST.get("end_file") + end_pos = ( + request.POST.get("end_pos") + if request.POST.get("end_pos") == "" + else int(request.POST.get("end_pos")) + ) + stop_time = request.POST.get("stop_time") + start_time = request.POST.get("start_time") + only_schemas = request.POST.getlist("only_schemas") + only_tables = request.POST.getlist("only_tables[]") + sql_type = ( + [] + if request.POST.getlist("sql_type[]") == [] + else request.POST.getlist("sql_type[]") + ) + extra_info = True if request.POST.get("extra_info") == "true" else False + ignore_primary_key = ( + True if request.POST.get("ignore_primary_key") == "true" else False + ) + full_columns = True if request.POST.get("full_columns") == "true" else False + no_db_prefix = True if request.POST.get("no_db_prefix") == "true" else False + file_per_table = True if request.POST.get("file_per_table") == "true" else False + + result = {"status": 0, "msg": "ok", "data": []} # 提交给my2sql进行解析 my2sql = My2SQL() # 准备参数 instance_password = shlex.quote(f'"{str(instance.password)}"') - args = {"conn_options": fr"-host {shlex.quote(str(instance.host))} -user {shlex.quote(str(instance.user))} \ + args = { + "conn_options": rf"-host {shlex.quote(str(instance.host))} -user {shlex.quote(str(instance.user))} \ -password '{instance_password}' -port {shlex.quote(str(instance.port))} ", - "work-type": work_type, - "start-file": start_file, - "start-pos": start_pos, - "stop-file": end_file, - "stop-pos": end_pos, - "start-datetime": '"'+start_time+'"', - "stop-datetime": '"'+stop_time+'"', - "databases": ' '.join(only_schemas), - "tables": ','.join(only_tables), - "sql": ','.join(sql_type), - "instance": instance, - "threads": threads, - "add-extraInfo": extra_info, - "ignore-primaryKey-forInsert": ignore_primary_key, - "full-columns": full_columns, - "do-not-add-prifixDb": no_db_prefix, - "file-per-table": file_per_table, - "output-toScreen": True - } + "work-type": work_type, + "start-file": start_file, + "start-pos": start_pos, + "stop-file": end_file, + "stop-pos": end_pos, + "start-datetime": '"' + start_time + '"', + "stop-datetime": '"' + stop_time + '"', + "databases": " ".join(only_schemas), + "tables": ",".join(only_tables), + "sql": ",".join(sql_type), + "instance": instance, + "threads": threads, + "add-extraInfo": extra_info, + "ignore-primaryKey-forInsert": ignore_primary_key, + "full-columns": full_columns, + "do-not-add-prifixDb": no_db_prefix, + "file-per-table": file_per_table, + "output-toScreen": True, + } # 参数检查 args_check_result = my2sql.check_args(args) - if args_check_result['status'] == 1: - return HttpResponse(json.dumps(args_check_result), content_type='application/json') + if args_check_result["status"] == 1: + return HttpResponse( + json.dumps(args_check_result), content_type="application/json" + ) # 参数转换 cmd_args = my2sql.generate_args2cmd(args, shell=True) @@ -151,15 +177,15 @@ def my2sql(request): # 读取前num行后结束 rows = [] n = 1 - for line in iter(p.stdout.readline, ''): + for line in iter(p.stdout.readline, ""): if n <= num and isinstance(line, str): - if line[0:6].upper() in ('INSERT', 'DELETE', 'UPDATE'): + if line[0:6].upper() in ("INSERT", "DELETE", "UPDATE"): n = n + 1 row_info = {} try: - row_info['sql'] = line + ';' + row_info["sql"] = line + ";" except IndexError: - row_info['sql'] = line + ';' + row_info["sql"] = line + ";" rows.append(row_info) else: break @@ -168,27 +194,35 @@ def my2sql(request): # 判断是否有异常 stderr = p.stderr.read() if stderr and isinstance(stderr, str): - result['status'] = 1 - result['msg'] = stderr - return HttpResponse(json.dumps(result), content_type='application/json') + result["status"] = 1 + result["msg"] = stderr + return HttpResponse(json.dumps(result), content_type="application/json") # 终止子进程 p.kill() - result['data'] = rows + result["data"] = rows except Exception as e: logger.error(traceback.format_exc()) - result['status'] = 1 - result['msg'] = str(e) + result["status"] = 1 + result["msg"] = str(e) # 异步保存到文件 if save_sql: - args.pop('conn_options') - args.pop('output-toScreen') - async_task(my2sql_file, args=args, user=request.user, hook=notify_for_my2sql, timeout=-1, - task_name=f'my2sql-{time.time()}') + args.pop("conn_options") + args.pop("output-toScreen") + async_task( + my2sql_file, + args=args, + user=request.user, + hook=notify_for_my2sql, + timeout=-1, + task_name=f"my2sql-{time.time()}", + ) # 返回查询结果 - return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), - content_type='application/json') + return HttpResponse( + json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), + content_type="application/json", + ) def my2sql_file(args, user): @@ -199,12 +233,12 @@ def my2sql_file(args, user): :return: """ my2sql = My2SQL() - instance = args.get('instance') + instance = args.get("instance") instance_password = shlex.quote(f'"{str(instance.password)}"') - conn_options = fr"-host {shlex.quote(str(instance.host))} -user {shlex.quote(str(instance.user))} \ + conn_options = rf"-host {shlex.quote(str(instance.host))} -user {shlex.quote(str(instance.user))} \ -password '{instance_password}' -port {shlex.quote(str(instance.port))} " - args['conn_options'] = conn_options - path = os.path.join(settings.BASE_DIR, 'downloads/my2sql/') + args["conn_options"] = conn_options + path = os.path.join(settings.BASE_DIR, "downloads/my2sql/") os.makedirs(path, exist_ok=True) # 参数转换 diff --git a/sql/data_dictionary.py b/sql/data_dictionary.py index 13aa3816e0..2ee48b8d5e 100644 --- a/sql/data_dictionary.py +++ b/sql/data_dictionary.py @@ -16,73 +16,91 @@ from .models import Instance -@permission_required('sql.menu_data_dictionary', raise_exception=True) +@permission_required("sql.menu_data_dictionary", raise_exception=True) def table_list(request): """数据字典获取表列表""" - instance_name = request.GET.get('instance_name', '') - db_name = request.GET.get('db_name', '') - db_type = request.GET.get('db_type', '') + instance_name = request.GET.get("instance_name", "") + db_name = request.GET.get("db_name", "") + db_type = request.GET.get("db_type", "") if instance_name and db_name: try: - instance = Instance.objects.get(instance_name=instance_name, db_type=db_type) + instance = Instance.objects.get( + instance_name=instance_name, db_type=db_type + ) query_engine = get_engine(instance=instance) data = query_engine.get_group_tables_by_db(db_name=db_name) - res = {'status': 0, 'data': data} + res = {"status": 0, "data": data} except Instance.DoesNotExist: - res = {'status': 1, 'msg': 'Instance.DoesNotExist'} + res = {"status": 1, "msg": "Instance.DoesNotExist"} except Exception as e: - res = {'status': 1, 'msg': str(e)} + res = {"status": 1, "msg": str(e)} else: - res = {'status': 1, 'msg': '非法调用!'} - return HttpResponse(json.dumps(res, cls=ExtendJSONEncoder, bigint_as_string=True), - content_type='application/json') + res = {"status": 1, "msg": "非法调用!"} + return HttpResponse( + json.dumps(res, cls=ExtendJSONEncoder, bigint_as_string=True), + content_type="application/json", + ) -@permission_required('sql.menu_data_dictionary', raise_exception=True) +@permission_required("sql.menu_data_dictionary", raise_exception=True) def table_info(request): """数据字典获取表信息""" - instance_name = request.GET.get('instance_name', '') - db_name = request.GET.get('db_name', '') - tb_name = request.GET.get('tb_name', '') - db_type = request.GET.get('db_type', '') + instance_name = request.GET.get("instance_name", "") + db_name = request.GET.get("db_name", "") + tb_name = request.GET.get("tb_name", "") + db_type = request.GET.get("db_type", "") if instance_name and db_name and tb_name: data = {} try: - instance = Instance.objects.get(instance_name=instance_name, db_type=db_type) + instance = Instance.objects.get( + instance_name=instance_name, db_type=db_type + ) query_engine = get_engine(instance=instance) - data['meta_data'] = query_engine.get_table_meta_data(db_name=db_name, tb_name=tb_name) - data['desc'] = query_engine.get_table_desc_data(db_name=db_name, tb_name=tb_name) - data['index'] = query_engine.get_table_index_data(db_name=db_name, tb_name=tb_name) + data["meta_data"] = query_engine.get_table_meta_data( + db_name=db_name, tb_name=tb_name + ) + data["desc"] = query_engine.get_table_desc_data( + db_name=db_name, tb_name=tb_name + ) + data["index"] = query_engine.get_table_index_data( + db_name=db_name, tb_name=tb_name + ) # mysql数据库可以获取创建表格的SQL语句,mssql暂无找到生成创建表格的SQL语句 - if instance.db_type == 'mysql': - _create_sql = query_engine.query(db_name, "show create table `%s`;" % tb_name) - data['create_sql'] = _create_sql.rows - res = {'status': 0, 'data': data} + if instance.db_type == "mysql": + _create_sql = query_engine.query( + db_name, "show create table `%s`;" % tb_name + ) + data["create_sql"] = _create_sql.rows + res = {"status": 0, "data": data} except Instance.DoesNotExist: - res = {'status': 1, 'msg': 'Instance.DoesNotExist'} + res = {"status": 1, "msg": "Instance.DoesNotExist"} except Exception as e: - res = {'status': 1, 'msg': str(e)} + res = {"status": 1, "msg": str(e)} else: - res = {'status': 1, 'msg': '非法调用!'} - return HttpResponse(json.dumps(res, cls=ExtendJSONEncoder, bigint_as_string=True), - content_type='application/json') + res = {"status": 1, "msg": "非法调用!"} + return HttpResponse( + json.dumps(res, cls=ExtendJSONEncoder, bigint_as_string=True), + content_type="application/json", + ) -@permission_required('sql.data_dictionary_export', raise_exception=True) +@permission_required("sql.data_dictionary_export", raise_exception=True) def export(request): """导出数据字典""" - instance_name = request.GET.get('instance_name', '') - db_name = request.GET.get('db_name', '') + instance_name = request.GET.get("instance_name", "") + db_name = request.GET.get("db_name", "") # escape - db_name = MySQLdb.escape_string(db_name).decode('utf-8') + db_name = MySQLdb.escape_string(db_name).decode("utf-8") try: - instance = user_instances(request.user, db_type=['mysql', 'mssql', 'oracle']).get(instance_name=instance_name) + instance = user_instances( + request.user, db_type=["mysql", "mssql", "oracle"] + ).get(instance_name=instance_name) query_engine = get_engine(instance=instance) except Instance.DoesNotExist: - return JsonResponse({'status': 1, 'msg': '你所在组未关联该实例!', 'data': []}) + return JsonResponse({"status": 1, "msg": "你所在组未关联该实例!", "data": []}) # 普通用户仅可以获取指定数据库的字典信息 if db_name: @@ -91,24 +109,38 @@ def export(request): elif request.user.is_superuser: dbs = query_engine.get_all_databases().rows else: - return JsonResponse({'status': 1, 'msg': f'仅管理员可以导出整个实例的字典信息!', 'data': []}) + return JsonResponse({"status": 1, "msg": f"仅管理员可以导出整个实例的字典信息!", "data": []}) # 获取数据,存入目录 - path = os.path.join(settings.BASE_DIR, 'downloads/dictionary') + path = os.path.join(settings.BASE_DIR, "downloads/dictionary") os.makedirs(path, exist_ok=True) for db in dbs: table_metas = query_engine.get_tables_metas_data(db_name=db) - context = {"db_name": db_name, "tables": table_metas, "export_time": datetime.datetime.now()} - data = loader.render_to_string(template_name="dictionaryexport.html", context=context, request=request) - with open(f'{path}/{instance_name}_{db}.html', 'w') as f: + context = { + "db_name": db_name, + "tables": table_metas, + "export_time": datetime.datetime.now(), + } + data = loader.render_to_string( + template_name="dictionaryexport.html", context=context, request=request + ) + with open(f"{path}/{instance_name}_{db}.html", "w") as f: f.write(data) # 关闭连接 query_engine.close() if db_name: - response = FileResponse(open(f'{path}/{instance_name}_{db_name}.html', 'rb')) - response['Content-Type'] = 'application/octet-stream' - response['Content-Disposition'] = f'attachment;filename="{quote(instance_name)}_{quote(db_name)}.html"' + response = FileResponse(open(f"{path}/{instance_name}_{db_name}.html", "rb")) + response["Content-Type"] = "application/octet-stream" + response[ + "Content-Disposition" + ] = f'attachment;filename="{quote(instance_name)}_{quote(db_name)}.html"' return response else: - return JsonResponse({'status': 0, 'msg': f'实例{instance_name}数据字典导出成功,请到downloads目录下载!', 'data': []}) + return JsonResponse( + { + "status": 0, + "msg": f"实例{instance_name}数据字典导出成功,请到downloads目录下载!", + "data": [], + } + ) diff --git a/sql/db_diagnostic.py b/sql/db_diagnostic.py index 8f57fde337..be8fdfc12b 100644 --- a/sql/db_diagnostic.py +++ b/sql/db_diagnostic.py @@ -1,7 +1,8 @@ import logging import traceback import MySQLdb -#import simplejson as json + +# import simplejson as json import json from django.contrib.auth.decorators import permission_required @@ -12,206 +13,245 @@ from sql.utils.resource_group import user_instances from .models import AliyunRdsConfig, Instance -from .aliyun_rds import process_status as aliyun_process_status, create_kill_session as aliyun_create_kill_session, \ - kill_session as aliyun_kill_session, sapce_status as aliyun_sapce_status +from .aliyun_rds import ( + process_status as aliyun_process_status, + create_kill_session as aliyun_create_kill_session, + kill_session as aliyun_kill_session, + sapce_status as aliyun_sapce_status, +) -logger = logging.getLogger('default') +logger = logging.getLogger("default") # 问题诊断--进程列表 -@permission_required('sql.process_view', raise_exception=True) +@permission_required("sql.process_view", raise_exception=True) def process(request): - instance_name = request.POST.get('instance_name') - command_type = request.POST.get('command_type') + instance_name = request.POST.get("instance_name") + command_type = request.POST.get("command_type") try: instance = user_instances(request.user).get(instance_name=instance_name) except Instance.DoesNotExist: - result = {'status': 1, 'msg': '你所在组未关联该实例', 'data': []} - return HttpResponse(json.dumps(result), content_type='application/json') + result = {"status": 1, "msg": "你所在组未关联该实例", "data": []} + return HttpResponse(json.dumps(result), content_type="application/json") query_engine = get_engine(instance=instance) query_result = None - if instance.db_type == 'mysql': + if instance.db_type == "mysql": # 判断是RDS还是其他实例 if AliyunRdsConfig.objects.filter(instance=instance, is_enable=True).exists(): result = aliyun_process_status(request) else: query_result = query_engine.processlist(command_type) - elif instance.db_type == 'mongo': - query_result = query_engine.current_op(command_type) + elif instance.db_type == "mongo": + query_result = query_engine.current_op(command_type) else: - result = {'status': 1, 'msg': '暂时不支持{}类型数据库的进程列表查询'.format(instance.db_type), 'data': []} - return HttpResponse(json.dumps(result), content_type='application/json') - - if query_result: + result = { + "status": 1, + "msg": "暂时不支持{}类型数据库的进程列表查询".format(instance.db_type), + "data": [], + } + return HttpResponse(json.dumps(result), content_type="application/json") + + if query_result: if not query_result.error: processlist = query_result.to_dict() - result = {'status': 0, 'msg': 'ok', 'rows': processlist} + result = {"status": 0, "msg": "ok", "rows": processlist} else: - result = {'status': 1, 'msg': query_result.error} - + result = {"status": 1, "msg": query_result.error} + # 返回查询结果 # ExtendJSONEncoderBytes 使用json模块,bigint_as_string只支持simplejson - return HttpResponse(json.dumps(result, cls=ExtendJSONEncoderBytes), - content_type='application/json') + return HttpResponse( + json.dumps(result, cls=ExtendJSONEncoderBytes), content_type="application/json" + ) # 问题诊断--通过线程id构建请求 这里只是用于确定将要kill的线程id还在运行 -@permission_required('sql.process_kill', raise_exception=True) +@permission_required("sql.process_kill", raise_exception=True) def create_kill_session(request): - instance_name = request.POST.get('instance_name') - thread_ids = request.POST.get('ThreadIDs') + instance_name = request.POST.get("instance_name") + thread_ids = request.POST.get("ThreadIDs") try: instance = user_instances(request.user).get(instance_name=instance_name) except Instance.DoesNotExist: - result = {'status': 1, 'msg': '你所在组未关联该实例', 'data': []} - return HttpResponse(json.dumps(result), content_type='application/json') + result = {"status": 1, "msg": "你所在组未关联该实例", "data": []} + return HttpResponse(json.dumps(result), content_type="application/json") - result = {'status': 0, 'msg': 'ok', 'data': []} + result = {"status": 0, "msg": "ok", "data": []} query_engine = get_engine(instance=instance) - if instance.db_type == 'mysql': + if instance.db_type == "mysql": # 判断是RDS还是其他实例 if AliyunRdsConfig.objects.filter(instance=instance, is_enable=True).exists(): result = aliyun_create_kill_session(request) else: - result['data'] = query_engine.get_kill_command(json.loads(thread_ids)) - elif instance.db_type == 'mongo': + result["data"] = query_engine.get_kill_command(json.loads(thread_ids)) + elif instance.db_type == "mongo": kill_command = query_engine.get_kill_command(json.loads(thread_ids)) - result['data'] = kill_command + result["data"] = kill_command else: - result = {'status': 1, 'msg': '暂时不支持{}类型数据库通过进程id构建请求'.format(instance.db_type), 'data': []} - return HttpResponse(json.dumps(result), content_type='application/json') + result = { + "status": 1, + "msg": "暂时不支持{}类型数据库通过进程id构建请求".format(instance.db_type), + "data": [], + } + return HttpResponse(json.dumps(result), content_type="application/json") # 返回查询结果 - return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), - content_type='application/json') + return HttpResponse( + json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), + content_type="application/json", + ) # 问题诊断--终止会话 这里是实际执行kill的操作 -@permission_required('sql.process_kill', raise_exception=True) +@permission_required("sql.process_kill", raise_exception=True) def kill_session(request): - instance_name = request.POST.get('instance_name') - thread_ids = request.POST.get('ThreadIDs') - result = {'status': 0, 'msg': 'ok', 'data': []} + instance_name = request.POST.get("instance_name") + thread_ids = request.POST.get("ThreadIDs") + result = {"status": 0, "msg": "ok", "data": []} try: instance = user_instances(request.user).get(instance_name=instance_name) except Instance.DoesNotExist: - result = {'status': 1, 'msg': '你所在组未关联该实例', 'data': []} - return HttpResponse(json.dumps(result), content_type='application/json') + result = {"status": 1, "msg": "你所在组未关联该实例", "data": []} + return HttpResponse(json.dumps(result), content_type="application/json") engine = get_engine(instance=instance) r = None - if instance.db_type == 'mysql': + if instance.db_type == "mysql": # 判断是RDS还是其他实例 if AliyunRdsConfig.objects.filter(instance=instance, is_enable=True).exists(): result = aliyun_kill_session(request) else: r = engine.kill(json.loads(thread_ids)) - elif instance.db_type == 'mongo': + elif instance.db_type == "mongo": r = engine.kill_op(json.loads(thread_ids)) else: - result = {'status': 1, 'msg': '暂时不支持{}类型数据库终止会话'.format(instance.db_type), 'data': []} - return HttpResponse(json.dumps(result), content_type='application/json') + result = { + "status": 1, + "msg": "暂时不支持{}类型数据库终止会话".format(instance.db_type), + "data": [], + } + return HttpResponse(json.dumps(result), content_type="application/json") if r and r.error: - result = {'status': 1, 'msg': r.error, 'data': []} + result = {"status": 1, "msg": r.error, "data": []} # 返回查询结果 - return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), - content_type='application/json') + return HttpResponse( + json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), + content_type="application/json", + ) # 问题诊断--表空间信息 -@permission_required('sql.tablespace_view', raise_exception=True) +@permission_required("sql.tablespace_view", raise_exception=True) def tablesapce(request): - instance_name = request.POST.get('instance_name') - offset = int(request.POST.get('offset',0)) - limit = int(request.POST.get('limit',14)) + instance_name = request.POST.get("instance_name") + offset = int(request.POST.get("offset", 0)) + limit = int(request.POST.get("limit", 14)) try: instance = user_instances(request.user).get(instance_name=instance_name) except Instance.DoesNotExist: - result = {'status': 1, 'msg': '你所在组未关联该实例', 'data': []} - return HttpResponse(json.dumps(result), content_type='application/json') - + result = {"status": 1, "msg": "你所在组未关联该实例", "data": []} + return HttpResponse(json.dumps(result), content_type="application/json") + query_engine = get_engine(instance=instance) - if instance.db_type == 'mysql': + if instance.db_type == "mysql": # 判断是RDS还是其他实例 if AliyunRdsConfig.objects.filter(instance=instance, is_enable=True).exists(): result = aliyun_sapce_status(request) else: - query_result = query_engine.tablesapce(offset,limit) + query_result = query_engine.tablesapce(offset, limit) r = query_engine.tablesapce_num() total = r.rows[0][0] else: - result = {'status': 1, 'msg': '暂时不支持{}类型数据库的表空间信息查询'.format(instance.db_type), 'data': []} - return HttpResponse(json.dumps(result), content_type='application/json') - - if query_result: + result = { + "status": 1, + "msg": "暂时不支持{}类型数据库的表空间信息查询".format(instance.db_type), + "data": [], + } + return HttpResponse(json.dumps(result), content_type="application/json") + + if query_result: if not query_result.error: table_space = query_result.to_dict() - result = {'status': 0, 'msg': 'ok', 'rows': table_space, 'total':total} + result = {"status": 0, "msg": "ok", "rows": table_space, "total": total} else: - result = {'status': 1, 'msg': query_result.error} + result = {"status": 1, "msg": query_result.error} # 返回查询结果 - return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), - content_type='application/json') + return HttpResponse( + json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), + content_type="application/json", + ) # 问题诊断--锁等待 -@permission_required('sql.trxandlocks_view', raise_exception=True) +@permission_required("sql.trxandlocks_view", raise_exception=True) def trxandlocks(request): - instance_name = request.POST.get('instance_name') + instance_name = request.POST.get("instance_name") try: instance = user_instances(request.user).get(instance_name=instance_name) except Instance.DoesNotExist: - result = {'status': 1, 'msg': '你所在组未关联该实例', 'data': []} - return HttpResponse(json.dumps(result), content_type='application/json') - + result = {"status": 1, "msg": "你所在组未关联该实例", "data": []} + return HttpResponse(json.dumps(result), content_type="application/json") + query_engine = get_engine(instance=instance) - if instance.db_type == 'mysql': + if instance.db_type == "mysql": query_result = query_engine.trxandlocks() - + else: - result = {'status': 1, 'msg': '暂时不支持{}类型数据库的锁等待查询'.format(instance.db_type), 'data': []} - return HttpResponse(json.dumps(result), content_type='application/json') - + result = { + "status": 1, + "msg": "暂时不支持{}类型数据库的锁等待查询".format(instance.db_type), + "data": [], + } + return HttpResponse(json.dumps(result), content_type="application/json") + if not query_result.error: trxandlocks = query_result.to_dict() - result = {'status': 0, 'msg': 'ok', 'rows': trxandlocks} + result = {"status": 0, "msg": "ok", "rows": trxandlocks} else: - result = {'status': 1, 'msg': query_result.error} + result = {"status": 1, "msg": query_result.error} # 返回查询结果 - return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), - content_type='application/json') + return HttpResponse( + json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), + content_type="application/json", + ) # 问题诊断--长事务 -@permission_required('sql.trx_view', raise_exception=True) +@permission_required("sql.trx_view", raise_exception=True) def innodb_trx(request): - instance_name = request.POST.get('instance_name') + instance_name = request.POST.get("instance_name") try: instance = user_instances(request.user).get(instance_name=instance_name) except Instance.DoesNotExist: - result = {'status': 1, 'msg': '你所在组未关联该实例', 'data': []} - return HttpResponse(json.dumps(result), content_type='application/json') - + result = {"status": 1, "msg": "你所在组未关联该实例", "data": []} + return HttpResponse(json.dumps(result), content_type="application/json") + query_engine = get_engine(instance=instance) - if instance.db_type == 'mysql': + if instance.db_type == "mysql": query_result = query_engine.get_long_transaction() else: - result = {'status': 1, 'msg': '暂时不支持{}类型数据库的长事务查询'.format(instance.db_type), 'data': []} - return HttpResponse(json.dumps(result), content_type='application/json') - + result = { + "status": 1, + "msg": "暂时不支持{}类型数据库的长事务查询".format(instance.db_type), + "data": [], + } + return HttpResponse(json.dumps(result), content_type="application/json") + if not query_result.error: trx = query_result.to_dict() - result = {'status': 0, 'msg': 'ok', 'rows': trx} + result = {"status": 0, "msg": "ok", "rows": trx} else: - result = {'status': 1, 'msg': query_result.error} + result = {"status": 1, "msg": query_result.error} # 返回查询结果 - return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), - content_type='application/json') + return HttpResponse( + json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), + content_type="application/json", + ) diff --git a/sql/engines/__init__.py b/sql/engines/__init__.py index 5b9e108915..a101abed13 100644 --- a/sql/engines/__init__.py +++ b/sql/engines/__init__.py @@ -5,6 +5,7 @@ class EngineBase: """enginebase 只定义了init函数和若干方法的名字, 具体实现用mysql.py pg.py等实现""" + test_query = None def __init__(self, instance=None): @@ -35,14 +36,14 @@ def __init__(self, instance=None): self.host, self.port = self.ssh.get_ssh() def __del__(self): - if hasattr(self, 'ssh'): + if hasattr(self, "ssh"): del self.ssh - if hasattr(self, 'remotessh'): + if hasattr(self, "remotessh"): del self.remotessh def remote_instance_conn(self, instance=None): # 判断如果配置了隧道则连接隧道 - if not hasattr(self, 'remotessh') and instance.tunnel: + if not hasattr(self, "remotessh") and instance.tunnel: self.remotessh = SSHConnection( instance.host, instance.port, @@ -61,7 +62,12 @@ def remote_instance_conn(self, instance=None): self.remote_port = instance.port self.remote_user = instance.user self.remote_password = instance.password - return self.remote_host, self.remote_port, self.remote_user, self.remote_password + return ( + self.remote_host, + self.remote_port, + self.remote_user, + self.remote_password, + ) def get_connection(self, db_name=None): """返回一个conn实例""" @@ -137,20 +143,20 @@ def describe_table(self, db_name, tb_name, **kwargs): def query_check(self, db_name=None, sql=""): """查询语句的检查、注释去除、切分, 返回一个字典 {'bad_query': bool, 'filtered_sql': str}""" - def filter_sql(self, sql='', limit_num=0): + def filter_sql(self, sql="", limit_num=0): """给查询语句增加结果级限制或者改写语句, 返回修改后的语句""" return sql.strip() - def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs): + def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs): """实际查询 返回一个ResultSet""" return ResultSet() - def query_masking(self, db_name=None, sql='', resultset=None): + def query_masking(self, db_name=None, sql="", resultset=None): """传入 sql语句, db名, 结果集, 返回一个脱敏后的结果集""" return resultset - def execute_check(self, db_name=None, sql=''): + def execute_check(self, db_name=None, sql=""): """执行语句的检查 返回一个ReviewSet""" return ReviewSet() @@ -209,12 +215,12 @@ def get_engine(instance=None): # pragma: no cover return PhoenixEngine(instance=instance) - elif instance.db_type == 'odps': + elif instance.db_type == "odps": from .odps import ODPSEngine return ODPSEngine(instance=instance) - elif instance.db_type == 'clickhouse': + elif instance.db_type == "clickhouse": from .clickhouse import ClickHouseEngine return ClickHouseEngine(instance=instance) diff --git a/sql/engines/clickhouse.py b/sql/engines/clickhouse.py index 3d991cd9c2..8944d27b89 100644 --- a/sql/engines/clickhouse.py +++ b/sql/engines/clickhouse.py @@ -9,7 +9,7 @@ import logging import re -logger = logging.getLogger('default') +logger = logging.getLogger("default") class ClickHouseEngine(EngineBase): @@ -23,20 +23,31 @@ def get_connection(self, db_name=None): if self.conn: return self.conn if db_name: - self.conn = connect(host=self.host, port=self.port, user=self.user, password=self.password, - database=db_name, connect_timeout=10) + self.conn = connect( + host=self.host, + port=self.port, + user=self.user, + password=self.password, + database=db_name, + connect_timeout=10, + ) else: - self.conn = connect(host=self.host, port=self.port, user=self.user, password=self.password, - connect_timeout=10) + self.conn = connect( + host=self.host, + port=self.port, + user=self.user, + password=self.password, + connect_timeout=10, + ) return self.conn @property def name(self): - return 'ClickHouse' + return "ClickHouse" @property def info(self): - return 'ClickHouse engine' + return "ClickHouse engine" @property def auto_backup(self): @@ -47,8 +58,8 @@ def auto_backup(self): def server_version(self): sql = "select value from system.build_options where name = 'VERSION_FULL';" result = self.query(sql=sql) - version = result.rows[0][0].split(' ')[1] - return tuple([int(n) for n in version.split('.')[:3]]) + version = result.rows[0][0].split(" ")[1] + return tuple([int(n) for n in version.split(".")[:3]]) def get_table_engine(self, tb_name): """获取某个table的engine type""" @@ -58,17 +69,21 @@ def get_table_engine(self, tb_name): and name='{tb_name.split('.')[1]}'""" query_result = self.query(sql=sql) if query_result.rows: - result = {'status': 1, 'engine': query_result.rows[0][0]} + result = {"status": 1, "engine": query_result.rows[0][0]} else: - result = {'status': 0, 'engine': 'None'} + result = {"status": 0, "engine": "None"} return result def get_all_databases(self): """获取数据库列表, 返回一个ResultSet""" sql = "show databases" result = self.query(sql=sql) - db_list = [row[0] for row in result.rows - if row[0] not in ('system', 'INFORMATION_SCHEMA', 'information_schema', 'datasets')] + db_list = [ + row[0] + for row in result.rows + if row[0] + not in ("system", "INFORMATION_SCHEMA", "information_schema", "datasets") + ] result.rows = db_list return result @@ -101,11 +116,13 @@ def describe_table(self, db_name, tb_name, **kwargs): sql = f"show create table `{tb_name}`;" result = self.query(db_name=db_name, sql=sql) - result.rows[0] = (tb_name,) + (result.rows[0][0].replace('(', '(\n ').replace(',', ',\n '),) + result.rows[0] = (tb_name,) + ( + result.rows[0][0].replace("(", "(\n ").replace(",", ",\n "), + ) return result - def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs): - """返回 ResultSet """ + def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs): + """返回 ResultSet""" result_set = ResultSet(full_sql=sql) try: conn = self.get_connection(db_name=db_name) @@ -122,91 +139,97 @@ def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs): result_set.affected_rows = len(rows) except Exception as e: logger.warning(f"ClickHouse语句执行报错,语句:{sql},错误信息{e}") - result_set.error = str(e).split('Stack trace')[0] + result_set.error = str(e).split("Stack trace")[0] finally: if close_conn: self.close() return result_set - def query_check(self, db_name=None, sql=''): + def query_check(self, db_name=None, sql=""): # 查询语句的检查、注释去除、切分 - result = {'msg': '', 'bad_query': False, 'filtered_sql': sql, 'has_star': False} + result = {"msg": "", "bad_query": False, "filtered_sql": sql, "has_star": False} # 删除注释语句,进行语法判断,执行第一条有效sql try: sql = sqlparse.format(sql, strip_comments=True) sql = sqlparse.split(sql)[0] - result['filtered_sql'] = sql.strip() + result["filtered_sql"] = sql.strip() except IndexError: - result['bad_query'] = True - result['msg'] = '没有有效的SQL语句' + result["bad_query"] = True + result["msg"] = "没有有效的SQL语句" if re.match(r"^select|^show|^explain", sql, re.I) is None: - result['bad_query'] = True - result['msg'] = '不支持的查询语法类型!' - if '*' in sql: - result['has_star'] = True - result['msg'] = 'SQL语句中含有 * ' + result["bad_query"] = True + result["msg"] = "不支持的查询语法类型!" + if "*" in sql: + result["has_star"] = True + result["msg"] = "SQL语句中含有 * " # clickhouse 20.6.3版本开始正式支持explain语法 if re.match(r"^explain", sql, re.I) and self.server_version < (20, 6, 3): - result['bad_query'] = True - result['msg'] = f"当前ClickHouse实例版本低于20.6.3,不支持explain!" + result["bad_query"] = True + result["msg"] = f"当前ClickHouse实例版本低于20.6.3,不支持explain!" # select语句先使用Explain判断语法是否正确 if re.match(r"^select", sql, re.I) and self.server_version >= (20, 6, 3): explain_result = self.query(db_name=db_name, sql=f"explain {sql}") if explain_result.error: - result['bad_query'] = True - result['msg'] = explain_result.error + result["bad_query"] = True + result["msg"] = explain_result.error return result - def filter_sql(self, sql='', limit_num=0): + def filter_sql(self, sql="", limit_num=0): # 对查询sql增加limit限制,limit n 或 limit n,n 或 limit n offset n统一改写成limit n - sql = sql.rstrip(';').strip() + sql = sql.rstrip(";").strip() if re.match(r"^select", sql, re.I): # LIMIT N - limit_n = re.compile(r'limit\s+(\d+)\s*$', re.I) + limit_n = re.compile(r"limit\s+(\d+)\s*$", re.I) # LIMIT M OFFSET N - limit_offset = re.compile(r'limit\s+(\d+)\s+offset\s+(\d+)\s*$', re.I) + limit_offset = re.compile(r"limit\s+(\d+)\s+offset\s+(\d+)\s*$", re.I) # LIMIT M,N - offset_comma_limit = re.compile(r'limit\s+(\d+)\s*,\s*(\d+)\s*$', re.I) + offset_comma_limit = re.compile(r"limit\s+(\d+)\s*,\s*(\d+)\s*$", re.I) if limit_n.search(sql): sql_limit = limit_n.search(sql).group(1) limit_num = min(int(limit_num), int(sql_limit)) - sql = limit_n.sub(f'limit {limit_num};', sql) + sql = limit_n.sub(f"limit {limit_num};", sql) elif limit_offset.search(sql): sql_limit = limit_offset.search(sql).group(1) sql_offset = limit_offset.search(sql).group(2) limit_num = min(int(limit_num), int(sql_limit)) - sql = limit_offset.sub(f'limit {limit_num} offset {sql_offset};', sql) + sql = limit_offset.sub(f"limit {limit_num} offset {sql_offset};", sql) elif offset_comma_limit.search(sql): sql_offset = offset_comma_limit.search(sql).group(1) sql_limit = offset_comma_limit.search(sql).group(2) limit_num = min(int(limit_num), int(sql_limit)) - sql = offset_comma_limit.sub(f'limit {sql_offset},{limit_num};', sql) + sql = offset_comma_limit.sub(f"limit {sql_offset},{limit_num};", sql) else: - sql = f'{sql} limit {limit_num};' + sql = f"{sql} limit {limit_num};" else: - sql = f'{sql};' + sql = f"{sql};" return sql - def explain_check(self, check_result, db_name=None, line=0, statement=''): + def explain_check(self, check_result, db_name=None, line=0, statement=""): """使用explain ast检查sql语法, 返回Review set""" - result = ReviewResult(id=line, errlevel=0, - stagestatus='Audit completed', - errormessage='None', - sql=statement, - affected_rows=0, - execute_time=0, ) + result = ReviewResult( + id=line, + errlevel=0, + stagestatus="Audit completed", + errormessage="None", + sql=statement, + affected_rows=0, + execute_time=0, + ) # clickhouse版本>=21.1.2 explain ast才支持非select语句检查 if self.server_version >= (21, 1, 2): explain_result = self.query(db_name=db_name, sql=f"explain ast {statement}") if explain_result.error: - result = ReviewResult(id=line, errlevel=2, - stagestatus='驳回未通过检查SQL', - errormessage=f'explain语法检查错误:{explain_result.error}', - sql=statement) + result = ReviewResult( + id=line, + errlevel=2, + stagestatus="驳回未通过检查SQL", + errormessage=f"explain语法检查错误:{explain_result.error}", + sql=statement, + ) return result - def execute_check(self, db_name=None, sql=''): + def execute_check(self, db_name=None, sql=""): """上线单执行前的检查, 返回Review set""" sql = sqlparse.format(sql, strip_comments=True) sql_list = sqlparse.split(sql) @@ -214,111 +237,162 @@ def execute_check(self, db_name=None, sql=''): # 禁用/高危语句检查 check_result = ReviewSet(full_sql=sql) line = 1 - critical_ddl_regex = self.config.get('critical_ddl_regex', '') + critical_ddl_regex = self.config.get("critical_ddl_regex", "") p = re.compile(critical_ddl_regex) check_result.syntax_type = 2 # TODO 工单类型 0、其他 1、DDL,2、DML for statement in sql_list: - statement = statement.rstrip(';') + statement = statement.rstrip(";") # 禁用语句 if re.match(r"^select|^show", statement.lower()): - result = ReviewResult(id=line, errlevel=2, - stagestatus='驳回不支持语句', - errormessage='仅支持DML和DDL语句,查询语句请使用SQL查询功能!', - sql=statement) + result = ReviewResult( + id=line, + errlevel=2, + stagestatus="驳回不支持语句", + errormessage="仅支持DML和DDL语句,查询语句请使用SQL查询功能!", + sql=statement, + ) # 高危语句 elif critical_ddl_regex and p.match(statement.strip().lower()): - result = ReviewResult(id=line, errlevel=2, - stagestatus='驳回高危SQL', - errormessage='禁止提交匹配' + critical_ddl_regex + '条件的语句!', - sql=statement) + result = ReviewResult( + id=line, + errlevel=2, + stagestatus="驳回高危SQL", + errormessage="禁止提交匹配" + critical_ddl_regex + "条件的语句!", + sql=statement, + ) # alter语句 elif re.match(r"^alter", statement.lower()): # alter table语句 if re.match(r"^alter\s+table\s+(.+?)\s+", statement.lower()): - table_name = re.match(r"^alter\s+table\s+(.+?)\s+", statement.lower(), re.M).group(1) - if '.' not in table_name: + table_name = re.match( + r"^alter\s+table\s+(.+?)\s+", statement.lower(), re.M + ).group(1) + if "." not in table_name: table_name = f"{db_name}.{table_name}" - table_engine = self.get_table_engine(table_name)['engine'] - table_exist = self.get_table_engine(table_name)['status'] + table_engine = self.get_table_engine(table_name)["engine"] + table_exist = self.get_table_engine(table_name)["status"] if table_exist == 1: - if not table_engine.endswith('MergeTree') and table_engine not in ('Merge', 'Distributed'): - result = ReviewResult(id=line, errlevel=2, - stagestatus='驳回不支持SQL', - errormessage='ALTER TABLE仅支持*MergeTree,Merge以及Distributed等引擎表!', - sql=statement) + if not table_engine.endswith( + "MergeTree" + ) and table_engine not in ("Merge", "Distributed"): + result = ReviewResult( + id=line, + errlevel=2, + stagestatus="驳回不支持SQL", + errormessage="ALTER TABLE仅支持*MergeTree,Merge以及Distributed等引擎表!", + sql=statement, + ) else: # delete与update语句,实际是alter语句的变种 - if re.match(r"^alter\s+table\s+(.+?)\s+(delete|update)\s+", statement.lower()): - if not table_engine.endswith('MergeTree'): - result = ReviewResult(id=line, errlevel=2, - stagestatus='驳回不支持SQL', - errormessage='DELETE与UPDATE仅支持*MergeTree引擎表!', - sql=statement) + if re.match( + r"^alter\s+table\s+(.+?)\s+(delete|update)\s+", + statement.lower(), + ): + if not table_engine.endswith("MergeTree"): + result = ReviewResult( + id=line, + errlevel=2, + stagestatus="驳回不支持SQL", + errormessage="DELETE与UPDATE仅支持*MergeTree引擎表!", + sql=statement, + ) else: - result = self.explain_check(check_result, db_name, line, statement) + result = self.explain_check( + check_result, db_name, line, statement + ) else: - result = self.explain_check(check_result, db_name, line, statement) + result = self.explain_check( + check_result, db_name, line, statement + ) else: - result = ReviewResult(id=line, errlevel=2, - stagestatus='表不存在', - errormessage=f'表 {table_name} 不存在!', - sql=statement) + result = ReviewResult( + id=line, + errlevel=2, + stagestatus="表不存在", + errormessage=f"表 {table_name} 不存在!", + sql=statement, + ) # 其他alter语句 else: result = self.explain_check(check_result, db_name, line, statement) # truncate语句 elif re.match(r"^truncate\s+table\s+(.+?)(\s|$)", statement.lower()): - table_name = re.match(r"^truncate\s+table\s+(.+?)(\s|$)", statement.lower(), re.M).group(1) - if '.' not in table_name: + table_name = re.match( + r"^truncate\s+table\s+(.+?)(\s|$)", statement.lower(), re.M + ).group(1) + if "." not in table_name: table_name = f"{db_name}.{table_name}" - table_engine = self.get_table_engine(table_name)['engine'] - table_exist = self.get_table_engine(table_name)['status'] + table_engine = self.get_table_engine(table_name)["engine"] + table_exist = self.get_table_engine(table_name)["status"] if table_exist == 1: - if table_engine in ('View', 'File,', 'URL', 'Buffer', 'Null'): - result = ReviewResult(id=line, errlevel=2, - stagestatus='驳回不支持SQL', - errormessage='TRUNCATE不支持View,File,URL,Buffer和Null表引擎!', - sql=statement) + if table_engine in ("View", "File,", "URL", "Buffer", "Null"): + result = ReviewResult( + id=line, + errlevel=2, + stagestatus="驳回不支持SQL", + errormessage="TRUNCATE不支持View,File,URL,Buffer和Null表引擎!", + sql=statement, + ) else: - result = self.explain_check(check_result, db_name, line, statement) + result = self.explain_check( + check_result, db_name, line, statement + ) else: - result = ReviewResult(id=line, errlevel=2, - stagestatus='表不存在', - errormessage=f'表 {table_name} 不存在!', - sql=statement) + result = ReviewResult( + id=line, + errlevel=2, + stagestatus="表不存在", + errormessage=f"表 {table_name} 不存在!", + sql=statement, + ) # insert语句,explain无法正确判断,暂时只做表存在性检查与简单关键字匹配 elif re.match(r"^insert", statement.lower()): - if re.match(r"^insert\s+into\s+(.+?)(\s+|\s*\(.+?)(values|format|select)(\s+|\()", statement.lower()): - table_name = re.match(r"^insert\s+into\s+(.+?)(\s+|\s*\(.+?)(values|format|select)(\s+|\()", - statement.lower(), re.M).group(1) - if '.' not in table_name: + if re.match( + r"^insert\s+into\s+(.+?)(\s+|\s*\(.+?)(values|format|select)(\s+|\()", + statement.lower(), + ): + table_name = re.match( + r"^insert\s+into\s+(.+?)(\s+|\s*\(.+?)(values|format|select)(\s+|\()", + statement.lower(), + re.M, + ).group(1) + if "." not in table_name: table_name = f"{db_name}.{table_name}" - table_exist = self.get_table_engine(table_name)['status'] + table_exist = self.get_table_engine(table_name)["status"] if table_exist == 1: - result = ReviewResult(id=line, errlevel=0, - stagestatus='Audit completed', - errormessage='None', - sql=statement, - affected_rows=0, - execute_time=0, ) + result = ReviewResult( + id=line, + errlevel=0, + stagestatus="Audit completed", + errormessage="None", + sql=statement, + affected_rows=0, + execute_time=0, + ) else: - result = ReviewResult(id=line, errlevel=2, - stagestatus='表不存在', - errormessage=f'表 {table_name} 不存在!', - sql=statement) + result = ReviewResult( + id=line, + errlevel=2, + stagestatus="表不存在", + errormessage=f"表 {table_name} 不存在!", + sql=statement, + ) else: - result = ReviewResult(id=line, errlevel=2, - stagestatus='驳回不支持SQL', - errormessage='INSERT语法不正确!', - sql=statement) + result = ReviewResult( + id=line, + errlevel=2, + stagestatus="驳回不支持SQL", + errormessage="INSERT语法不正确!", + sql=statement, + ) # 其他语句使用explain ast简单检查 else: result = self.explain_check(check_result, db_name, line, statement) # 没有找出DDL语句的才继续执行此判断 if check_result.syntax_type == 2: - if get_syntax_type(statement, parser=False, db_type='mysql') == 'DDL': + if get_syntax_type(statement, parser=False, db_type="mysql") == "DDL": check_result.syntax_type = 1 check_result.rows += [result] line += 1 @@ -340,47 +414,55 @@ def execute_workflow(self, workflow): line = 1 for statement in sql_list: with FuncTimer() as t: - result = self.execute(db_name=workflow.db_name, sql=statement, close_conn=True) + result = self.execute( + db_name=workflow.db_name, sql=statement, close_conn=True + ) if not result.error: - execute_result.rows.append(ReviewResult( - id=line, - errlevel=0, - stagestatus='Execute Successfully', - errormessage='None', - sql=statement, - affected_rows=0, - execute_time=t.cost, - )) + execute_result.rows.append( + ReviewResult( + id=line, + errlevel=0, + stagestatus="Execute Successfully", + errormessage="None", + sql=statement, + affected_rows=0, + execute_time=t.cost, + ) + ) line += 1 else: # 追加当前报错语句信息到执行结果中 execute_result.error = result.error - execute_result.rows.append(ReviewResult( - id=line, - errlevel=2, - stagestatus='Execute Failed', - errormessage=f'异常信息:{result.error}', - sql=statement, - affected_rows=0, - execute_time=0, - )) - line += 1 - # 报错语句后面的语句标记为审核通过、未执行,追加到执行结果中 - for statement in sql_list[line - 1:]: - execute_result.rows.append(ReviewResult( + execute_result.rows.append( + ReviewResult( id=line, - errlevel=0, - stagestatus='Audit completed', - errormessage=f'前序语句失败, 未执行', + errlevel=2, + stagestatus="Execute Failed", + errormessage=f"异常信息:{result.error}", sql=statement, affected_rows=0, execute_time=0, - )) + ) + ) + line += 1 + # 报错语句后面的语句标记为审核通过、未执行,追加到执行结果中 + for statement in sql_list[line - 1 :]: + execute_result.rows.append( + ReviewResult( + id=line, + errlevel=0, + stagestatus="Audit completed", + errormessage=f"前序语句失败, 未执行", + sql=statement, + affected_rows=0, + execute_time=0, + ) + ) line += 1 break return execute_result - def execute(self, db_name=None, sql='', close_conn=True): + def execute(self, db_name=None, sql="", close_conn=True): """原生执行语句""" result = ResultSet(full_sql=sql) conn = self.get_connection(db_name=db_name) @@ -391,7 +473,7 @@ def execute(self, db_name=None, sql='', close_conn=True): cursor.close() except Exception as e: logger.warning(f"ClickHouse语句执行报错,语句:{sql},错误信息{e}") - result.error = str(e).split('Stack trace')[0] + result.error = str(e).split("Stack trace")[0] if close_conn: self.close() return result diff --git a/sql/engines/goinception.py b/sql/engines/goinception.py index eca7c525ed..2ed75ced42 100644 --- a/sql/engines/goinception.py +++ b/sql/engines/goinception.py @@ -11,7 +11,7 @@ from . import EngineBase from .models import ResultSet, ReviewSet, ReviewResult -logger = logging.getLogger('default') +logger = logging.getLogger("default") class GoInceptionEngine(EngineBase): @@ -19,42 +19,51 @@ class GoInceptionEngine(EngineBase): @property def name(self): - return 'GoInception' + return "GoInception" @property def info(self): - return 'GoInception engine' + return "GoInception engine" def get_connection(self, db_name=None): if self.conn: return self.conn - if hasattr(self, 'instance'): - self.conn = MySQLdb.connect(host=self.host, port=self.port, charset=self.instance.charset or 'utf8mb4', - connect_timeout=10) + if hasattr(self, "instance"): + self.conn = MySQLdb.connect( + host=self.host, + port=self.port, + charset=self.instance.charset or "utf8mb4", + connect_timeout=10, + ) return self.conn archer_config = SysConfig() - go_inception_host = archer_config.get('go_inception_host') - go_inception_port = int(archer_config.get('go_inception_port', 4000)) - self.conn = MySQLdb.connect(host=go_inception_host, port=go_inception_port, charset='utf8mb4', - connect_timeout=10) + go_inception_host = archer_config.get("go_inception_host") + go_inception_port = int(archer_config.get("go_inception_port", 4000)) + self.conn = MySQLdb.connect( + host=go_inception_host, + port=go_inception_port, + charset="utf8mb4", + connect_timeout=10, + ) return self.conn @staticmethod def get_backup_connection(): archer_config = SysConfig() - backup_host = archer_config.get('inception_remote_backup_host') - backup_port = int(archer_config.get('inception_remote_backup_port', 3306)) - backup_user = archer_config.get('inception_remote_backup_user') - backup_password = archer_config.get('inception_remote_backup_password', '') - return MySQLdb.connect(host=backup_host, - port=backup_port, - user=backup_user, - passwd=backup_password, - charset='utf8mb4', - autocommit=True - ) + backup_host = archer_config.get("inception_remote_backup_host") + backup_port = int(archer_config.get("inception_remote_backup_port", 3306)) + backup_user = archer_config.get("inception_remote_backup_user") + backup_password = archer_config.get("inception_remote_backup_password", "") + return MySQLdb.connect( + host=backup_host, + port=backup_port, + user=backup_user, + passwd=backup_password, + charset="utf8mb4", + autocommit=True, + ) - def execute_check(self, instance=None, db_name=None, sql=''): + def execute_check(self, instance=None, db_name=None, sql=""): """inception check""" # 判断如果配置了隧道则连接隧道 host, port, user, password = self.remote_instance_conn(instance) @@ -78,7 +87,7 @@ def execute_check(self, instance=None, db_name=None, sql=''): check_result.error_count += 1 # 没有找出DDL语句的才继续执行此判断 if check_result.syntax_type == 2: - if get_syntax_type(r[5], parser=False, db_type='mysql') == 'DDL': + if get_syntax_type(r[5], parser=False, db_type="mysql") == "DDL": check_result.syntax_type = 1 check_result.column_list = inception_result.column_list check_result.checked = True @@ -109,12 +118,15 @@ def execute(self, workflow=None): # 执行报错,inception crash或者执行中连接异常的场景 if inception_result.error and not execute_result.rows: execute_result.error = inception_result.error - execute_result.rows = [ReviewResult( - stage='Execute failed', - errlevel=2, - stagestatus='异常终止', - errormessage=f'goInception Error: {inception_result.error}', - sql=workflow.sqlworkflowcontent.sql_content)] + execute_result.rows = [ + ReviewResult( + stage="Execute failed", + errlevel=2, + stagestatus="异常终止", + errormessage=f"goInception Error: {inception_result.error}", + sql=workflow.sqlworkflowcontent.sql_content, + ) + ] return execute_result # 把结果转换为ReviewSet @@ -123,13 +135,17 @@ def execute(self, workflow=None): # 如果发现任何一个行执行结果里有errLevel为1或2,并且状态列没有包含Execute Successfully,则最终执行结果为有异常. for r in execute_result.rows: - if r.errlevel in (1, 2) and not re.search(r"Execute Successfully", r.stagestatus): - execute_result.error = "Line {0} has error/warning: {1}".format(r.id, r.errormessage) + if r.errlevel in (1, 2) and not re.search( + r"Execute Successfully", r.stagestatus + ): + execute_result.error = "Line {0} has error/warning: {1}".format( + r.id, r.errormessage + ) break return execute_result - def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs): - """返回 ResultSet """ + def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs): + """返回 ResultSet""" result_set = ResultSet(full_sql=sql) conn = self.get_connection() try: @@ -145,13 +161,13 @@ def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs): result_set.rows = rows result_set.affected_rows = effect_row except Exception as e: - logger.warning(f'goInception语句执行报错,错误信息{traceback.format_exc()}') + logger.warning(f"goInception语句执行报错,错误信息{traceback.format_exc()}") result_set.error = str(e) if close_conn: self.close() return result_set - def query_print(self, instance, db_name=None, sql=''): + def query_print(self, instance, db_name=None, sql=""): """ 打印语法树。 """ @@ -163,11 +179,11 @@ def query_print(self, instance, db_name=None, sql=''): {sql.rstrip(';')}; inception_magic_commit;""" print_info = self.query(db_name=db_name, sql=sql).to_dict()[1] - if print_info.get('errmsg'): - raise RuntimeError(print_info.get('errmsg')) + if print_info.get("errmsg"): + raise RuntimeError(print_info.get("errmsg")) return print_info - def query_data_masking(self, instance, db_name=None, sql=''): + def query_data_masking(self, instance, db_name=None, sql=""): """ 将sql交给goInception打印语法树,获取select list 使用 masking 参数,可参考 https://github.com/hanchuanchuan/goInception/pull/355 @@ -182,13 +198,13 @@ def query_data_masking(self, instance, db_name=None, sql=''): query_result = self.query(db_name=db_name, sql=sql) # 有异常时主动抛出 if query_result.error: - raise RuntimeError(f'Inception Error: {query_result.error}') + raise RuntimeError(f"Inception Error: {query_result.error}") if not query_result.rows: - raise RuntimeError(f'Inception Error: 未获取到语法信息') + raise RuntimeError(f"Inception Error: 未获取到语法信息") # 兼容语法错误时errlevel=0的场景 print_info = query_result.to_dict()[0] - if print_info['errlevel'] == 0 and print_info['errmsg'] is None: - return json.loads(print_info['query_tree']) + if print_info["errlevel"] == 0 and print_info["errmsg"] is None: + return json.loads(print_info["query_tree"]) else: raise RuntimeError(f"Inception Error: {print_info['errmsg']}") @@ -196,7 +212,9 @@ def get_rollback(self, workflow): """ 获取回滚语句,并且按照执行顺序倒序展示,return ['源语句','回滚语句'] """ - list_execute_result = json.loads(workflow.sqlworkflowcontent.execute_result or '[]') + list_execute_result = json.loads( + workflow.sqlworkflowcontent.execute_result or "[]" + ) # 回滚语句倒序展示 list_execute_result.reverse() list_backup_sql = [] @@ -207,18 +225,18 @@ def get_rollback(self, workflow): try: # 获取backup_db_name, 兼容旧数据'[[]]'格式 if isinstance(row, list): - if row[8] == 'None': + if row[8] == "None": continue backup_db_name = row[8] sequence = row[7] sql = row[5] # 新数据 else: - if row.get('backup_dbname') in ('None', ''): + if row.get("backup_dbname") in ("None", ""): continue - backup_db_name = row.get('backup_dbname') - sequence = row.get('sequence') - sql = row.get('sql') + backup_db_name = row.get("backup_dbname") + sequence = row.get("sequence") + sql = row.get("sql") # 获取备份表名 opid_time = sequence.replace("'", "") sql_table = f"""select tablename @@ -236,7 +254,9 @@ def get_rollback(self, workflow): cur.execute(sql_back) list_backup = cur.fetchall() # 拼接成回滚语句列表,['源语句','回滚语句'] - list_backup_sql.append([sql, '\n'.join([back_info[0] for back_info in list_backup])]) + list_backup_sql.append( + [sql, "\n".join([back_info[0] for back_info in list_backup])] + ) except Exception as e: logger.error(f"获取回滚语句报错,异常信息{traceback.format_exc()}") raise Exception(e) @@ -260,9 +280,9 @@ def set_variable(self, variable_name, variable_value): def osc_control(self, **kwargs): """控制osc执行,获取进度、终止、暂停、恢复等""" - sqlsha1 = kwargs.get('sqlsha1') - command = kwargs.get('command') - if command == 'get': + sqlsha1 = kwargs.get("sqlsha1") + command = kwargs.get("command") + if command == "get": sql = f"inception get osc_percent '{sqlsha1}';" else: sql = f"inception {command} osc '{sqlsha1}';" @@ -270,7 +290,7 @@ def osc_control(self, **kwargs): @staticmethod def get_table_ref(query_tree, db_name=None): - __author__ = 'xxlrr' + __author__ = "xxlrr" """ * 从goInception解析后的语法树里解析出兼容Inception格式的引用表信息。 * 目前的逻辑是在SQL语法树中通过递归查找选中最小的 TableRefs 子树(可能有多个), @@ -294,12 +314,15 @@ def get_table_ref(query_tree, db_name=None): else: snodes = tree.find_max_tree("Source") if snodes: - table_ref.extend([ - { - "schema": snode['Source']['Schema']['O'] or db_name, - "name": snode['Source']['Name']['O'] - } for snode in snodes - ]) + table_ref.extend( + [ + { + "schema": snode["Source"]["Schema"]["O"] or db_name, + "name": snode["Source"]["Name"]["O"], + } + for snode in snodes + ] + ) # assert: source node must exists if table_refs node exists. # else: # raise Exception("GoInception Error: not found source node") @@ -313,7 +336,7 @@ def close(self): class DictTree(dict): def find_max_tree(self, *keys): - __author__ = 'xxlrr' + __author__ = "xxlrr" """通过广度优先搜索算法查找满足条件的最大子树(不找叶子节点)""" fit = [] find_queue = [self] @@ -331,14 +354,15 @@ def find_max_tree(self, *keys): def get_session_variables(instance): """按照目标实例动态设置goInception的会话参数,可用于按照业务组自定义审核规则等场景""" variables = {} - set_session_sql = '' + set_session_sql = "" if AliyunRdsConfig.objects.filter(instance=instance, is_enable=True).exists(): - variables.update({ - "ghost_aliyun_rds": "on", - "ghost_allow_on_master": "true", - "ghost_assume_rbr": "true", - - }) + variables.update( + { + "ghost_aliyun_rds": "on", + "ghost_allow_on_master": "true", + "ghost_assume_rbr": "true", + } + ) # 转换成SQL语句 for k, v in variables.items(): set_session_sql += f"inception set session {k} = '{v}';\n" diff --git a/sql/engines/models.py b/sql/engines/models.py index fc377590eb..34e73d6652 100644 --- a/sql/engines/models.py +++ b/sql/engines/models.py @@ -4,16 +4,23 @@ class SqlItem: - - def __init__(self, id=0, statement='', stmt_type='SQL', object_owner='', object_type='', object_name=''): - ''' + def __init__( + self, + id=0, + statement="", + stmt_type="SQL", + object_owner="", + object_type="", + object_name="", + ): + """ :param id: SQL序号,从0开始 :param statement: SQL Statement :param stmt_type: SQL类型(SQL, PLSQL), 默认为SQL :param object_owner: PLSQL Object Owner :param object_type: PLSQL Object Type :param object_name: PLSQL Object Name - ''' + """ self.id = id self.statement = statement self.stmt_type = stmt_type @@ -34,32 +41,34 @@ def __init__(self, inception_result=None, **kwargs): """ if inception_result: self.id = inception_result[0] or 0 - self.stage = inception_result[1] or '' + self.stage = inception_result[1] or "" self.errlevel = inception_result[2] or 0 - self.stagestatus = inception_result[3] or '' - self.errormessage = inception_result[4] or '' - self.sql = inception_result[5] or '' + self.stagestatus = inception_result[3] or "" + self.errormessage = inception_result[4] or "" + self.sql = inception_result[5] or "" self.affected_rows = inception_result[6] or 0 - self.sequence = inception_result[7] or '' - self.backup_dbname = inception_result[8] or '' - self.execute_time = inception_result[9] or '' - self.sqlsha1 = inception_result[10] or '' - self.backup_time = inception_result[11] if len(inception_result) >= 12 else '' - self.actual_affected_rows = '' + self.sequence = inception_result[7] or "" + self.backup_dbname = inception_result[8] or "" + self.execute_time = inception_result[9] or "" + self.sqlsha1 = inception_result[10] or "" + self.backup_time = ( + inception_result[11] if len(inception_result) >= 12 else "" + ) + self.actual_affected_rows = "" else: - self.id = kwargs.get('id', 0) - self.stage = kwargs.get('stage', '') - self.errlevel = kwargs.get('errlevel', 0) - self.stagestatus = kwargs.get('stagestatus', '') - self.errormessage = kwargs.get('errormessage', '') - self.sql = kwargs.get('sql', '') - self.affected_rows = kwargs.get('affected_rows', 0) - self.sequence = kwargs.get('sequence', '') - self.backup_dbname = kwargs.get('backup_dbname', '') - self.execute_time = kwargs.get('execute_time', '') - self.sqlsha1 = kwargs.get('sqlsha1', '') - self.backup_time = kwargs.get('backup_time', '') - self.actual_affected_rows = kwargs.get('actual_affected_rows', '') + self.id = kwargs.get("id", 0) + self.stage = kwargs.get("stage", "") + self.errlevel = kwargs.get("errlevel", 0) + self.stagestatus = kwargs.get("stagestatus", "") + self.errormessage = kwargs.get("errormessage", "") + self.sql = kwargs.get("sql", "") + self.affected_rows = kwargs.get("affected_rows", 0) + self.sequence = kwargs.get("sequence", "") + self.backup_dbname = kwargs.get("backup_dbname", "") + self.execute_time = kwargs.get("execute_time", "") + self.sqlsha1 = kwargs.get("sqlsha1", "") + self.backup_time = kwargs.get("backup_time", "") + self.actual_affected_rows = kwargs.get("actual_affected_rows", "") # 自定义属性 for key, value in kwargs.items(): @@ -70,8 +79,15 @@ def __init__(self, inception_result=None, **kwargs): class ReviewSet: """review和执行后的结果集, rows中是review result, 有设定好的字段""" - def __init__(self, full_sql='', rows=None, status=None, - affected_rows=0, column_list=None, **kwargs): + def __init__( + self, + full_sql="", + rows=None, + status=None, + affected_rows=0, + column_list=None, + **kwargs + ): self.full_sql = full_sql self.is_execute = False self.checked = None @@ -107,15 +123,22 @@ def to_dict(self): class ResultSet: """查询的结果集, rows 内只有值, column_list 中的是key""" - def __init__(self, full_sql='', rows=None, status=None, - affected_rows=0, column_list=None, **kwargs): + def __init__( + self, + full_sql="", + rows=None, + status=None, + affected_rows=0, + column_list=None, + **kwargs + ): self.full_sql = full_sql self.is_execute = False self.checked = None self.is_masked = False - self.query_time = '' + self.query_time = "" self.mask_rule_hit = False - self.mask_time = '' + self.mask_time = "" self.warning = None self.error = None self.is_critical = False @@ -134,11 +157,11 @@ def json(self): def to_dict(self): tmp_list = [] for r in self.rows: - if isinstance(r,dict): + if isinstance(r, dict): tmp_list += [r] else: tmp_list += [dict(zip(self.column_list, r))] return tmp_list def to_sep_dict(self): - return {'column_list': self.column_list, 'rows': self.rows} + return {"column_list": self.column_list, "rows": self.rows} diff --git a/sql/engines/mongo.py b/sql/engines/mongo.py index dec1bbc349..e198624537 100644 --- a/sql/engines/mongo.py +++ b/sql/engines/mongo.py @@ -18,10 +18,10 @@ from . import EngineBase from .models import ResultSet, ReviewSet, ReviewResult -logger = logging.getLogger('default') +logger = logging.getLogger("default") # mongo客户端安装在本机的位置 -mongo = 'mongo' +mongo = "mongo" # 自定义异常 @@ -35,7 +35,7 @@ def __str__(self): class JsonDecoder: - '''处理传入mongodb语句中的条件,并转换成pymongo可识别的字典格式''' + """处理传入mongodb语句中的条件,并转换成pymongo可识别的字典格式""" def __init__(self): pass @@ -43,35 +43,37 @@ def __init__(self): def __json_object(self, tokener): # obj = collections.OrderedDict() obj = {} - if tokener.cur_token() != '{': + if tokener.cur_token() != "{": raise Exception('Json must start with "{"') while True: tokener.next() tk_temp = tokener.cur_token() - if tk_temp == '}': + if tk_temp == "}": return {} # 限制key的格式 - if not isinstance(tk_temp, str): # or (not tk_temp.isidentifier() and not tk_temp.startswith("$")) - raise Exception('invalid key %s' % tk_temp) + if not isinstance( + tk_temp, str + ): # or (not tk_temp.isidentifier() and not tk_temp.startswith("$")) + raise Exception("invalid key %s" % tk_temp) key = tk_temp.strip() tokener.next() - if tokener.cur_token() != ':': + if tokener.cur_token() != ":": raise Exception('expect ":" after "%s"' % key) tokener.next() val = tokener.cur_token() - if val == '[': + if val == "[": val = self.__json_array(tokener) - elif val == '{': + elif val == "{": val = self.__json_object(tokener) obj[key] = val tokener.next() tk_split = tokener.cur_token() - if tk_split == ',': + if tk_split == ",": continue - elif tk_split == '}': + elif tk_split == "}": break else: if tk_split is None: @@ -80,20 +82,20 @@ def __json_object(self, tokener): return obj def __json_array(self, tokener): - if tokener.cur_token() != '[': + if tokener.cur_token() != "[": raise Exception('Json array must start with "["') arr = [] while True: tokener.next() tk_temp = tokener.cur_token() - if tk_temp == ']': + if tk_temp == "]": return [] - if tk_temp == '{': + if tk_temp == "{": val = self.__json_object(tokener) - elif tk_temp == '[': + elif tk_temp == "[": val = self.__json_array(tokener) - elif tk_temp in (',', ':', '}'): + elif tk_temp in (",", ":", "}"): raise Exception('unexpected token "%s"' % tk_temp) else: val = tk_temp @@ -101,9 +103,9 @@ def __json_array(self, tokener): tokener.next() tk_end = tokener.cur_token() - if tk_end == ',': + if tk_end == ",": continue - if tk_end == ']': + if tk_end == "]": break else: if tk_end is None: @@ -116,9 +118,9 @@ def decode(self, json_str): return None first_token = tokener.cur_token() - if first_token == '{': + if first_token == "{": decode_val = self.__json_object(tokener) - elif first_token == '[': + elif first_token == "[": decode_val = self.__json_array(tokener) else: raise Exception('Json must start with "{"') @@ -135,7 +137,7 @@ def __init__(self, json_str): def __cur_char(self): if self.__i < len(self.__str): return self.__str[self.__i] - return '' + return "" def __previous_char(self): if self.__i < len(self.__str): @@ -143,26 +145,29 @@ def __previous_char(self): def __remain_str(self): if self.__i < len(self.__str): - return self.__str[self.__i:] + return self.__str[self.__i :] def __move_i(self, step=1): if self.__i < len(self.__str): self.__i += step def __next_string(self): - '''当出现了"和'后就进入这个方法解析,直到出现与之对应的结束字符''' - outstr = '' + """当出现了"和'后就进入这个方法解析,直到出现与之对应的结束字符""" + outstr = "" trans_flag = False start_ch = "" self.__move_i() - while self.__cur_char() != '': + while self.__cur_char() != "": ch = self.__cur_char() - if start_ch == "": start_ch = self.__previous_char() + if start_ch == "": + start_ch = self.__previous_char() if ch == '\\"': # 判断是否是转义 trans_flag = True else: if not trans_flag: - if (ch == '"' and start_ch == '"') or (ch == "'" and start_ch == "'"): + if (ch == '"' and start_ch == '"') or ( + ch == "'" and start_ch == "'" + ): break else: trans_flag = False @@ -171,25 +176,29 @@ def __next_string(self): return outstr def __next_number(self): - expr = '' - while self.__cur_char().isdigit() or self.__cur_char() in ('.', '+', '-'): + expr = "" + while self.__cur_char().isdigit() or self.__cur_char() in (".", "+", "-"): expr += self.__cur_char() self.__move_i() self.__move_i(-1) - if '.' in expr: + if "." in expr: return float(expr) else: return int(expr) def __next_const(self): - '''处理没有被''和""包含的字符,如true和ObjectId''' + """处理没有被''和""包含的字符,如true和ObjectId""" outstr = "" data_type = "" while self.__cur_char().isalpha() or self.__cur_char() in ("$", "_", " "): outstr += self.__cur_char() self.__move_i() if outstr.replace(" ", "") in ( - "ObjectId", "newDate", "ISODate", "newISODate"): # ======类似的类型比较多还需单独处理,如int()等 + "ObjectId", + "newDate", + "ISODate", + "newISODate", + ): # ======类似的类型比较多还需单独处理,如int()等 data_type = outstr for c in self.__remain_str(): outstr += c @@ -199,8 +208,8 @@ def __next_const(self): self.__move_i(-1) - if outstr in ('true', 'false', 'null'): - return {'true': True, 'false': False, 'null': None}[outstr] + if outstr in ("true", "false", "null"): + return {"true": True, "false": False, "null": None}[outstr] elif data_type == "ObjectId": ojStr = re.findall(r"ObjectId\(.*?\)", outstr) # 单独处理ObjectId if len(ojStr) > 0: @@ -208,10 +217,16 @@ def __next_const(self): id_str = re.findall(r"\(.*?\)", ojStr[0]) oid = id_str[0].replace(" ", "")[2:-2] return ObjectId(oid) - elif data_type.replace(" ", "") in ("newDate", "ISODate", "newISODate"): # 处理时间格式 + elif data_type.replace(" ", "") in ( + "newDate", + "ISODate", + "newISODate", + ): # 处理时间格式 tmp_type = "%s()" % data_type if outstr.replace(" ", "") == tmp_type.replace(" ", ""): - return datetime.datetime.now() + datetime.timedelta(hours=-8) # mongodb默认时区为utc + return datetime.datetime.now() + datetime.timedelta( + hours=-8 + ) # mongodb默认时区为utc date_regex = re.compile(r'%s\("(.*)"\)' % data_type, re.IGNORECASE) date_content = date_regex.findall(outstr) if len(date_content) > 0: @@ -221,21 +236,26 @@ def __next_const(self): raise Exception('Invalid symbol "%s"' % outstr) def next(self): - is_white_space = lambda a_char: a_char in ('\x20', '\n', '\r', '\t') # 定义一个匿名函数 + is_white_space = lambda a_char: a_char in ( + "\x20", + "\n", + "\r", + "\t", + ) # 定义一个匿名函数 while is_white_space(self.__cur_char()): self.__move_i() ch = self.__cur_char() - if ch == '': + if ch == "": cur_token = None - elif ch in ('{', '}', '[', ']', ',', ':'): + elif ch in ("{", "}", "[", "]", ",", ":"): cur_token = ch elif ch in ('"', "'"): # 当字符为" ' cur_token = self.__next_string() elif ch.isalpha() or ch in ("$", "_"): # 字符串是否只由字母和"$","_"组成 cur_token = self.__next_const() - elif ch.isdigit() or ch in ('.', '-', '+'): # 检测字符串是否只由数字组成 + elif ch.isdigit() or ch in (".", "-", "+"): # 检测字符串是否只由数字组成 cur_token = self.__next_number() else: raise Exception('Invalid symbol "%s"' % ch) @@ -256,52 +276,84 @@ class MongoEngine(EngineBase): def test_connection(self): return self.get_all_databases() - def exec_cmd(self, sql, db_name=None, slave_ok=''): + def exec_cmd(self, sql, db_name=None, slave_ok=""): """审核时执行的语句""" if self.user and self.password and self.port and self.host: msg = "" - auth_db = self.instance.db_name or 'admin' + auth_db = self.instance.db_name or "admin" sql_len = len(sql) is_load = False # 默认不使用load方法执行mongodb sql语句 try: - if not sql.startswith('var host=') and sql_len > 4000: # 在master节点执行的情况,如果sql长度大于4000,就采取load js的方法 + if ( + not sql.startswith("var host=") and sql_len > 4000 + ): # 在master节点执行的情况,如果sql长度大于4000,就采取load js的方法 # 因为用mongo load方法执行js脚本,所以需要重新改写一下sql,以便回显js执行结果 - sql = 'var result = ' + sql + '\nprintjson(result);' + sql = "var result = " + sql + "\nprintjson(result);" # 因为要知道具体的临时文件位置,所以用了NamedTemporaryFile模块 - fp = tempfile.NamedTemporaryFile(suffix=".js", prefix="mongo_", dir='/tmp/', delete=True) - fp.write(sql.encode('utf-8')) + fp = tempfile.NamedTemporaryFile( + suffix=".js", prefix="mongo_", dir="/tmp/", delete=True + ) + fp.write(sql.encode("utf-8")) fp.seek(0) # 把文件指针指向开始,这样写的sql内容才能落到磁盘文件上 cmd = "{mongo} --quiet -u {uname} -p '{password}' {host}:{port}/{auth_db} <<\\EOF\ndb=db.getSiblingDB(\"{db_name}\");{slave_ok}load('{tempfile_}')\nEOF".format( - mongo=mongo, uname=self.user, password=self.password, host=self.host, port=self.port, - db_name=db_name, sql=sql, auth_db=auth_db, slave_ok=slave_ok, tempfile_=fp.name) + mongo=mongo, + uname=self.user, + password=self.password, + host=self.host, + port=self.port, + db_name=db_name, + sql=sql, + auth_db=auth_db, + slave_ok=slave_ok, + tempfile_=fp.name, + ) is_load = True # 标记使用了load方法,用来在finally里面判断是否需要强制删除临时文件 - elif not sql.startswith( - 'var host=') and sql_len < 4000: # 在master节点执行的情况, 如果sql长度小于4000,就直接用mongo shell执行,减少磁盘交换,节省性能 + elif ( + not sql.startswith("var host=") and sql_len < 4000 + ): # 在master节点执行的情况, 如果sql长度小于4000,就直接用mongo shell执行,减少磁盘交换,节省性能 cmd = "{mongo} --quiet -u {uname} -p '{password}' {host}:{port}/{auth_db} <<\\EOF\ndb=db.getSiblingDB(\"{db_name}\");{slave_ok}printjson({sql})\nEOF".format( - mongo=mongo, uname=self.user, password=self.password, host=self.host, port=self.port, - db_name=db_name, sql=sql, auth_db=auth_db, slave_ok=slave_ok) + mongo=mongo, + uname=self.user, + password=self.password, + host=self.host, + port=self.port, + db_name=db_name, + sql=sql, + auth_db=auth_db, + slave_ok=slave_ok, + ) else: cmd = "{mongo} --quiet -u {user} -p '{password}' {host}:{port}/{auth_db} <<\\EOF\nrs.slaveOk();{sql}\nEOF".format( - mongo=mongo, user=self.user, password=self.password, host=self.host, port=self.port, - db_name=db_name, sql=sql, auth_db=auth_db) - p = subprocess.Popen(cmd, shell=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - universal_newlines=True) + mongo=mongo, + user=self.user, + password=self.password, + host=self.host, + port=self.port, + db_name=db_name, + sql=sql, + auth_db=auth_db, + ) + p = subprocess.Popen( + cmd, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + ) re_msg = [] - for line in iter(p.stdout.read, ''): + for line in iter(p.stdout.read, ""): re_msg.append(line) # 因为返回的line中也有可能带有换行符,因此需要先全部转换成字符串 - __msg = '\n'.join(re_msg) + __msg = "\n".join(re_msg) _re_msg = [] - for _line in __msg.split('\n'): - if not _re_msg and re.match('WARNING.*', _line): + for _line in __msg.split("\n"): + if not _re_msg and re.match("WARNING.*", _line): # 第一行可能是WARNING语句,因此跳过 continue _re_msg.append(_line) - msg = '\n'.join(_re_msg) + msg = "\n".join(_re_msg) except Exception as e: logger.warning(f"mongo语句执行报错,语句:{sql},{e}错误信息{traceback.format_exc()}") finally: @@ -314,8 +366,8 @@ def get_master(self): sql = "rs.isMaster().primary" master = self.exec_cmd(sql) - if master != 'undefined' and master.find("TypeError") >= 0: - sp_host = master.replace("\"", "").split(":") + if master != "undefined" and master.find("TypeError") >= 0: + sp_host = master.replace('"', "").split(":") self.host = sp_host[0] self.port = int(sp_host[1]) # return master @@ -323,11 +375,11 @@ def get_master(self): def get_slave(self): """获得从节点的port和host""" - sql = '''var host=""; rs.status().members.forEach(function(item) {i=1; if (item.stateStr =="SECONDARY") \ - {host=item.name } }); print(host);''' + sql = """var host=""; rs.status().members.forEach(function(item) {i=1; if (item.stateStr =="SECONDARY") \ + {host=item.name } }); print(host);""" slave_msg = self.exec_cmd(sql) - if slave_msg.lower().find('undefined') < 0: - sp_host = slave_msg.replace("\"", "").split(":") + if slave_msg.lower().find("undefined") < 0: + sp_host = slave_msg.replace('"', "").split(":") self.host = sp_host[0] self.port = int(sp_host[1]) return True @@ -339,7 +391,7 @@ def get_table_conut(self, table_name, db_name): count_sql = f"db.{table_name}.count()" status = self.get_slave() # 查询总数据要求在slave节点执行 if self.host and self.port and status: - count = int(self.exec_cmd(count_sql, db_name, slave_ok='rs.slaveOk();')) + count = int(self.exec_cmd(count_sql, db_name, slave_ok="rs.slaveOk();")) else: count = int(self.exec_cmd(count_sql, db_name)) return count @@ -349,9 +401,11 @@ def get_table_conut(self, table_name, db_name): def execute_workflow(self, workflow): """执行上线单,返回Review set""" - return self.execute(db_name=workflow.db_name, sql=workflow.sqlworkflowcontent.sql_content) + return self.execute( + db_name=workflow.db_name, sql=workflow.sqlworkflowcontent.sql_content + ) - def execute(self, db_name=None, sql=''): + def execute(self, db_name=None, sql=""): """mongo命令执行语句""" self.get_master() execute_result = ReviewSet(full_sql=sql) @@ -360,7 +414,7 @@ def execute(self, db_name=None, sql=''): sp_sql = sql.split(";") line = 0 for exec_sql in sp_sql: - if not exec_sql == '': + if not exec_sql == "": exec_sql = exec_sql.strip() try: # DeprecationWarning: time.clock has been deprecated in Python 3.3 and will be removed from Python 3.8: use time.perf_counter or time.process_time instead @@ -370,162 +424,252 @@ def execute(self, db_name=None, sql=''): line += 1 logger.debug("执行结果:" + r) # 如果执行中有错误 - rz = r.replace(' ', '').replace('"', '').lower() + rz = r.replace(" ", "").replace('"', "").lower() tr = 1 - if r.lower().find("syntaxerror") >= 0 or rz.find('ok:0') >= 0 or rz.find( - "error:invalid") >= 0 or rz.find("ReferenceError") >= 0 \ - or rz.find("getErrorWithCode") >= 0 or rz.find("failedtoconnect") >= 0 or rz.find( - "Error: field") >= 0: + if ( + r.lower().find("syntaxerror") >= 0 + or rz.find("ok:0") >= 0 + or rz.find("error:invalid") >= 0 + or rz.find("ReferenceError") >= 0 + or rz.find("getErrorWithCode") >= 0 + or rz.find("failedtoconnect") >= 0 + or rz.find("Error: field") >= 0 + ): tr = 0 - if (rz.find("errmsg") >= 0 or tr == 0) and (r.lower().find("already exist") < 0): + if (rz.find("errmsg") >= 0 or tr == 0) and ( + r.lower().find("already exist") < 0 + ): execute_result.error = r result = ReviewResult( id=line, - stage='Execute failed', + stage="Execute failed", errlevel=2, - stagestatus='异常终止', - errormessage=f'mongo语句执行报错: {r}', - sql=exec_sql) + stagestatus="异常终止", + errormessage=f"mongo语句执行报错: {r}", + sql=exec_sql, + ) else: # 把结果转换为ReviewSet result = ReviewResult( - id=line, errlevel=0, - stagestatus='执行结束', + id=line, + errlevel=0, + stagestatus="执行结束", errormessage=r, execute_time=round(end - start, 6), actual_affected_rows=0, # todo============这个值需要优化 - sql=exec_sql) + sql=exec_sql, + ) execute_result.rows += [result] except Exception as e: - logger.warning(f"mongo语句执行报错,语句:{exec_sql},错误信息{traceback.format_exc()}") + logger.warning( + f"mongo语句执行报错,语句:{exec_sql},错误信息{traceback.format_exc()}" + ) execute_result.error = str(e) # result_set.column_list = [i[0] for i in fields] if fields else [] return execute_result - def execute_check(self, db_name=None, sql=''): + def execute_check(self, db_name=None, sql=""): """上线单执行前的检查, 返回Review set""" line = 1 count = 0 check_result = ReviewSet(full_sql=sql) sql = sql.strip() - if (sql.find(";") < 0): + if sql.find(";") < 0: raise Exception("提交的语句请以分号结尾") # 以;切分语句,逐句执行 sp_sql = sql.split(";") # 执行语句 for check_sql in sp_sql: - alert = '' # 警告信息 - if not check_sql == '' and check_sql != '\n': + alert = "" # 警告信息 + if not check_sql == "" and check_sql != "\n": check_sql = check_sql.strip() # check_sql = f'''{check_sql}''' # check_sql = check_sql.replace('\n', '') #处理成一行 # 支持的命令列表 - supportMethodList = ["explain", "bulkWrite", "convertToCapped", "createIndex", "createIndexes", - "deleteOne", - "deleteMany", "drop", "dropIndex", "dropIndexes", "ensureIndex", "insert", - "insertOne", - "insertMany", "remove", "replaceOne", "renameCollection", "update", "updateOne", - "updateMany", "createCollection", "renameCollection"] + supportMethodList = [ + "explain", + "bulkWrite", + "convertToCapped", + "createIndex", + "createIndexes", + "deleteOne", + "deleteMany", + "drop", + "dropIndex", + "dropIndexes", + "ensureIndex", + "insert", + "insertOne", + "insertMany", + "remove", + "replaceOne", + "renameCollection", + "update", + "updateOne", + "updateMany", + "createCollection", + "renameCollection", + ] # 需要有表存在为前提的操作 - is_exist_premise_method = ["convertToCapped", "deleteOne", "deleteMany", "drop", "dropIndex", - "dropIndexes", - "remove", "replaceOne", "renameCollection", "update", "updateOne", - "updateMany", "renameCollection"] + is_exist_premise_method = [ + "convertToCapped", + "deleteOne", + "deleteMany", + "drop", + "dropIndex", + "dropIndexes", + "remove", + "replaceOne", + "renameCollection", + "update", + "updateOne", + "updateMany", + "renameCollection", + ] pattern = re.compile( - r'''^db\.createCollection\(([\s\S]*)\)$|^db\.([\w\.-]+)\.(?:[A-Za-z]+)(?:\([\s\S]*\)$)|^db\.getCollection\((?:\s*)(?:'|")([\w-]*)('|")(\s*)\)\.([A-Za-z]+)(\([\s\S]*\)$)''') + r"""^db\.createCollection\(([\s\S]*)\)$|^db\.([\w\.-]+)\.(?:[A-Za-z]+)(?:\([\s\S]*\)$)|^db\.getCollection\((?:\s*)(?:'|")([\w-]*)('|")(\s*)\)\.([A-Za-z]+)(\([\s\S]*\)$)""" + ) m = pattern.match(check_sql) - if m is not None and (re.search(re.compile(r'}(?:\s*){'), check_sql) is None) and check_sql.count( - '{') == check_sql.count('}') and check_sql.count('(') == check_sql.count(')'): + if ( + m is not None + and (re.search(re.compile(r"}(?:\s*){"), check_sql) is None) + and check_sql.count("{") == check_sql.count("}") + and check_sql.count("(") == check_sql.count(")") + ): sql_str = m.group() - table_name = (m.group(1) or m.group(2) or m.group(3)).strip() # 通过正则的组拿到表名 - table_name = table_name.replace('"', '').replace("'", "") + table_name = ( + m.group(1) or m.group(2) or m.group(3) + ).strip() # 通过正则的组拿到表名 + table_name = table_name.replace('"', "").replace("'", "") table_names = self.get_all_tables(db_name).rows is_in = table_name in table_names # 检查表是否存在 if not is_in: alert = f"\n提示:{table_name}文档不存在!" if sql_str: count = 0 - if sql_str.find('createCollection') > 0: # 如果是db.createCollection() + if ( + sql_str.find("createCollection") > 0 + ): # 如果是db.createCollection() methodStr = "createCollection" alert = "" if is_in: check_result.error = "文档已经存在" - result = ReviewResult(id=line, errlevel=2, - stagestatus='文档已经存在', - errormessage='文档已经存在!', - affected_rows=count, - sql=check_sql) + result = ReviewResult( + id=line, + errlevel=2, + stagestatus="文档已经存在", + errormessage="文档已经存在!", + affected_rows=count, + sql=check_sql, + ) check_result.rows += [result] continue else: # method = sql_str.split('.')[2] # methodStr = method.split('(')[0].strip() - methodStr = sql_str.split('(')[0].split('.')[-1].strip() # 最后一个.和括号(之间的字符串作为方法 + methodStr = ( + sql_str.split("(")[0].split(".")[-1].strip() + ) # 最后一个.和括号(之间的字符串作为方法 if methodStr in is_exist_premise_method and not is_in: check_result.error = "文档不存在" - result = ReviewResult(id=line, errlevel=2, - stagestatus='文档不存在', - errormessage=f'文档不存在,不能进行{methodStr}操作!', - sql=check_sql) + result = ReviewResult( + id=line, + errlevel=2, + stagestatus="文档不存在", + errormessage=f"文档不存在,不能进行{methodStr}操作!", + sql=check_sql, + ) check_result.rows += [result] continue if methodStr in supportMethodList: # 检查方法是否支持 - if methodStr == "createIndex" or methodStr == "createIndexes" or methodStr == "ensureIndex": # 判断是否创建索引,如果大于500万,提醒不能在高峰期创建 + if ( + methodStr == "createIndex" + or methodStr == "createIndexes" + or methodStr == "ensureIndex" + ): # 判断是否创建索引,如果大于500万,提醒不能在高峰期创建 p_back = re.compile( - r'''(['"])(?:(?!\1)background)\1(?:\s*):(?:\s*)true|background\s*:\s*true|(['"])(?:(?!\1)background)\1(?:\s*):(?:\s*)(['"])(?:(?!\2)true)\2''', - re.M) + r"""(['"])(?:(?!\1)background)\1(?:\s*):(?:\s*)true|background\s*:\s*true|(['"])(?:(?!\1)background)\1(?:\s*):(?:\s*)(['"])(?:(?!\2)true)\2""", + re.M, + ) m_back = re.search(p_back, check_sql) if m_back is None: count = 5555555 - check_result.warning = '创建索引请加background:true' + check_result.warning = "创建索引请加background:true" check_result.warning_count += 1 - result = ReviewResult(id=line, errlevel=2, - stagestatus='后台创建索引', - errormessage='创建索引没有加 background:true' + alert, - sql=check_sql) + result = ReviewResult( + id=line, + errlevel=2, + stagestatus="后台创建索引", + errormessage="创建索引没有加 background:true" + alert, + sql=check_sql, + ) elif not is_in: count = 0 else: - count = self.get_table_conut(table_name, db_name) # 获得表的总条数 + count = self.get_table_conut( + table_name, db_name + ) # 获得表的总条数 if count >= 5000000: - check_result.warning = alert + '大于500万条,请在业务低谷期创建索引' + check_result.warning = ( + alert + "大于500万条,请在业务低谷期创建索引" + ) check_result.warning_count += 1 - result = ReviewResult(id=line, errlevel=1, - stagestatus='大表创建索引', - errormessage='大于500万条,请在业务低谷期创建索引!', - affected_rows=count, - sql=check_sql) + result = ReviewResult( + id=line, + errlevel=1, + stagestatus="大表创建索引", + errormessage="大于500万条,请在业务低谷期创建索引!", + affected_rows=count, + sql=check_sql, + ) if count < 5000000: # 检测通过 - affected_all_row_method = ["drop", "dropIndex", "dropIndexes", "createIndex", - "createIndexes", "ensureIndex"] + affected_all_row_method = [ + "drop", + "dropIndex", + "dropIndexes", + "createIndex", + "createIndexes", + "ensureIndex", + ] if methodStr not in affected_all_row_method: count = 0 else: - count = self.get_table_conut(table_name, db_name) # 获得表的总条数 - result = ReviewResult(id=line, errlevel=0, - stagestatus='Audit completed', - errormessage='检测通过', - affected_rows=count, - sql=check_sql, - execute_time=0) + count = self.get_table_conut( + table_name, db_name + ) # 获得表的总条数 + result = ReviewResult( + id=line, + errlevel=0, + stagestatus="Audit completed", + errormessage="检测通过", + affected_rows=count, + sql=check_sql, + execute_time=0, + ) else: - result = ReviewResult(id=line, errlevel=2, - stagestatus='驳回不支持语句', - errormessage='仅支持DML和DDL语句,如需查询请使用数据库查询功能!', - sql=check_sql) + result = ReviewResult( + id=line, + errlevel=2, + stagestatus="驳回不支持语句", + errormessage="仅支持DML和DDL语句,如需查询请使用数据库查询功能!", + sql=check_sql, + ) else: check_result.error = "语法错误" - result = ReviewResult(id=line, errlevel=2, - stagestatus='语法错误', - errormessage='请检查语句的正确性或(){} },{是否正确匹配!', - sql=check_sql) + result = ReviewResult( + id=line, + errlevel=2, + stagestatus="语法错误", + errormessage="请检查语句的正确性或(){} },{是否正确匹配!", + sql=check_sql, + ) check_result.rows += [result] line += 1 count = 0 - check_result.column_list = ['Result'] # 审核结果的列名 + check_result.column_list = ["Result"] # 审核结果的列名 check_result.checked = True check_result.warning = self.warning # 统计警告和错误数量 @@ -537,10 +681,15 @@ def execute_check(self, db_name=None, sql=''): return check_result def get_connection(self, db_name=None): - self.db_name = db_name or self.instance.db_name or 'admin' - auth_db = self.instance.db_name or 'admin' - self.conn = pymongo.MongoClient(self.host, self.port, authSource=auth_db, connect=True, - connectTimeoutMS=10000) + self.db_name = db_name or self.instance.db_name or "admin" + auth_db = self.instance.db_name or "admin" + self.conn = pymongo.MongoClient( + self.host, + self.port, + authSource=auth_db, + connect=True, + connectTimeoutMS=10000, + ) if self.user and self.password: self.conn[self.db_name].authenticate(self.user, self.password, auth_db) return self.conn @@ -552,11 +701,11 @@ def close(self): @property def name(self): # pragma: no cover - return 'Mongo' + return "Mongo" @property def info(self): # pragma: no cover - return 'Mongo engine' + return "Mongo engine" def get_roles(self): sql_get_roles = "db.system.roles.find({},{_id:1})" @@ -605,14 +754,19 @@ def get_all_columns_by_tb(self, db_name, tb_name, **kwargs): for prop in document: if prop not in columns: columns.append(prop) - result.column_list = ['COLUMN_NAME'] + result.column_list = ["COLUMN_NAME"] result.rows = columns return result def describe_table(self, db_name, tb_name, **kwargs): """return ResultSet 类似查询""" result = self.get_all_columns_by_tb(db_name=db_name, tb_name=tb_name) - result.rows = [[[r], ] for r in result.rows] + result.rows = [ + [ + [r], + ] + for r in result.rows + ] return result @staticmethod @@ -625,7 +779,7 @@ def dispose_str(parse_sql, start_flag, index): return index index += 1 stop_flag = start_flag - raise Exception('near column %s,\' or \" has no close' % index) + raise Exception("near column %s,' or \" has no close" % index) def dispose_pair(self, parse_sql, index, begin, end): """解析处理需要配对的字符{}[]() 检索一个左括号计数器加1,右括号计数器减1""" @@ -648,7 +802,9 @@ def dispose_pair(self, parse_sql, index, begin, end): index = self.dispose_str(parse_sql, char, index) index += 1 if count > 0: - raise Exception("near column %s, The symbol %s has no closed" % (index, begin)) + raise Exception( + "near column %s, The symbol %s has no closed" % (index, begin) + ) re_char = parse_sql[start_pos:stop_pos] # 截取 return index, re_char @@ -686,7 +842,9 @@ def parse_query_sentence(self, parse_sql): pipeline = [] agg_index = 0 while agg_index < len(re_char): - p_index, condition = self.dispose_pair(re_char, agg_index, "{", "}") + p_index, condition = self.dispose_pair( + re_char, agg_index, "{", "}" + ) agg_index = p_index + 1 if condition: de = JsonDecoder() @@ -700,7 +858,7 @@ def parse_query_sentence(self, parse_sql): query_dict["condition"] = pipeline query_dict["method"] = method elif method.lower() == "getcollection": # 获得表名 - collection = re_char.strip().replace("'", "").replace('"', '') + collection = re_char.strip().replace("'", "").replace('"', "") query_dict["collection"] = collection elif method.lower() == "getindexes": query_dict["method"] = "index_information" @@ -712,7 +870,7 @@ def parse_query_sentence(self, parse_sql): if query_dict: return query_dict - def filter_sql(self, sql='', limit_num=0): + def filter_sql(self, sql="", limit_num=0): """给查询语句改写语句, 返回修改后的语句""" sql = sql.split(";")[0].strip() # 执行计划 @@ -720,37 +878,38 @@ def filter_sql(self, sql='', limit_num=0): sql = sql.replace("explain", "") + ".explain()" return sql.strip() - def query_check(self, db_name=None, sql=''): + def query_check(self, db_name=None, sql=""): """提交查询前的检查""" sql = sql.strip() if sql.startswith("explain"): sql = sql[7:] + ".explain()" sql = re.sub("[;\s]*.explain\(\)$", ".explain()", sql).strip() - result = {'msg': '', 'bad_query': False, 'filtered_sql': sql, 'has_star': False} + result = {"msg": "", "bad_query": False, "filtered_sql": sql, "has_star": False} pattern = re.compile( - r'''^db\.(\w+\.?)+(?:\([\s\S]*\)(\s*;*)$)|^db\.getCollection\((?:\s*)(?:'|")(\w+\.?)+('|")(\s*)\)\.([A-Za-z]+)(\([\s\S]*\)(\s*;*)$)''') + r"""^db\.(\w+\.?)+(?:\([\s\S]*\)(\s*;*)$)|^db\.getCollection\((?:\s*)(?:'|")(\w+\.?)+('|")(\s*)\)\.([A-Za-z]+)(\([\s\S]*\)(\s*;*)$)""" + ) m = pattern.match(sql) if m is not None: logger.debug(sql) query_dict = self.parse_query_sentence(sql) if "method" not in query_dict: - result['msg'] += "错误:对不起,只支持查询相关方法" - result['bad_query'] = True + result["msg"] += "错误:对不起,只支持查询相关方法" + result["bad_query"] = True return result collection_name = query_dict["collection"] collection_names = self.get_all_tables(db_name).rows is_in = collection_name in collection_names # 检查表是否存在 if not is_in: - result['msg'] += f"\n错误: {collection_name} 文档不存在!" - result['bad_query'] = True + result["msg"] += f"\n错误: {collection_name} 文档不存在!" + result["bad_query"] = True return result else: - result['msg'] += '请检查语句的正确性! 请使用原生查询语句' - result['bad_query'] = True + result["msg"] += "请检查语句的正确性! 请使用原生查询语句" + result["bad_query"] = True return result - def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs): + def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs): """执行查询""" result_set = ResultSet(full_sql=sql) @@ -781,7 +940,11 @@ def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs): for k, v in de.decode(query_dict["sort"]).items(): sorting.append((k, v)) find_cmd += ".sort(sorting)" - if method == "find" and "limit" not in query_dict and "explain" not in query_dict: + if ( + method == "find" + and "limit" not in query_dict + and "explain" not in query_dict + ): find_cmd += ".limit(limit_num)" if "limit" in query_dict and query_dict["limit"]: query_limit = int(query_dict["limit"]) @@ -820,7 +983,9 @@ def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs): row = [] columns.insert(0, "mongodballdata") for ro in cursor: - json_col = json.dumps(ro, ensure_ascii=False, indent=2, separators=(",", ":")) + json_col = json.dumps( + ro, ensure_ascii=False, indent=2, separators=(",", ":") + ) row.insert(0, json_col) for k, v in ro.items(): if k not in columns: @@ -832,7 +997,7 @@ def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs): result_set.rows = rows else: cursor = json.loads(json_util.dumps(cursor)) - cols = projection if 'projection' in dir() else None + cols = projection if "projection" in dir() else None rows, columns = self.parse_tuple(cursor, db_name, collection_name, cols) result_set.rows = rows result_set.column_list = columns @@ -840,7 +1005,9 @@ def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs): if isinstance(rows, list): logger.debug(rows) result_set.rows = tuple( - [json.dumps(x, ensure_ascii=False, indent=2, separators=(",", ":"))] for x in rows) + [json.dumps(x, ensure_ascii=False, indent=2, separators=(",", ":"))] + for x in rows + ) except Exception as e: logger.warning(f"Mongo命令执行报错,语句:{sql}, 错误信息:{traceback.format_exc()}") @@ -865,7 +1032,9 @@ def parse_tuple(self, cursor, db_name, tb_name, projection=None): columns = self.fill_query_columns(cursor, columns) for ro in cursor: - json_col = json.dumps(ro, ensure_ascii=False, indent=2, separators=(",", ":")) + json_col = json.dumps( + ro, ensure_ascii=False, indent=2, separators=(",", ":") + ) row.insert(0, json_col) for key in columns[1:]: if key in ro: @@ -877,13 +1046,17 @@ def parse_tuple(self, cursor, db_name, tb_name, projection=None): # 转换$oid ff = re.findall(re_oid, str(value)) for ii in ff: - value = str(value).replace(ii, "ObjectId(" + ii.split(":")[1].strip()[:-1] + ")") + value = str(value).replace( + ii, "ObjectId(" + ii.split(":")[1].strip()[:-1] + ")" + ) # 转换时间戳$date dd = re.findall(re_date, str(value)) for d in dd: t = int(d.split(":")[1].strip()[:-1]) e = datetime.fromtimestamp(t / 1000) - value = str(value).replace(d, e.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]) + value = str(value).replace( + d, e.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] + ) row.append(str(value)) else: row.append("(N/A)") @@ -904,47 +1077,60 @@ def fill_query_columns(cursor, columns): def current_op(self, command_type): """ 获取当前连接信息 - + command_type: - Full 包含活跃与不活跃的连接,包含内部的连接,即全部的连接状态 + Full 包含活跃与不活跃的连接,包含内部的连接,即全部的连接状态 All 包含活跃与不活跃的连接,不包含内部的连接 Active 包含活跃 Inner 内部连接 """ - result_set = ResultSet(full_sql='db.aggregate([{"$currentOp": {"allUsers":true, "idleConnections":true}}])') + result_set = ResultSet( + full_sql='db.aggregate([{"$currentOp": {"allUsers":true, "idleConnections":true}}])' + ) try: conn = self.get_connection() processlists = [] if not command_type: - command_type = 'Active' - if command_type in ['Full', 'All', 'Inner']: + command_type = "Active" + if command_type in ["Full", "All", "Inner"]: idle_connections = True else: idle_connections = False # conn.admin.current_op() 这个方法已经被pymongo废除,但mongodb3.6+才支持aggregate with conn.admin.aggregate( - [{'$currentOp': {'allUsers': True, 'idleConnections': idle_connections}}]) as cursor: + [ + { + "$currentOp": { + "allUsers": True, + "idleConnections": idle_connections, + } + } + ] + ) as cursor: for operation in cursor: # 对sharding集群的特殊处理 - if 'client' not in operation and \ - operation.get('clientMetadata', {}).get('mongos', {}).get('client', {}): - operation['client'] = operation['clientMetadata']['mongos']['client'] + if "client" not in operation and operation.get( + "clientMetadata", {} + ).get("mongos", {}).get("client", {}): + operation["client"] = operation["clientMetadata"]["mongos"][ + "client" + ] # client_s 只是处理的mongos,并不是实际客户端 # client 在sharding获取不到? - if command_type in ['Full']: + if command_type in ["Full"]: processlists.append(operation) - elif command_type in ['All', 'Active']: - if 'clientMetadata' in operation: + elif command_type in ["All", "Active"]: + if "clientMetadata" in operation: processlists.append(operation) - elif command_type in ['Inner']: - if not 'clientMetadata' in operation: + elif command_type in ["Inner"]: + if not "clientMetadata" in operation: processlists.append(operation) result_set.rows = processlists except Exception as e: - logger.warning(f'mongodb获取连接信息错误,错误信息{traceback.format_exc()}') + logger.warning(f"mongodb获取连接信息错误,错误信息{traceback.format_exc()}") result_set.error = str(e) return result_set @@ -953,15 +1139,17 @@ def get_kill_command(self, opids): """由传入的opid列表生成kill字符串""" conn = self.get_connection() active_opid = [] - with conn.admin.aggregate([{'$currentOp': {'allUsers': True, 'idleConnections': False}}]) as cursor: + with conn.admin.aggregate( + [{"$currentOp": {"allUsers": True, "idleConnections": False}}] + ) as cursor: for operation in cursor: - if 'opid' in operation and operation['opid'] in opids: - active_opid.append(operation['opid']) + if "opid" in operation and operation["opid"] in opids: + active_opid.append(operation["opid"]) - kill_command = '' + kill_command = "" for opid in active_opid: if isinstance(opid, int): - kill_command = kill_command + 'db.killOp({});'.format(opid) + kill_command = kill_command + "db.killOp({});".format(opid) else: kill_command = kill_command + 'db.killOp("{}");'.format(opid) @@ -974,11 +1162,13 @@ def kill_op(self, opids): conn = self.get_connection() db = conn.admin for opid in opids: - conn.admin.command({'killOp': 1, 'op': opid}) + conn.admin.command({"killOp": 1, "op": opid}) except Exception as e: try: - sql = {'killOp': 1, 'op': _opid} + sql = {"killOp": 1, "op": _opid} except: - sql = {'killOp': 1, 'op': ''} - logger.warning(f"mongodb语句执行killOp报错,语句:db.runCommand({sql}) ,错误信息{traceback.format_exc()}") + sql = {"killOp": 1, "op": ""} + logger.warning( + f"mongodb语句执行killOp报错,语句:db.runCommand({sql}) ,错误信息{traceback.format_exc()}" + ) result.error = str(e) diff --git a/sql/engines/mssql.py b/sql/engines/mssql.py index af7af070ee..712ba07224 100644 --- a/sql/engines/mssql.py +++ b/sql/engines/mssql.py @@ -9,7 +9,7 @@ from .models import ResultSet, ReviewSet, ReviewResult from sql.utils.data_masking import brute_mask -logger = logging.getLogger('default') +logger = logging.getLogger("default") class MssqlEngine(EngineBase): @@ -17,8 +17,13 @@ class MssqlEngine(EngineBase): def get_connection(self, db_name=None): connstr = """DRIVER=ODBC Driver 17 for SQL Server;SERVER={0},{1};UID={2};PWD={3}; -client charset = UTF-8;connect timeout=10;CHARSET={4};""".format(self.host, self.port, self.user, self.password, - self.instance.charset or 'UTF8') +client charset = UTF-8;connect timeout=10;CHARSET={4};""".format( + self.host, + self.port, + self.user, + self.password, + self.instance.charset or "UTF8", + ) if self.conn: return self.conn self.conn = pyodbc.connect(connstr) @@ -28,8 +33,11 @@ def get_all_databases(self): """获取数据库列表, 返回一个ResultSet""" sql = "SELECT name FROM master.sys.databases order by name" result = self.query(sql=sql) - db_list = [row[0] for row in result.rows - if row[0] not in ('master', 'msdb', 'tempdb', 'model')] + db_list = [ + row[0] + for row in result.rows + if row[0] not in ("master", "msdb", "tempdb", "model") + ] result.rows = db_list return result @@ -37,9 +45,11 @@ def get_all_tables(self, db_name, **kwargs): """获取table 列表, 返回一个ResultSet""" sql = """SELECT TABLE_NAME FROM {0}.INFORMATION_SCHEMA.TABLES - WHERE TABLE_TYPE = 'BASE TABLE' order by TABLE_NAME;""".format(db_name) + WHERE TABLE_TYPE = 'BASE TABLE' order by TABLE_NAME;""".format( + db_name + ) result = self.query(db_name=db_name, sql=sql) - tb_list = [row[0] for row in result.rows if row[0] not in ['test']] + tb_list = [row[0] for row in result.rows if row[0] not in ["test"]] result.rows = tb_list return result @@ -120,7 +130,7 @@ def get_table_meta_data(self, db_name, tb_name, **kwargs): ON index_size.table_name = space.table_name; """ _meta_data = self.query(db_name, sql) - return {'column_list': _meta_data.column_list, 'rows': _meta_data.rows[0]} + return {"column_list": _meta_data.column_list, "rows": _meta_data.rows[0]} def get_table_desc_data(self, db_name, tb_name, **kwargs): """获取表格字段信息""" @@ -132,7 +142,7 @@ def get_table_desc_data(self, db_name, tb_name, **kwargs): COLUMN_DEFAULT 默认值 from INFORMATION_SCHEMA.columns where TABLE_CATALOG='{db_name}' and TABLE_NAME = '{tb_name}';""" _desc_data = self.query(db_name, sql) - return {'column_list': _desc_data.column_list, 'rows': _desc_data.rows} + return {"column_list": _desc_data.column_list, "rows": _desc_data.rows} def get_table_index_data(self, db_name, tb_name, **kwargs): """获取表格索引信息""" @@ -145,7 +155,7 @@ def get_table_index_data(self, db_name, tb_name, **kwargs): WHERE i.object_id = OBJECT_ID('{tb_name}') group by i.name,i.object_id,i.index_id,is_unique,is_primary_key;""" _index_data = self.query(db_name, sql) - return {'column_list': _index_data.column_list, 'rows': _index_data.rows} + return {"column_list": _index_data.column_list, "rows": _index_data.rows} def get_tables_metas_data(self, db_name, **kwargs): """获取数据库所有表格信息,用作数据字典导出接口""" @@ -165,11 +175,15 @@ def get_tables_metas_data(self, db_name, **kwargs): table_metas = [] for tb in tbs: _meta = dict() - engine_keys = [{"key": "COLUMN_NAME", "value": "字段名"}, {"key": "COLUMN_TYPE", "value": "数据类型"}, - {"key": "COLLATION_NAME", "value": "列字符集"}, {"key": "IS_NULLABLE", "value": "允许非空"}, - {"key": "COLUMN_DEFAULT", "value": "默认值"}] + engine_keys = [ + {"key": "COLUMN_NAME", "value": "字段名"}, + {"key": "COLUMN_TYPE", "value": "数据类型"}, + {"key": "COLLATION_NAME", "value": "列字符集"}, + {"key": "IS_NULLABLE", "value": "允许非空"}, + {"key": "COLUMN_DEFAULT", "value": "默认值"}, + ] _meta["ENGINE_KEYS"] = engine_keys - _meta['TABLE_INFO'] = tb + _meta["TABLE_INFO"] = tb sql_cols = f"""select COLUMN_NAME, case when ISNUMERIC(CHARACTER_MAXIMUM_LENGTH)=1 then DATA_TYPE + '(' + convert(varchar(max), CHARACTER_MAXIMUM_LENGTH) + ')' else DATA_TYPE end COLUMN_TYPE, COLLATION_NAME, @@ -182,7 +196,7 @@ def get_tables_metas_data(self, db_name, **kwargs): # 转换查询结果为dict for row in query_result.rows: columns.append(dict(zip(query_result.column_list, row))) - _meta['COLUMNS'] = tuple(columns) + _meta["COLUMNS"] = tuple(columns) table_metas.append(_meta) return table_metas @@ -211,64 +225,89 @@ def describe_table(self, db_name, tb_name, **kwargs): left join {0}..sysindexkeys i on i.id=o.id and i.colid=c.colid and i.indid=ie.indid WHERE O.name NOT LIKE 'MS%' AND O.name NOT LIKE 'SY%' and O.name='{1}' - order by o.name,c.colid""".format(db_name, tb_name) + order by o.name,c.colid""".format( + db_name, tb_name + ) result = self.query(sql=sql) return result - def query_check(self, db_name=None, sql=''): + def query_check(self, db_name=None, sql=""): # 查询语句的检查、注释去除、切分 - result = {'msg': '', 'bad_query': False, 'filtered_sql': sql, 'has_star': False} - banned_keywords = ["ascii", "char", "charindex", "concat", "concat_ws", "difference", "format", - "len", "nchar", "patindex", "quotename", "replace", "replicate", - "reverse", "right", "soundex", "space", "str", "string_agg", - "string_escape", "string_split", "stuff", "substring", "trim", "unicode"] - keyword_warning = '' + result = {"msg": "", "bad_query": False, "filtered_sql": sql, "has_star": False} + banned_keywords = [ + "ascii", + "char", + "charindex", + "concat", + "concat_ws", + "difference", + "format", + "len", + "nchar", + "patindex", + "quotename", + "replace", + "replicate", + "reverse", + "right", + "soundex", + "space", + "str", + "string_agg", + "string_escape", + "string_split", + "stuff", + "substring", + "trim", + "unicode", + ] + keyword_warning = "" star_patter = r"(^|,|\s)\*(\s|\(|$)" - sql_whitelist = ['select', 'sp_helptext'] + sql_whitelist = ["select", "sp_helptext"] # 根据白名单list拼接pattern语句 whitelist_pattern = "^" + "|^".join(sql_whitelist) # 删除注释语句,进行语法判断,执行第一条有效sql try: sql = sql.format(sql, strip_comments=True) sql = sqlparse.split(sql)[0] - result['filtered_sql'] = sql.strip() + result["filtered_sql"] = sql.strip() sql_lower = sql.lower() except IndexError: - result['bad_query'] = True - result['msg'] = '没有有效的SQL语句' + result["bad_query"] = True + result["msg"] = "没有有效的SQL语句" return result if re.match(whitelist_pattern, sql_lower) is None: - result['bad_query'] = True - result['msg'] = '仅支持{}语法!'.format(','.join(sql_whitelist)) + result["bad_query"] = True + result["msg"] = "仅支持{}语法!".format(",".join(sql_whitelist)) return result if re.search(star_patter, sql_lower) is not None: - keyword_warning += '禁止使用 * 关键词\n' - result['has_star'] = True + keyword_warning += "禁止使用 * 关键词\n" + result["has_star"] = True for keyword in banned_keywords: pattern = r"(^|,| |=){}( |\(|$)".format(keyword) if re.search(pattern, sql_lower) is not None: - keyword_warning += '禁止使用 {} 关键词\n'.format(keyword) - result['bad_query'] = True - if result.get('bad_query') or result.get('has_star'): - result['msg'] = keyword_warning + keyword_warning += "禁止使用 {} 关键词\n".format(keyword) + result["bad_query"] = True + if result.get("bad_query") or result.get("has_star"): + result["msg"] = keyword_warning return result - def filter_sql(self, sql='', limit_num=0): + def filter_sql(self, sql="", limit_num=0): sql_lower = sql.lower() # 对查询sql增加limit限制 if re.match(r"^select", sql_lower): - if sql_lower.find(' top ') == -1: - return sql_lower.replace('select', 'select top {}'.format(limit_num)) + if sql_lower.find(" top ") == -1: + return sql_lower.replace("select", "select top {}".format(limit_num)) return sql.strip() - def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs): - """返回 ResultSet """ + def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs): + """返回 ResultSet""" result_set = ResultSet(full_sql=sql) try: conn = self.get_connection() cursor = conn.cursor() if db_name: - cursor.execute('use [{}];'.format(db_name)) + cursor.execute("use [{}];".format(db_name)) cursor.execute(sql) if int(limit_num) > 0: rows = cursor.fetchmany(int(limit_num)) @@ -287,7 +326,7 @@ def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs): self.close() return result_set - def query_masking(self, db_name=None, sql='', resultset=None): + def query_masking(self, db_name=None, sql="", resultset=None): """传入 sql语句, db名, 结果集, 返回一个脱敏后的结果集""" # 仅对select语句脱敏 @@ -298,11 +337,11 @@ def query_masking(self, db_name=None, sql='', resultset=None): filtered_result = resultset return filtered_result - def execute_check(self, db_name=None, sql=''): + def execute_check(self, db_name=None, sql=""): """上线单执行前的检查, 返回Review set""" check_result = ReviewSet(full_sql=sql) # 切分语句,追加到检测结果中,默认全部检测通过 - split_reg = re.compile('^GO$', re.I | re.M) + split_reg = re.compile("^GO$", re.I | re.M) sql = re.split(split_reg, sql, 0) sql = filter(None, sql) split_sql = [f"""use [{db_name}]"""] @@ -310,14 +349,17 @@ def execute_check(self, db_name=None, sql=''): split_sql = split_sql + [i] rowid = 1 for statement in split_sql: - check_result.rows.append(ReviewResult( - id=rowid, - errlevel=0, - stagestatus='Audit completed', - errormessage='None', - sql=statement, - affected_rows=0, - execute_time=0, )) + check_result.rows.append( + ReviewResult( + id=rowid, + errlevel=0, + stagestatus="Audit completed", + errormessage="None", + sql=statement, + affected_rows=0, + execute_time=0, + ) + ) rowid += 1 return check_result @@ -325,14 +367,16 @@ def execute_workflow(self, workflow): if workflow.is_backup: # TODO mssql 备份未实现 pass - return self.execute(db_name=workflow.db_name, sql=workflow.sqlworkflowcontent.sql_content) + return self.execute( + db_name=workflow.db_name, sql=workflow.sqlworkflowcontent.sql_content + ) - def execute(self, db_name=None, sql='', close_conn=True): + def execute(self, db_name=None, sql="", close_conn=True): """执行sql语句 返回 Review set""" execute_result = ReviewSet(full_sql=sql) conn = self.get_connection(db_name=db_name) cursor = conn.cursor() - split_reg = re.compile('^GO$', re.I | re.M) + split_reg = re.compile("^GO$", re.I | re.M) sql = re.split(split_reg, sql, 0) sql = filter(None, sql) split_sql = [f"""use [{db_name}]"""] @@ -345,44 +389,50 @@ def execute(self, db_name=None, sql='', close_conn=True): except Exception as e: logger.warning(f"Mssql命令执行报错,语句:{sql}, 错误信息:{traceback.format_exc()}") execute_result.error = str(e) - execute_result.rows.append(ReviewResult( - id=rowid, - errlevel=2, - stagestatus='Execute Failed', - errormessage=f'异常信息:{e}', - sql=statement, - affected_rows=0, - execute_time=0, - )) + execute_result.rows.append( + ReviewResult( + id=rowid, + errlevel=2, + stagestatus="Execute Failed", + errormessage=f"异常信息:{e}", + sql=statement, + affected_rows=0, + execute_time=0, + ) + ) break else: - execute_result.rows.append(ReviewResult( - id=rowid, - errlevel=0, - stagestatus='Execute Successfully', - errormessage='None', - sql=statement, - affected_rows=cursor.rowcount, - execute_time=0, - )) + execute_result.rows.append( + ReviewResult( + id=rowid, + errlevel=0, + stagestatus="Execute Successfully", + errormessage="None", + sql=statement, + affected_rows=cursor.rowcount, + execute_time=0, + ) + ) rowid += 1 if execute_result.error: # 如果失败, 将剩下的部分加入结果集, 并将语句回滚 for statement in split_sql[rowid:]: - execute_result.rows.append(ReviewResult( - id=rowid, - errlevel=2, - stagestatus='Execute Failed', - errormessage=f'前序语句失败, 未执行', - sql=statement, - affected_rows=0, - execute_time=0, - )) + execute_result.rows.append( + ReviewResult( + id=rowid, + errlevel=2, + stagestatus="Execute Failed", + errormessage=f"前序语句失败, 未执行", + sql=statement, + affected_rows=0, + execute_time=0, + ) + ) rowid += 1 cursor.rollback() for row in execute_result.rows: - if row.stagestatus == 'Execute Successfully': - row.stagestatus += '\nRollback Successfully' + if row.stagestatus == "Execute Successfully": + row.stagestatus += "\nRollback Successfully" else: cursor.commit() if close_conn: diff --git a/sql/engines/mysql.py b/sql/engines/mysql.py index 00fa4de15c..6845e78798 100644 --- a/sql/engines/mysql.py +++ b/sql/engines/mysql.py @@ -16,7 +16,7 @@ from sql.utils.data_masking import data_masking from common.config import SysConfig -logger = logging.getLogger('default') +logger = logging.getLogger("default") class MysqlEngine(EngineBase): @@ -30,30 +30,41 @@ def __init__(self, instance=None): def get_connection(self, db_name=None): # https://stackoverflow.com/questions/19256155/python-mysqldb-returning-x01-for-bit-values conversions = MySQLdb.converters.conversions - conversions[FIELD_TYPE.BIT] = lambda data: data == b'\x01' + conversions[FIELD_TYPE.BIT] = lambda data: data == b"\x01" if self.conn: self.thread_id = self.conn.thread_id() return self.conn if db_name: - self.conn = MySQLdb.connect(host=self.host, port=self.port, user=self.user, passwd=self.password, - db=db_name, charset=self.instance.charset or 'utf8mb4', - conv=conversions, - connect_timeout=10) + self.conn = MySQLdb.connect( + host=self.host, + port=self.port, + user=self.user, + passwd=self.password, + db=db_name, + charset=self.instance.charset or "utf8mb4", + conv=conversions, + connect_timeout=10, + ) else: - self.conn = MySQLdb.connect(host=self.host, port=self.port, user=self.user, passwd=self.password, - charset=self.instance.charset or 'utf8mb4', - conv=conversions, - connect_timeout=10) + self.conn = MySQLdb.connect( + host=self.host, + port=self.port, + user=self.user, + passwd=self.password, + charset=self.instance.charset or "utf8mb4", + conv=conversions, + connect_timeout=10, + ) self.thread_id = self.conn.thread_id() return self.conn @property def name(self): - return 'MySQL' + return "MySQL" @property def info(self): - return 'MySQL engine' + return "MySQL engine" @property def auto_backup(self): @@ -62,14 +73,21 @@ def auto_backup(self): @property def seconds_behind_master(self): - slave_status = self.query(sql='show slave status', close_conn=False, cursorclass=MySQLdb.cursors.DictCursor) - return slave_status.rows[0].get('Seconds_Behind_Master') if slave_status.rows else None + slave_status = self.query( + sql="show slave status", + close_conn=False, + cursorclass=MySQLdb.cursors.DictCursor, + ) + return ( + slave_status.rows[0].get("Seconds_Behind_Master") + if slave_status.rows + else None + ) @property def server_version(self): def numeric_part(s): - """Returns the leading numeric part of a string. - """ + """Returns the leading numeric part of a string.""" re_numeric_part = re.compile(r"^(\d+)") m = re_numeric_part.match(s) if m: @@ -78,27 +96,32 @@ def numeric_part(s): self.get_connection() version = self.conn.get_server_info() - return tuple([numeric_part(n) for n in version.split('.')[:3]]) + return tuple([numeric_part(n) for n in version.split(".")[:3]]) @property def schema_object(self): """获取实例对象信息""" - url = build_database_url(host=self.host, - username=self.user, - password=self.password, - port=self.port) - return schemaobject.SchemaObject(url, charset=self.instance.charset or 'utf8mb4') + url = build_database_url( + host=self.host, username=self.user, password=self.password, port=self.port + ) + return schemaobject.SchemaObject( + url, charset=self.instance.charset or "utf8mb4" + ) def kill_connection(self, thread_id): """终止数据库连接""" - self.query(sql=f'kill {thread_id}') + self.query(sql=f"kill {thread_id}") def get_all_databases(self): """获取数据库列表, 返回一个ResultSet""" sql = "show databases" result = self.query(sql=sql) - db_list = [row[0] for row in result.rows - if row[0] not in ('information_schema', 'performance_schema', 'mysql', 'test', 'sys')] + db_list = [ + row[0] + for row in result.rows + if row[0] + not in ("information_schema", "performance_schema", "mysql", "test", "sys") + ] result.rows = db_list return result @@ -106,13 +129,13 @@ def get_all_tables(self, db_name, **kwargs): """获取table 列表, 返回一个ResultSet""" sql = "show tables" result = self.query(db_name=db_name, sql=sql) - tb_list = [row[0] for row in result.rows if row[0] not in ['test']] + tb_list = [row[0] for row in result.rows if row[0] not in ["test"]] result.rows = tb_list return result def get_group_tables_by_db(self, db_name): # escape - db_name = MySQLdb.escape_string(db_name).decode('utf-8') + db_name = MySQLdb.escape_string(db_name).decode("utf-8") data = {} sql = f"""SELECT TABLE_NAME, TABLE_COMMENT @@ -131,8 +154,8 @@ def get_group_tables_by_db(self, db_name): def get_table_meta_data(self, db_name, tb_name, **kwargs): """数据字典页面使用:获取表格的元信息,返回一个dict{column_list: [], rows: []}""" # escape - db_name = MySQLdb.escape_string(db_name).decode('utf-8') - tb_name = MySQLdb.escape_string(tb_name).decode('utf-8') + db_name = MySQLdb.escape_string(db_name).decode("utf-8") + tb_name = MySQLdb.escape_string(tb_name).decode("utf-8") sql = f"""SELECT TABLE_NAME as table_name, ENGINE as engine, @@ -156,7 +179,7 @@ def get_table_meta_data(self, db_name, tb_name, **kwargs): TABLE_SCHEMA='{db_name}' AND TABLE_NAME='{tb_name}'""" _meta_data = self.query(db_name, sql) - return {'column_list': _meta_data.column_list, 'rows': _meta_data.rows[0]} + return {"column_list": _meta_data.column_list, "rows": _meta_data.rows[0]} def get_table_desc_data(self, db_name, tb_name, **kwargs): """获取表格字段信息""" @@ -176,7 +199,7 @@ def get_table_desc_data(self, db_name, tb_name, **kwargs): AND TABLE_NAME = '{tb_name}' ORDER BY ORDINAL_POSITION;""" _desc_data = self.query(db_name, sql) - return {'column_list': _desc_data.column_list, 'rows': _desc_data.rows} + return {"column_list": _desc_data.column_list, "rows": _desc_data.rows} def get_table_index_data(self, db_name, tb_name, **kwargs): """获取表格索引信息""" @@ -195,24 +218,35 @@ def get_table_index_data(self, db_name, tb_name, **kwargs): TABLE_SCHEMA = '{db_name}' AND TABLE_NAME = '{tb_name}';""" _index_data = self.query(db_name, sql) - return {'column_list': _index_data.column_list, 'rows': _index_data.rows} + return {"column_list": _index_data.column_list, "rows": _index_data.rows} def get_tables_metas_data(self, db_name, **kwargs): """获取数据库所有表格信息,用作数据字典导出接口""" - sql_tbs = f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA='{db_name}';" - tbs = self.query(sql=sql_tbs, cursorclass=MySQLdb.cursors.DictCursor, close_conn=False).rows + sql_tbs = ( + f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA='{db_name}';" + ) + tbs = self.query( + sql=sql_tbs, cursorclass=MySQLdb.cursors.DictCursor, close_conn=False + ).rows table_metas = [] for tb in tbs: _meta = dict() - engine_keys = [{"key": "COLUMN_NAME", "value": "字段名"}, {"key": "COLUMN_TYPE", "value": "数据类型"}, - {"key": "COLUMN_DEFAULT", "value": "默认值"}, {"key": "IS_NULLABLE", "value": "允许非空"}, - {"key": "EXTRA", "value": "自动递增"}, {"key": "COLUMN_KEY", "value": "是否主键"}, - {"key": "COLUMN_COMMENT", "value": "备注"}] + engine_keys = [ + {"key": "COLUMN_NAME", "value": "字段名"}, + {"key": "COLUMN_TYPE", "value": "数据类型"}, + {"key": "COLUMN_DEFAULT", "value": "默认值"}, + {"key": "IS_NULLABLE", "value": "允许非空"}, + {"key": "EXTRA", "value": "自动递增"}, + {"key": "COLUMN_KEY", "value": "是否主键"}, + {"key": "COLUMN_COMMENT", "value": "备注"}, + ] _meta["ENGINE_KEYS"] = engine_keys - _meta['TABLE_INFO'] = tb + _meta["TABLE_INFO"] = tb sql_cols = f"""SELECT * FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA='{tb['TABLE_SCHEMA']}' AND TABLE_NAME='{tb['TABLE_NAME']}';""" - _meta['COLUMNS'] = self.query(sql=sql_cols, cursorclass=MySQLdb.cursors.DictCursor, close_conn=False).rows + _meta["COLUMNS"] = self.query( + sql=sql_cols, cursorclass=MySQLdb.cursors.DictCursor, close_conn=False + ).rows table_metas.append(_meta) return table_metas @@ -243,11 +277,11 @@ def describe_table(self, db_name, tb_name, **kwargs): result = self.query(db_name=db_name, sql=sql) return result - def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs): - """返回 ResultSet """ + def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs): + """返回 ResultSet""" result_set = ResultSet(full_sql=sql) - max_execution_time = kwargs.get('max_execution_time', 0) - cursorclass = kwargs.get('cursorclass') or MySQLdb.cursors.Cursor + max_execution_time = kwargs.get("max_execution_time", 0) + cursorclass = kwargs.get("cursorclass") or MySQLdb.cursors.Cursor try: conn = self.get_connection(db_name=db_name) conn.autocommit(True) @@ -274,68 +308,75 @@ def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs): self.close() return result_set - def query_check(self, db_name=None, sql=''): + def query_check(self, db_name=None, sql=""): # 查询语句的检查、注释去除、切分 - result = {'msg': '', 'bad_query': False, 'filtered_sql': sql, 'has_star': False} + result = {"msg": "", "bad_query": False, "filtered_sql": sql, "has_star": False} # 删除注释语句,进行语法判断,执行第一条有效sql try: sql = sqlparse.format(sql, strip_comments=True) sql = sqlparse.split(sql)[0] - result['filtered_sql'] = sql.strip() + result["filtered_sql"] = sql.strip() except IndexError: - result['bad_query'] = True - result['msg'] = '没有有效的SQL语句' + result["bad_query"] = True + result["msg"] = "没有有效的SQL语句" if re.match(r"^select|^show|^explain", sql, re.I) is None: - result['bad_query'] = True - result['msg'] = '不支持的查询语法类型!' - if '*' in sql: - result['has_star'] = True - result['msg'] = 'SQL语句中含有 * ' + result["bad_query"] = True + result["msg"] = "不支持的查询语法类型!" + if "*" in sql: + result["has_star"] = True + result["msg"] = "SQL语句中含有 * " # select语句先使用Explain判断语法是否正确 if re.match(r"^select", sql, re.I): explain_result = self.query(db_name=db_name, sql=f"explain {sql}") if explain_result.error: - result['bad_query'] = True - result['msg'] = explain_result.error + result["bad_query"] = True + result["msg"] = explain_result.error # 不应该查看mysql.user表 - if re.match('.*(\\s)+(mysql|`mysql`)(\\s)*\\.(\\s)*(user|`user`)((\\s)*|;).*', sql.lower().replace('\n', '')) or \ - (db_name == "mysql" and re.match('.*(\\s)+(user|`user`)((\\s)*|;).*', sql.lower().replace('\n', ''))): - result['bad_query'] = True - result['msg'] = '您无权查看该表' + if re.match( + ".*(\\s)+(mysql|`mysql`)(\\s)*\\.(\\s)*(user|`user`)((\\s)*|;).*", + sql.lower().replace("\n", ""), + ) or ( + db_name == "mysql" + and re.match( + ".*(\\s)+(user|`user`)((\\s)*|;).*", sql.lower().replace("\n", "") + ) + ): + result["bad_query"] = True + result["msg"] = "您无权查看该表" return result - def filter_sql(self, sql='', limit_num=0): + def filter_sql(self, sql="", limit_num=0): # 对查询sql增加limit限制,limit n 或 limit n,n 或 limit n offset n统一改写成limit n - sql = sql.rstrip(';').strip() + sql = sql.rstrip(";").strip() if re.match(r"^select", sql, re.I): # LIMIT N - limit_n = re.compile(r'limit\s+(\d+)\s*$', re.I) + limit_n = re.compile(r"limit\s+(\d+)\s*$", re.I) # LIMIT M OFFSET N - limit_offset = re.compile(r'limit\s+(\d+)\s+offset\s+(\d+)\s*$', re.I) + limit_offset = re.compile(r"limit\s+(\d+)\s+offset\s+(\d+)\s*$", re.I) # LIMIT M,N - offset_comma_limit = re.compile(r'limit\s+(\d+)\s*,\s*(\d+)\s*$', re.I) + offset_comma_limit = re.compile(r"limit\s+(\d+)\s*,\s*(\d+)\s*$", re.I) if limit_n.search(sql): sql_limit = limit_n.search(sql).group(1) limit_num = min(int(limit_num), int(sql_limit)) - sql = limit_n.sub(f'limit {limit_num};', sql) + sql = limit_n.sub(f"limit {limit_num};", sql) elif limit_offset.search(sql): sql_limit = limit_offset.search(sql).group(1) sql_offset = limit_offset.search(sql).group(2) limit_num = min(int(limit_num), int(sql_limit)) - sql = limit_offset.sub(f'limit {limit_num} offset {sql_offset};', sql) + sql = limit_offset.sub(f"limit {limit_num} offset {sql_offset};", sql) elif offset_comma_limit.search(sql): sql_offset = offset_comma_limit.search(sql).group(1) sql_limit = offset_comma_limit.search(sql).group(2) limit_num = min(int(limit_num), int(sql_limit)) - sql = offset_comma_limit.sub(f'limit {sql_offset},{limit_num};', sql) + sql = offset_comma_limit.sub(f"limit {sql_offset},{limit_num};", sql) else: - sql = f'{sql} limit {limit_num};' + sql = f"{sql} limit {limit_num};" else: - sql = f'{sql};' + sql = f"{sql};" return sql - def query_masking(self, db_name=None, sql='', resultset=None): + def query_masking(self, db_name=None, sql="", resultset=None): """传入 sql语句, db名, 结果集, 返回一个脱敏后的结果集""" # 仅对select语句脱敏 @@ -345,53 +386,65 @@ def query_masking(self, db_name=None, sql='', resultset=None): mask_result = resultset return mask_result - def execute_check(self, db_name=None, sql=''): + def execute_check(self, db_name=None, sql=""): """上线单执行前的检查, 返回Review set""" # 进行Inception检查,获取检测结果 try: - check_result = self.inc_engine.execute_check(instance=self.instance, db_name=db_name, sql=sql) + check_result = self.inc_engine.execute_check( + instance=self.instance, db_name=db_name, sql=sql + ) except Exception as e: logger.debug(f"{self.inc_engine.name}检测语句报错:错误信息{traceback.format_exc()}") - raise RuntimeError(f"{self.inc_engine.name}检测语句报错,请注意检查系统配置中{self.inc_engine.name}配置,错误信息:\n{e}") + raise RuntimeError( + f"{self.inc_engine.name}检测语句报错,请注意检查系统配置中{self.inc_engine.name}配置,错误信息:\n{e}" + ) # 判断Inception检测结果 if check_result.error: logger.debug(f"{self.inc_engine.name}检测语句报错:错误信息{check_result.error}") - raise RuntimeError(f"{self.inc_engine.name}检测语句报错,错误信息:\n{check_result.error}") + raise RuntimeError( + f"{self.inc_engine.name}检测语句报错,错误信息:\n{check_result.error}" + ) # 禁用/高危语句检查 - critical_ddl_regex = self.config.get('critical_ddl_regex', '') + critical_ddl_regex = self.config.get("critical_ddl_regex", "") p = re.compile(critical_ddl_regex) for row in check_result.rows: statement = row.sql # 去除注释 - statement = remove_comments(statement, db_type='mysql') + statement = remove_comments(statement, db_type="mysql") # 禁用语句 if re.match(r"^select", statement.lower()): check_result.error_count += 1 - row.stagestatus = '驳回不支持语句' + row.stagestatus = "驳回不支持语句" row.errlevel = 2 - row.errormessage = '仅支持DML和DDL语句,查询语句请使用SQL查询功能!' + row.errormessage = "仅支持DML和DDL语句,查询语句请使用SQL查询功能!" # 高危语句 elif critical_ddl_regex and p.match(statement.strip().lower()): check_result.error_count += 1 - row.stagestatus = '驳回高危SQL' + row.stagestatus = "驳回高危SQL" row.errlevel = 2 - row.errormessage = '禁止提交匹配' + critical_ddl_regex + '条件的语句!' + row.errormessage = "禁止提交匹配" + critical_ddl_regex + "条件的语句!" return check_result def execute_workflow(self, workflow): """执行上线单,返回Review set""" # 判断实例是否只读 - read_only = self.query(sql='SELECT @@global.read_only;').rows[0][0] - if read_only in (1, 'ON'): + read_only = self.query(sql="SELECT @@global.read_only;").rows[0][0] + if read_only in (1, "ON"): result = ReviewSet( full_sql=workflow.sqlworkflowcontent.sql_content, - rows=[ReviewResult(id=1, errlevel=2, - stagestatus='Execute Failed', - errormessage='实例read_only=1,禁止执行变更语句!', - sql=workflow.sqlworkflowcontent.sql_content)]) - result.error = '实例read_only=1,禁止执行变更语句!', + rows=[ + ReviewResult( + id=1, + errlevel=2, + stagestatus="Execute Failed", + errormessage="实例read_only=1,禁止执行变更语句!", + sql=workflow.sqlworkflowcontent.sql_content, + ) + ], + ) + result.error = ("实例read_only=1,禁止执行变更语句!",) return result # TODO 原生执行 # if workflow.is_manual == 1: @@ -399,7 +452,7 @@ def execute_workflow(self, workflow): # inception执行 return self.inc_engine.execute(workflow) - def execute(self, db_name=None, sql='', close_conn=True): + def execute(self, db_name=None, sql="", close_conn=True): """原生执行语句""" result = ResultSet(full_sql=sql) conn = self.get_connection(db_name=db_name) @@ -424,8 +477,16 @@ def get_rollback(self, workflow): def get_variables(self, variables=None): """获取实例参数""" if variables: - variables = "','".join(variables) if isinstance(variables, list) else "','".join(list(variables)) - db = 'performance_schema' if self.server_version > (5, 7) else 'information_schema' + variables = ( + "','".join(variables) + if isinstance(variables, list) + else "','".join(list(variables)) + ) + db = ( + "performance_schema" + if self.server_version > (5, 7) + else "information_schema" + ) sql = f"""select * from {db}.global_variables where variable_name in ('{variables}');""" else: sql = "show global variables;" @@ -438,7 +499,7 @@ def set_variable(self, variable_name, variable_value): def osc_control(self, **kwargs): """控制osc执行,获取进度、终止、暂停、恢复等 - get、kill、pause、resume + get、kill、pause、resume """ return self.inc_engine.osc_control(**kwargs) @@ -446,42 +507,44 @@ def processlist(self, command_type): """获取连接信息""" base_sql = "select id, user, host, db, command, time, state, ifnull(info,'') as info from information_schema.processlist" # escape - command_type = MySQLdb.escape_string(command_type).decode('utf-8') + command_type = MySQLdb.escape_string(command_type).decode("utf-8") if not command_type: - command_type = 'Query' - if command_type == 'All': - sql = base_sql + ';' - elif command_type == 'Not Sleep': + command_type = "Query" + if command_type == "All": + sql = base_sql + ";" + elif command_type == "Not Sleep": sql = "{} where command<>'Sleep';".format(base_sql) else: sql = "{} where command= '{}';".format(base_sql, command_type) - - return self.query('information_schema', sql) - + + return self.query("information_schema", sql) + def get_kill_command(self, thread_ids): """由传入的线程列表生成kill命令""" - sql = "select concat('kill ', id, ';') from information_schema.processlist where id in ({});"\ - .format(','.join(str(tid) for tid in thread_ids)) - all_kill_sql = self.query('information_schema', sql) - kill_sql = '' + sql = "select concat('kill ', id, ';') from information_schema.processlist where id in ({});".format( + ",".join(str(tid) for tid in thread_ids) + ) + all_kill_sql = self.query("information_schema", sql) + kill_sql = "" for row in all_kill_sql.rows: - kill_sql = kill_sql + row[0] - + kill_sql = kill_sql + row[0] + return kill_sql - - def kill(self, thread_ids): + + def kill(self, thread_ids): """kill线程""" - sql = "select concat('kill ', id, ';') from information_schema.processlist where id in ({});"\ - .format(','.join(str(tid) for tid in thread_ids)) - all_kill_sql = self.query('information_schema', sql) - kill_sql = '' + sql = "select concat('kill ', id, ';') from information_schema.processlist where id in ({});".format( + ",".join(str(tid) for tid in thread_ids) + ) + all_kill_sql = self.query("information_schema", sql) + kill_sql = "" for row in all_kill_sql.rows: kill_sql = kill_sql + row[0] - return self.execute('information_schema', kill_sql) - - def tablesapce(self, offset=0, row_count=14): + return self.execute("information_schema", kill_sql) + + def tablesapce(self, offset=0, row_count=14): """获取表空间信息""" - sql = ''' + sql = """ SELECT table_schema AS table_schema, table_name AS table_name, @@ -495,22 +558,24 @@ def tablesapce(self, offset=0, row_count=14): FROM information_schema.tables WHERE table_schema NOT IN ('information_schema', 'performance_schema', 'mysql', 'test', 'sys') ORDER BY total_size DESC - LIMIT {},{};'''.format(offset, row_count) - return self.query('information_schema', sql) - - def tablesapce_num(self): + LIMIT {},{};""".format( + offset, row_count + ) + return self.query("information_schema", sql) + + def tablesapce_num(self): """获取表空间数量""" - sql = ''' + sql = """ SELECT count(*) FROM information_schema.tables - WHERE table_schema NOT IN ('information_schema', 'performance_schema', 'mysql', 'test', 'sys')''' - return self.query('information_schema', sql) - + WHERE table_schema NOT IN ('information_schema', 'performance_schema', 'mysql', 'test', 'sys')""" + return self.query("information_schema", sql) + def trxandlocks(self): """获取锁等待信息""" server_version = self.server_version if server_version < (8, 0, 1): - sql = ''' + sql = """ SELECT rtrx.`trx_state` AS "等待的状态", rtrx.`trx_started` AS "等待事务开始时间", @@ -536,10 +601,10 @@ def trxandlocks(self): WHERE rl.`lock_id` = lw.`requested_lock_id` AND l.`lock_id` = lw.`blocking_lock_id` AND lw.requesting_trx_id = rtrx.trx_id - AND lw.blocking_trx_id = trx.trx_id;''' - + AND lw.blocking_trx_id = trx.trx_id;""" + else: - sql = ''' + sql = """ SELECT rtrx.`trx_state` AS "等待的状态", rtrx.`trx_started` AS "等待事务开始时间", @@ -565,13 +630,13 @@ def trxandlocks(self): WHERE rl.`ENGINE_LOCK_ID` = lw.`REQUESTING_ENGINE_LOCK_ID` AND l.`ENGINE_LOCK_ID` = lw.`BLOCKING_ENGINE_LOCK_ID` AND lw.REQUESTING_ENGINE_TRANSACTION_ID = rtrx.trx_id - AND lw.BLOCKING_ENGINE_TRANSACTION_ID = trx.trx_id;''' - - return self.query('information_schema', sql) - + AND lw.BLOCKING_ENGINE_TRANSACTION_ID = trx.trx_id;""" + + return self.query("information_schema", sql) + def get_long_transaction(self, thread_time=3): """获取长事务""" - sql = '''select trx.trx_started, + sql = """select trx.trx_started, trx.trx_state, trx.trx_operation_state, trx.trx_mysql_thread_id, @@ -598,10 +663,12 @@ def get_long_transaction(self, thread_time=3): WHERE trx.trx_state = 'RUNNING' AND p.COMMAND = 'Sleep' AND p.time > {} - ORDER BY trx.trx_started ASC;'''.format(thread_time) - - return self.query('information_schema', sql) - + ORDER BY trx.trx_started ASC;""".format( + thread_time + ) + + return self.query("information_schema", sql) + def close(self): if self.conn: self.conn.close() diff --git a/sql/engines/odps.py b/sql/engines/odps.py index 79b0bb23de..b4e986b5e1 100644 --- a/sql/engines/odps.py +++ b/sql/engines/odps.py @@ -10,7 +10,7 @@ from odps import ODPS -logger = logging.getLogger('default') +logger = logging.getLogger("default") class ODPSEngine(EngineBase): @@ -31,16 +31,16 @@ def get_connection(self, db_name=None): @property def name(self): - return 'ODPS' + return "ODPS" @property def info(self): - return 'ODPS engine' + return "ODPS engine" def get_all_databases(self): """获取数据库列表, 返回一个ResultSet - ODPS只有project概念, 直接返回project名称 - TODO: 目前ODPS获取所有项目接口比较慢, 暂时支持返回一个project,后续再优化 + ODPS只有project概念, 直接返回project名称 + TODO: 目前ODPS获取所有项目接口比较慢, 暂时支持返回一个project,后续再优化 """ result = ResultSet() @@ -80,7 +80,7 @@ def get_all_tables(self, db_name, **kwargs): def get_all_columns_by_tb(self, db_name, tb_name, **kwargs): """获取所有字段, 返回一个ResultSet""" - column_list = ['COLUMN_NAME', 'COLUMN_TYPE', 'COLUMN_COMMENT'] + column_list = ["COLUMN_NAME", "COLUMN_TYPE", "COLUMN_COMMENT"] conn = self.get_connection(db_name) @@ -105,27 +105,27 @@ def describe_table(self, db_name, tb_name, **kwargs): return result - def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs): - """返回 ResultSet """ + def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs): + """返回 ResultSet""" result_set = ResultSet(full_sql=sql) if not re.match(r"^select", sql, re.I): result_set.error = str("仅支持ODPS查询语句") # 存在limit,替换limit; 不存在,添加limit - if re.search('limit', sql): - sql = re.sub('limit.+(\d+)', 'limit ' + str(limit_num), sql) + if re.search("limit", sql): + sql = re.sub("limit.+(\d+)", "limit " + str(limit_num), sql) else: - if sql.strip()[-1] == ';': + if sql.strip()[-1] == ";": sql = sql[:-1] - sql = sql + ' limit ' + str(limit_num) + ';' + sql = sql + " limit " + str(limit_num) + ";" try: conn = self.get_connection(db_name) effect_row = conn.execute_sql(sql) reader = effect_row.open_reader() rows = [row.values for row in reader] - column_list = getattr(reader, '_schema').names + column_list = getattr(reader, "_schema").names result_set.column_list = column_list result_set.rows = rows @@ -136,27 +136,27 @@ def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs): result_set.error = str(e) return result_set - def query_check(self, db_name=None, sql=''): + def query_check(self, db_name=None, sql=""): # 查询语句的检查、注释去除、切分 - result = {'msg': '', 'bad_query': False, 'filtered_sql': sql, 'has_star': False} - keyword_warning = '' - sql_whitelist = ['select'] + result = {"msg": "", "bad_query": False, "filtered_sql": sql, "has_star": False} + keyword_warning = "" + sql_whitelist = ["select"] # 根据白名单list拼接pattern语句 whitelist_pattern = re.compile("^" + "|^".join(sql_whitelist), re.IGNORECASE) # 删除注释语句,进行语法判断,执行第一条有效sql try: sql = sqlparse.format(sql, strip_comments=True) sql = sqlparse.split(sql)[0] - result['filtered_sql'] = sql.strip() + result["filtered_sql"] = sql.strip() # sql_lower = sql.lower() except IndexError: - result['bad_query'] = True - result['msg'] = '没有有效的SQL语句' + result["bad_query"] = True + result["msg"] = "没有有效的SQL语句" return result if whitelist_pattern.match(sql) is None: - result['bad_query'] = True - result['msg'] = '仅支持{}语法!'.format(','.join(sql_whitelist)) + result["bad_query"] = True + result["msg"] = "仅支持{}语法!".format(",".join(sql_whitelist)) return result - if result.get('bad_query'): - result['msg'] = keyword_warning + if result.get("bad_query"): + result["msg"] = keyword_warning return result diff --git a/sql/engines/oracle.py b/sql/engines/oracle.py index 7849d69f2e..4d6195c8b2 100644 --- a/sql/engines/oracle.py +++ b/sql/engines/oracle.py @@ -10,13 +10,17 @@ import pandas as pd from common.config import SysConfig from common.utils.timer import FuncTimer -from sql.utils.sql_utils import get_syntax_type, get_full_sqlitem_list, get_exec_sqlitem_list +from sql.utils.sql_utils import ( + get_syntax_type, + get_full_sqlitem_list, + get_exec_sqlitem_list, +) from . import EngineBase import cx_Oracle from .models import ResultSet, ReviewSet, ReviewResult from sql.utils.data_masking import simple_column_mask -logger = logging.getLogger('default') +logger = logging.getLogger("default") class OracleEngine(EngineBase): @@ -32,21 +36,27 @@ def get_connection(self, db_name=None): return self.conn if self.sid: dsn = cx_Oracle.makedsn(self.host, self.port, self.sid) - self.conn = cx_Oracle.connect(self.user, self.password, dsn=dsn, encoding="UTF-8", nencoding="UTF-8") + self.conn = cx_Oracle.connect( + self.user, self.password, dsn=dsn, encoding="UTF-8", nencoding="UTF-8" + ) elif self.service_name: - dsn = cx_Oracle.makedsn(self.host, self.port, service_name=self.service_name) - self.conn = cx_Oracle.connect(self.user, self.password, dsn=dsn, encoding="UTF-8", nencoding="UTF-8") + dsn = cx_Oracle.makedsn( + self.host, self.port, service_name=self.service_name + ) + self.conn = cx_Oracle.connect( + self.user, self.password, dsn=dsn, encoding="UTF-8", nencoding="UTF-8" + ) else: - raise ValueError('sid 和 dsn 均未填写, 请联系管理页补充该实例配置.') + raise ValueError("sid 和 dsn 均未填写, 请联系管理页补充该实例配置.") return self.conn @property def name(self): - return 'Oracle' + return "Oracle" @property def info(self): - return 'Oracle engine' + return "Oracle engine" @property def auto_backup(self): @@ -57,23 +67,24 @@ def auto_backup(self): def get_backup_connection(): """备份库连接""" archer_config = SysConfig() - backup_host = archer_config.get('inception_remote_backup_host') - backup_port = int(archer_config.get('inception_remote_backup_port', 3306)) - backup_user = archer_config.get('inception_remote_backup_user') - backup_password = archer_config.get('inception_remote_backup_password') - return MySQLdb.connect(host=backup_host, - port=backup_port, - user=backup_user, - passwd=backup_password, - charset='utf8mb4', - autocommit=True - ) + backup_host = archer_config.get("inception_remote_backup_host") + backup_port = int(archer_config.get("inception_remote_backup_port", 3306)) + backup_user = archer_config.get("inception_remote_backup_user") + backup_password = archer_config.get("inception_remote_backup_password") + return MySQLdb.connect( + host=backup_host, + port=backup_port, + user=backup_user, + passwd=backup_password, + charset="utf8mb4", + autocommit=True, + ) @property def server_version(self): conn = self.get_connection() version = conn.version - return tuple([n for n in version.split('.')[:3]]) + return tuple([n for n in version.split(".")[:3]]) def get_all_databases(self): """获取数据库列表, 返回resultSet 供上层调用, 底层实际上是获取oracle的schema列表""" @@ -102,11 +113,47 @@ def _get_all_schemas(self): """ result = self.query(sql="SELECT username FROM all_users order by username") sysschema = ( - 'AUD_SYS', 'ANONYMOUS', 'APEX_030200', 'APEX_PUBLIC_USER', 'APPQOSSYS', 'BI USERS', 'CTXSYS', 'DBSNMP', - 'DIP USERS', 'EXFSYS', 'FLOWS_FILES', 'HR USERS', 'IX USERS', 'MDDATA', 'MDSYS', 'MGMT_VIEW', 'OE USERS', - 'OLAPSYS', 'ORACLE_OCM', 'ORDDATA', 'ORDPLUGINS', 'ORDSYS', 'OUTLN', 'OWBSYS', 'OWBSYS_AUDIT', 'PM USERS', - 'SCOTT', 'SH USERS', 'SI_INFORMTN_SCHEMA', 'SPATIAL_CSW_ADMIN_USR', 'SPATIAL_WFS_ADMIN_USR', 'SYS', - 'SYSMAN', 'SYSTEM', 'WMSYS', 'XDB', 'XS$NULL', 'DIP', 'OJVMSYS', 'LBACSYS') + "AUD_SYS", + "ANONYMOUS", + "APEX_030200", + "APEX_PUBLIC_USER", + "APPQOSSYS", + "BI USERS", + "CTXSYS", + "DBSNMP", + "DIP USERS", + "EXFSYS", + "FLOWS_FILES", + "HR USERS", + "IX USERS", + "MDDATA", + "MDSYS", + "MGMT_VIEW", + "OE USERS", + "OLAPSYS", + "ORACLE_OCM", + "ORDDATA", + "ORDPLUGINS", + "ORDSYS", + "OUTLN", + "OWBSYS", + "OWBSYS_AUDIT", + "PM USERS", + "SCOTT", + "SH USERS", + "SI_INFORMTN_SCHEMA", + "SPATIAL_CSW_ADMIN_USR", + "SPATIAL_WFS_ADMIN_USR", + "SYS", + "SYSMAN", + "SYSTEM", + "WMSYS", + "XDB", + "XS$NULL", + "DIP", + "OJVMSYS", + "LBACSYS", + ) schema_list = [row[0] for row in result.rows if row[0] not in sysschema] result.rows = schema_list return result @@ -116,7 +163,7 @@ def get_all_tables(self, db_name, **kwargs): sql = f"""SELECT table_name FROM all_tables WHERE nvl(tablespace_name, 'no tablespace') NOT IN ('SYSTEM', 'SYSAUX') AND OWNER = '{db_name}' AND IOT_NAME IS NULL AND DURATION IS NULL order by table_name """ result = self.query(db_name=db_name, sql=sql) - tb_list = [row[0] for row in result.rows if row[0] not in ['test']] + tb_list = [row[0] for row in result.rows if row[0] not in ["test"]] result.rows = tb_list return result @@ -161,7 +208,7 @@ def get_table_meta_data(self, db_name, tb_name, **kwargs): tcs.OWNER='{db_name}' AND tcs.TABLE_NAME='{tb_name}'""" _meta_data = self.query(db_name=db_name, sql=meta_data_sql) - return {'column_list': _meta_data.column_list, 'rows': _meta_data.rows[0]} + return {"column_list": _meta_data.column_list, "rows": _meta_data.rows[0]} def get_table_desc_data(self, db_name, tb_name, **kwargs): """获取表格字段信息""" @@ -206,7 +253,7 @@ def get_table_desc_data(self, db_name, tb_name, **kwargs): AND bcs.TABLE_NAME='{tb_name}' ORDER BY bcs.COLUMN_NAME""" _desc_data = self.query(db_name=db_name, sql=desc_sql) - return {'column_list': _desc_data.column_list, 'rows': _desc_data.rows} + return {"column_list": _desc_data.column_list, "rows": _desc_data.rows} def get_table_index_data(self, db_name, tb_name, **kwargs): """获取表格索引信息""" @@ -228,7 +275,7 @@ def get_table_index_data(self, db_name, tb_name, **kwargs): ais.owner = '{db_name}' AND ais.table_name = '{tb_name}'""" _index_data = self.query(db_name, index_sql) - return {'column_list': _index_data.column_list, 'rows': _index_data.rows} + return {"column_list": _index_data.column_list, "rows": _index_data.rows} def get_tables_metas_data(self, db_name, **kwargs): """获取数据库所有表格信息,用作数据字典导出接口""" @@ -282,22 +329,42 @@ def get_tables_metas_data(self, db_name, **kwargs): cols_req = self.query(sql=sql_cols, close_conn=False).rows # 给查询结果定义列名,query_engine.query的游标是0 1 2 - cols_df = pd.DataFrame(cols_req, - columns=['TABLE_NAME', 'TABLE_COMMENTS', 'COLUMN_NAME', 'COLUMN_TYPE', 'COLUMN_DEFAULT', - 'IS_NULLABLE', 'COLUMN_KEY', 'COLUMN_COMMENT']) + cols_df = pd.DataFrame( + cols_req, + columns=[ + "TABLE_NAME", + "TABLE_COMMENTS", + "COLUMN_NAME", + "COLUMN_TYPE", + "COLUMN_DEFAULT", + "IS_NULLABLE", + "COLUMN_KEY", + "COLUMN_COMMENT", + ], + ) # 获得表名称去重 - col_list = cols_df.drop_duplicates('TABLE_NAME').to_dict('records') + col_list = cols_df.drop_duplicates("TABLE_NAME").to_dict("records") for cl in col_list: _meta = dict() - engine_keys = [{"key": "COLUMN_NAME", "value": "字段名"}, {"key": "COLUMN_TYPE", "value": "数据类型"}, - {"key": "COLUMN_DEFAULT", "value": "默认值"}, {"key": "IS_NULLABLE", "value": "允许非空"}, - {"key": "COLUMN_KEY", "value": "是否主键"}, {"key": "COLUMN_COMMENT", "value": "备注"}] + engine_keys = [ + {"key": "COLUMN_NAME", "value": "字段名"}, + {"key": "COLUMN_TYPE", "value": "数据类型"}, + {"key": "COLUMN_DEFAULT", "value": "默认值"}, + {"key": "IS_NULLABLE", "value": "允许非空"}, + {"key": "COLUMN_KEY", "value": "是否主键"}, + {"key": "COLUMN_COMMENT", "value": "备注"}, + ] _meta["ENGINE_KEYS"] = engine_keys - _meta['TABLE_INFO'] = {'TABLE_NAME': cl['TABLE_NAME'], 'TABLE_COMMENTS': cl['TABLE_COMMENTS']} - table_name = cl['TABLE_NAME'] + _meta["TABLE_INFO"] = { + "TABLE_NAME": cl["TABLE_NAME"], + "TABLE_COMMENTS": cl["TABLE_COMMENTS"], + } + table_name = cl["TABLE_NAME"] # 查询DataFrame中满足表名的记录,并转为list - _meta['COLUMNS'] = cols_df.query("TABLE_NAME == @table_name").to_dict('records') + _meta["COLUMNS"] = cols_df.query("TABLE_NAME == @table_name").to_dict( + "records" + ) table_metas.append(_meta) return table_metas @@ -306,7 +373,7 @@ def get_all_objects(self, db_name, **kwargs): """获取object_name 列表, 返回一个ResultSet""" sql = f"""SELECT object_name FROM all_objects WHERE OWNER = '{db_name}' """ result = self.query(db_name=db_name, sql=sql) - tb_list = [row[0] for row in result.rows if row[0] not in ['test']] + tb_list = [row[0] for row in result.rows if row[0] not in ["test"]] result.rows = tb_list return result @@ -336,27 +403,27 @@ def describe_table(self, db_name, tb_name, **kwargs): result = self.query(db_name=db_name, sql=sql) return result - def object_name_check(self, db_name=None, object_name=''): + def object_name_check(self, db_name=None, object_name=""): """获取table 列表, 返回一个ResultSet""" - if '.' in object_name: - schema_name = object_name.split('.')[0] - object_name = object_name.split('.')[1] + if "." in object_name: + schema_name = object_name.split(".")[0] + object_name = object_name.split(".")[1] if '"' in schema_name: - schema_name = schema_name.replace('"', '') + schema_name = schema_name.replace('"', "") if '"' in object_name: - object_name = object_name.replace('"', '') + object_name = object_name.replace('"', "") else: object_name = object_name.upper() else: schema_name = schema_name.upper() if '"' in object_name: - object_name = object_name.replace('"', '') + object_name = object_name.replace('"', "") else: object_name = object_name.upper() else: schema_name = db_name if '"' in object_name: - object_name = object_name.replace('"', '') + object_name = object_name.replace('"', "") else: object_name = object_name.upper() sql = f""" SELECT object_name FROM all_objects WHERE OWNER = '{schema_name}' and OBJECT_NAME = '{object_name}' """ @@ -367,47 +434,71 @@ def object_name_check(self, db_name=None, object_name=''): return False @staticmethod - def get_sql_first_object_name(sql=''): + def get_sql_first_object_name(sql=""): """获取sql文本中的object_name""" - object_name = '' + object_name = "" if re.match(r"^create\s+table\s", sql, re.M | re.IGNORECASE): - object_name = re.match(r"^create\s+table\s(.+?)(\s|\()", sql, re.M | re.IGNORECASE).group(1) + object_name = re.match( + r"^create\s+table\s(.+?)(\s|\()", sql, re.M | re.IGNORECASE + ).group(1) elif re.match(r"^create\s+index\s", sql, re.M | re.IGNORECASE): - object_name = re.match(r"^create\s+index\s(.+?)\s", sql, re.M | re.IGNORECASE).group(1) + object_name = re.match( + r"^create\s+index\s(.+?)\s", sql, re.M | re.IGNORECASE + ).group(1) elif re.match(r"^create\s+unique\s+index\s", sql, re.M | re.IGNORECASE): - object_name = re.match(r"^create\s+unique\s+index\s(.+?)\s", sql, re.M | re.IGNORECASE).group(1) + object_name = re.match( + r"^create\s+unique\s+index\s(.+?)\s", sql, re.M | re.IGNORECASE + ).group(1) elif re.match(r"^create\s+sequence\s", sql, re.M | re.IGNORECASE): - object_name = re.match(r"^create\s+sequence\s(.+?)(\s|$)", sql, re.M | re.IGNORECASE).group(1) + object_name = re.match( + r"^create\s+sequence\s(.+?)(\s|$)", sql, re.M | re.IGNORECASE + ).group(1) elif re.match(r"^alter\s+table\s", sql, re.M | re.IGNORECASE): - object_name = re.match(r"^alter\s+table\s(.+?)\s", sql, re.M | re.IGNORECASE).group(1) + object_name = re.match( + r"^alter\s+table\s(.+?)\s", sql, re.M | re.IGNORECASE + ).group(1) elif re.match(r"^create\s+function\s", sql, re.M | re.IGNORECASE): - object_name = re.match(r"^create\s+function\s(.+?)(\s|\()", sql, re.M | re.IGNORECASE).group(1) + object_name = re.match( + r"^create\s+function\s(.+?)(\s|\()", sql, re.M | re.IGNORECASE + ).group(1) elif re.match(r"^create\s+view\s", sql, re.M | re.IGNORECASE): - object_name = re.match(r"^create\s+view\s(.+?)\s", sql, re.M | re.IGNORECASE).group(1) + object_name = re.match( + r"^create\s+view\s(.+?)\s", sql, re.M | re.IGNORECASE + ).group(1) elif re.match(r"^create\s+procedure\s", sql, re.M | re.IGNORECASE): - object_name = re.match(r"^create\s+procedure\s(.+?)\s", sql, re.M | re.IGNORECASE).group(1) + object_name = re.match( + r"^create\s+procedure\s(.+?)\s", sql, re.M | re.IGNORECASE + ).group(1) elif re.match(r"^create\s+package\s+body", sql, re.M | re.IGNORECASE): - object_name = re.match(r"^create\s+package\s+body\s(.+?)\s", sql, re.M | re.IGNORECASE).group(1) + object_name = re.match( + r"^create\s+package\s+body\s(.+?)\s", sql, re.M | re.IGNORECASE + ).group(1) elif re.match(r"^create\s+package\s", sql, re.M | re.IGNORECASE): - object_name = re.match(r"^create\s+package\s(.+?)\s", sql, re.M | re.IGNORECASE).group(1) + object_name = re.match( + r"^create\s+package\s(.+?)\s", sql, re.M | re.IGNORECASE + ).group(1) else: return object_name.strip() return object_name.strip() @staticmethod - def check_create_index_table(sql='', object_name_list=None, db_name=''): + def check_create_index_table(sql="", object_name_list=None, db_name=""): object_name_list = object_name_list or set() if re.match(r"^create\s+index\s", sql): - table_name = re.match(r"^create\s+index\s+.+\s+on\s(.+?)(\(|\s\()", sql, re.M).group(1) - if '.' not in table_name: + table_name = re.match( + r"^create\s+index\s+.+\s+on\s(.+?)(\(|\s\()", sql, re.M + ).group(1) + if "." not in table_name: table_name = f"{db_name}.{table_name}" if table_name in object_name_list: return True else: return False elif re.match(r"^create\s+unique\s+index\s", sql): - table_name = re.match(r"^create\s+unique\s+index\s+.+\s+on\s(.+?)(\(|\s\()", sql, re.M).group(1) - if '.' not in table_name: + table_name = re.match( + r"^create\s+unique\s+index\s+.+\s+on\s(.+?)(\(|\s\()", sql, re.M + ).group(1) + if "." not in table_name: table_name = f"{db_name}.{table_name}" if table_name in object_name_list: return True @@ -417,11 +508,11 @@ def check_create_index_table(sql='', object_name_list=None, db_name=''): return False @staticmethod - def get_dml_table(sql='', object_name_list=None, db_name=''): + def get_dml_table(sql="", object_name_list=None, db_name=""): object_name_list = object_name_list or set() if re.match(r"^update", sql): table_name = re.match(r"^update\s(.+?)\s", sql, re.M).group(1) - if '.' not in table_name: + if "." not in table_name: table_name = f"{db_name}.{table_name}" if table_name in object_name_list: return True @@ -429,16 +520,19 @@ def get_dml_table(sql='', object_name_list=None, db_name=''): return False elif re.match(r"^delete", sql): table_name = re.match(r"^delete\s+from\s+([\w-]+)\s*", sql, re.M).group(1) - if '.' not in table_name: + if "." not in table_name: table_name = f"{db_name}.{table_name}" if table_name in object_name_list: return True else: return False elif re.match(r"^insert\s", sql): - table_name = re.match(r"^insert\s+((into)|(all\s+into)|(all\s+when\s(.+?)into))\s+(.+?)(\(|\s)", sql, - re.M).group(6) - if '.' not in table_name: + table_name = re.match( + r"^insert\s+((into)|(all\s+into)|(all\s+when\s(.+?)into))\s+(.+?)(\(|\s)", + sql, + re.M, + ).group(6) + if "." not in table_name: table_name = f"{db_name}.{table_name}" if table_name in object_name_list: return True @@ -448,93 +542,103 @@ def get_dml_table(sql='', object_name_list=None, db_name=''): return False @staticmethod - def where_check(sql=''): + def where_check(sql=""): if re.match(r"^update((?!where).)*$|^delete((?!where).)*$", sql): return True else: parsed = sqlparse.parse(sql)[0] flattened = list(parsed.flatten()) n_skip = 0 - flattened = flattened[:len(flattened) - n_skip] - logical_operators = ('AND', 'OR', 'NOT', 'BETWEEN', 'ORDER BY', 'GROUP BY', 'HAVING') + flattened = flattened[: len(flattened) - n_skip] + logical_operators = ( + "AND", + "OR", + "NOT", + "BETWEEN", + "ORDER BY", + "GROUP BY", + "HAVING", + ) for t in reversed(flattened): if t.is_keyword: return True return False - def explain_check(self, db_name=None, sql='', close_conn=False): + def explain_check(self, db_name=None, sql="", close_conn=False): # 使用explain进行支持的SQL语法审核,连接需不中断,防止数据库不断fork进程的大批量消耗 - result = {'msg': '', 'rows': 0} + result = {"msg": "", "rows": 0} try: conn = self.get_connection() cursor = conn.cursor() if db_name: - cursor.execute(f" ALTER SESSION SET CURRENT_SCHEMA = \"{db_name}\" ") + cursor.execute(f' ALTER SESSION SET CURRENT_SCHEMA = "{db_name}" ') if re.match(r"^explain", sql, re.I): sql = sql else: sql = f"explain plan for {sql}" - sql = sql.rstrip(';') + sql = sql.rstrip(";") cursor.execute(sql) # 获取影响行数 cursor.execute(f"select CARDINALITY from SYS.PLAN_TABLE$ where id = 0") rows = cursor.fetchone() conn.rollback() if not rows: - result['rows'] = 0 + result["rows"] = 0 else: - result['rows'] = rows[0] + result["rows"] = rows[0] except Exception as e: logger.warning(f"Oracle 语句执行报错,语句:{sql},错误信息{traceback.format_exc()}") - result['msg'] = str(e) + result["msg"] = str(e) finally: if close_conn: self.close() return result - def query_check(self, db_name=None, sql=''): + def query_check(self, db_name=None, sql=""): # 查询语句的检查、注释去除、切分 - result = {'msg': '', 'bad_query': False, 'filtered_sql': sql, 'has_star': False} - keyword_warning = '' + result = {"msg": "", "bad_query": False, "filtered_sql": sql, "has_star": False} + keyword_warning = "" star_patter = r"(^|,|\s)\*(\s|\(|$)" # 删除注释语句,进行语法判断,执行第一条有效sql try: sql = sqlparse.format(sql, strip_comments=True) sql = sqlparse.split(sql)[0] - result['filtered_sql'] = re.sub(r';$', '', sql.strip()) + result["filtered_sql"] = re.sub(r";$", "", sql.strip()) sql_lower = sql.lower() except IndexError: - result['bad_query'] = True - result['msg'] = '没有有效的SQL语句' + result["bad_query"] = True + result["msg"] = "没有有效的SQL语句" return result if re.match(r"^select|^with|^explain", sql_lower) is None: - result['bad_query'] = True - result['msg'] = '不支持语法!' + result["bad_query"] = True + result["msg"] = "不支持语法!" return result if re.search(star_patter, sql_lower) is not None: - keyword_warning += '禁止使用 * 关键词\n' - result['has_star'] = True - if result.get('bad_query') or result.get('has_star'): - result['msg'] = keyword_warning + keyword_warning += "禁止使用 * 关键词\n" + result["has_star"] = True + if result.get("bad_query") or result.get("has_star"): + result["msg"] = keyword_warning return result - def filter_sql(self, sql='', limit_num=0): + def filter_sql(self, sql="", limit_num=0): sql_lower = sql.lower() # 对查询sql增加limit限制 if re.match(r"^select|^with", sql_lower) and not ( - re.match(r"^select\s+sql_audit.", sql_lower) and sql_lower.find(" sql_audit where rownum <= ") != -1): + re.match(r"^select\s+sql_audit.", sql_lower) + and sql_lower.find(" sql_audit where rownum <= ") != -1 + ): sql = f"select sql_audit.* from ({sql.rstrip(';')}) sql_audit where rownum <= {limit_num}" return sql.strip() - def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs): - """返回 ResultSet """ + def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs): + """返回 ResultSet""" result_set = ResultSet(full_sql=sql) try: conn = self.get_connection() cursor = conn.cursor() if db_name: - cursor.execute(f" ALTER SESSION SET CURRENT_SCHEMA = \"{db_name}\" ") - sql = sql.rstrip(';') + cursor.execute(f' ALTER SESSION SET CURRENT_SCHEMA = "{db_name}" ') + sql = sql.rstrip(";") # 支持oralce查询SQL执行计划语句 if re.match(r"^explain", sql, re.I): cursor.execute(sql) @@ -543,9 +647,12 @@ def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs): cursor.execute(sql) fields = cursor.description if any(x[1] == cx_Oracle.CLOB for x in fields): - rows = [tuple([(c.read() if type(c) == cx_Oracle.LOB else c) for c in r]) for r in cursor] + rows = [ + tuple([(c.read() if type(c) == cx_Oracle.LOB else c) for c in r]) + for r in cursor + ] if int(limit_num) > 0: - rows = rows[0:int(limit_num)] + rows = rows[0 : int(limit_num)] else: if int(limit_num) > 0: rows = cursor.fetchmany(int(limit_num)) @@ -562,7 +669,7 @@ def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs): self.close() return result_set - def query_masking(self, db_name=None, sql='', resultset=None): + def query_masking(self, db_name=None, sql="", resultset=None): """简单字段脱敏规则, 仅对select有效""" if re.match(r"^select", sql, re.I): filtered_result = simple_column_mask(self.instance, resultset) @@ -571,7 +678,7 @@ def query_masking(self, db_name=None, sql='', resultset=None): filtered_result = resultset return filtered_result - def execute_check(self, db_name=None, sql='', close_conn=True): + def execute_check(self, db_name=None, sql="", close_conn=True): """ 上线单执行前的检查, 返回Review set update by Jan.song 20200302 @@ -585,81 +692,118 @@ def execute_check(self, db_name=None, sql='', close_conn=True): line = 1 # 保存SQL中的新建对象 object_name_list = set() - critical_ddl_regex = config.get('critical_ddl_regex', '') + critical_ddl_regex = config.get("critical_ddl_regex", "") p = re.compile(critical_ddl_regex) check_result.syntax_type = 2 # TODO 工单类型 0、其他 1、DDL,2、DML try: sqlitemList = get_full_sqlitem_list(sql, db_name) for sqlitem in sqlitemList: - sql_lower = sqlitem.statement.lower().rstrip(';') - sql_nolower = sqlitem.statement.rstrip(';') + sql_lower = sqlitem.statement.lower().rstrip(";") + sql_nolower = sqlitem.statement.rstrip(";") # 禁用语句 if re.match(r"^select|^with|^explain", sql_lower): - result = ReviewResult(id=line, errlevel=2, - stagestatus='驳回不支持语句', - errormessage='仅支持DML和DDL语句,查询语句请使用SQL查询功能!', - sql=sqlitem.statement) + result = ReviewResult( + id=line, + errlevel=2, + stagestatus="驳回不支持语句", + errormessage="仅支持DML和DDL语句,查询语句请使用SQL查询功能!", + sql=sqlitem.statement, + ) # 高危语句 elif critical_ddl_regex and p.match(sql_lower.strip()): - result = ReviewResult(id=line, errlevel=2, - stagestatus='驳回高危SQL', - errormessage='禁止提交匹配' + critical_ddl_regex + '条件的语句!', - sql=sqlitem.statement) + result = ReviewResult( + id=line, + errlevel=2, + stagestatus="驳回高危SQL", + errormessage="禁止提交匹配" + critical_ddl_regex + "条件的语句!", + sql=sqlitem.statement, + ) # 驳回未带where数据修改语句,如确实需做全部删除或更新,显示的带上where 1=1 - elif re.match(r"^update((?!where).)*$|^delete((?!where).)*$", sql_lower): - result = ReviewResult(id=line, errlevel=2, - stagestatus='驳回未带where数据修改', - errormessage='数据修改需带where条件!', - sql=sqlitem.statement) + elif re.match( + r"^update((?!where).)*$|^delete((?!where).)*$", sql_lower + ): + result = ReviewResult( + id=line, + errlevel=2, + stagestatus="驳回未带where数据修改", + errormessage="数据修改需带where条件!", + sql=sqlitem.statement, + ) # 驳回事务控制,会话控制SQL elif re.match(r"^set|^rollback|^exit", sql_lower): - result = ReviewResult(id=line, errlevel=2, - stagestatus='SQL中不能包含^set|^rollback|^exit', - errormessage='SQL中不能包含^set|^rollback|^exit', - sql=sqlitem.statement) + result = ReviewResult( + id=line, + errlevel=2, + stagestatus="SQL中不能包含^set|^rollback|^exit", + errormessage="SQL中不能包含^set|^rollback|^exit", + sql=sqlitem.statement, + ) # 通过explain对SQL做语法语义检查 - elif re.match(explain_re, sql_lower) and sqlitem.stmt_type == 'SQL': - if self.check_create_index_table(db_name=db_name, sql=sql_lower, object_name_list=object_name_list): + elif re.match(explain_re, sql_lower) and sqlitem.stmt_type == "SQL": + if self.check_create_index_table( + db_name=db_name, + sql=sql_lower, + object_name_list=object_name_list, + ): object_name = self.get_sql_first_object_name(sql=sql_lower) - if '.' in object_name: + if "." in object_name: object_name = object_name else: object_name = f"""{db_name}.{object_name}""" object_name_list.add(object_name) - result = ReviewResult(id=line, errlevel=1, - stagestatus='WARNING:新建表的新建索引语句暂无法检测!', - errormessage='WARNING:新建表的新建索引语句暂无法检测!', - stmt_type=sqlitem.stmt_type, - object_owner=sqlitem.object_owner, - object_type=sqlitem.object_type, - object_name=sqlitem.object_name, - sql=sqlitem.statement) - elif len(object_name_list) > 0 and self.get_dml_table(db_name=db_name, sql=sql_lower, - object_name_list=object_name_list): - result = ReviewResult(id=line, errlevel=1, - stagestatus='WARNING:新建表的数据修改暂无法检测!', - errormessage='WARNING:新建表的数据修改暂无法检测!', - stmt_type=sqlitem.stmt_type, - object_owner=sqlitem.object_owner, - object_type=sqlitem.object_type, - object_name=sqlitem.object_name, - sql=sqlitem.statement) + result = ReviewResult( + id=line, + errlevel=1, + stagestatus="WARNING:新建表的新建索引语句暂无法检测!", + errormessage="WARNING:新建表的新建索引语句暂无法检测!", + stmt_type=sqlitem.stmt_type, + object_owner=sqlitem.object_owner, + object_type=sqlitem.object_type, + object_name=sqlitem.object_name, + sql=sqlitem.statement, + ) + elif len(object_name_list) > 0 and self.get_dml_table( + db_name=db_name, + sql=sql_lower, + object_name_list=object_name_list, + ): + result = ReviewResult( + id=line, + errlevel=1, + stagestatus="WARNING:新建表的数据修改暂无法检测!", + errormessage="WARNING:新建表的数据修改暂无法检测!", + stmt_type=sqlitem.stmt_type, + object_owner=sqlitem.object_owner, + object_type=sqlitem.object_type, + object_name=sqlitem.object_name, + sql=sqlitem.statement, + ) else: - result_set = self.explain_check(db_name=db_name, sql=sqlitem.statement, close_conn=False) - if result_set['msg']: - result = ReviewResult(id=line, errlevel=2, - stagestatus='explain语法检查未通过!', - errormessage=result_set['msg'], - sql=sqlitem.statement) + result_set = self.explain_check( + db_name=db_name, sql=sqlitem.statement, close_conn=False + ) + if result_set["msg"]: + result = ReviewResult( + id=line, + errlevel=2, + stagestatus="explain语法检查未通过!", + errormessage=result_set["msg"], + sql=sqlitem.statement, + ) else: # 对create table\create index\create unique index语法做对象存在性检测 - if re.match(r"^create\s+table|^create\s+index|^create\s+unique\s+index", sql_lower): - object_name = self.get_sql_first_object_name(sql=sql_nolower) + if re.match( + r"^create\s+table|^create\s+index|^create\s+unique\s+index", + sql_lower, + ): + object_name = self.get_sql_first_object_name( + sql=sql_nolower + ) # 保存create对象对后续SQL做存在性判断 - if '.' in object_name: - schema_name = object_name.split('.')[0] - object_name = object_name.split('.')[1] + if "." in object_name: + schema_name = object_name.split(".")[0] + object_name = object_name.split(".")[1] if '"' in schema_name: schema_name = schema_name if '"' not in object_name: @@ -669,72 +813,91 @@ def execute_check(self, db_name=None, sql='', close_conn=True): if '"' not in object_name: object_name = object_name.upper() else: - schema_name = ('"' + db_name + '"') + schema_name = '"' + db_name + '"' if '"' not in object_name: object_name = object_name.upper() object_name = f"""{schema_name}.{object_name}""" - if self.object_name_check(db_name=db_name, - object_name=object_name) or object_name in object_name_list: - result = ReviewResult(id=line, errlevel=2, - stagestatus=f"""{object_name}对象已经存在!""", - errormessage=f"""{object_name}对象已经存在!""", - sql=sqlitem.statement) + if ( + self.object_name_check( + db_name=db_name, object_name=object_name + ) + or object_name in object_name_list + ): + result = ReviewResult( + id=line, + errlevel=2, + stagestatus=f"""{object_name}对象已经存在!""", + errormessage=f"""{object_name}对象已经存在!""", + sql=sqlitem.statement, + ) else: object_name_list.add(object_name) - if result_set['rows'] > 1000: - result = ReviewResult(id=line, errlevel=1, - stagestatus='影响行数大于1000,请关注', - errormessage='影响行数大于1000,请关注', - sql=sqlitem.statement, - stmt_type=sqlitem.stmt_type, - object_owner=sqlitem.object_owner, - object_type=sqlitem.object_type, - object_name=sqlitem.object_name, - affected_rows=result_set['rows'], - execute_time=0, ) + if result_set["rows"] > 1000: + result = ReviewResult( + id=line, + errlevel=1, + stagestatus="影响行数大于1000,请关注", + errormessage="影响行数大于1000,请关注", + sql=sqlitem.statement, + stmt_type=sqlitem.stmt_type, + object_owner=sqlitem.object_owner, + object_type=sqlitem.object_type, + object_name=sqlitem.object_name, + affected_rows=result_set["rows"], + execute_time=0, + ) else: - result = ReviewResult(id=line, errlevel=0, - stagestatus='Audit completed', - errormessage='None', - sql=sqlitem.statement, - stmt_type=sqlitem.stmt_type, - object_owner=sqlitem.object_owner, - object_type=sqlitem.object_type, - object_name=sqlitem.object_name, - affected_rows=result_set['rows'], - execute_time=0, ) + result = ReviewResult( + id=line, + errlevel=0, + stagestatus="Audit completed", + errormessage="None", + sql=sqlitem.statement, + stmt_type=sqlitem.stmt_type, + object_owner=sqlitem.object_owner, + object_type=sqlitem.object_type, + object_name=sqlitem.object_name, + affected_rows=result_set["rows"], + execute_time=0, + ) else: - if result_set['rows'] > 1000: - result = ReviewResult(id=line, errlevel=1, - stagestatus='影响行数大于1000,请关注', - errormessage='影响行数大于1000,请关注', - sql=sqlitem.statement, - stmt_type=sqlitem.stmt_type, - object_owner=sqlitem.object_owner, - object_type=sqlitem.object_type, - object_name=sqlitem.object_name, - affected_rows=result_set['rows'], - execute_time=0, ) + if result_set["rows"] > 1000: + result = ReviewResult( + id=line, + errlevel=1, + stagestatus="影响行数大于1000,请关注", + errormessage="影响行数大于1000,请关注", + sql=sqlitem.statement, + stmt_type=sqlitem.stmt_type, + object_owner=sqlitem.object_owner, + object_type=sqlitem.object_type, + object_name=sqlitem.object_name, + affected_rows=result_set["rows"], + execute_time=0, + ) else: - result = ReviewResult(id=line, errlevel=0, - stagestatus='Audit completed', - errormessage='None', - sql=sqlitem.statement, - stmt_type=sqlitem.stmt_type, - object_owner=sqlitem.object_owner, - object_type=sqlitem.object_type, - object_name=sqlitem.object_name, - affected_rows=result_set['rows'], - execute_time=0, ) + result = ReviewResult( + id=line, + errlevel=0, + stagestatus="Audit completed", + errormessage="None", + sql=sqlitem.statement, + stmt_type=sqlitem.stmt_type, + object_owner=sqlitem.object_owner, + object_type=sqlitem.object_type, + object_name=sqlitem.object_name, + affected_rows=result_set["rows"], + execute_time=0, + ) # 其它无法用explain判断的语句 else: # 对alter table做对象存在性检查 if re.match(r"^alter\s+table\s", sql_lower): object_name = self.get_sql_first_object_name(sql=sql_nolower) - if '.' in object_name: - schema_name = object_name.split('.')[0] - object_name = object_name.split('.')[1] + if "." in object_name: + schema_name = object_name.split(".")[0] + object_name = object_name.split(".")[1] if '"' in schema_name: schema_name = schema_name if '"' not in object_name: @@ -744,34 +907,44 @@ def execute_check(self, db_name=None, sql='', close_conn=True): if '"' not in object_name: object_name = object_name.upper() else: - schema_name = ('"' + db_name + '"') + schema_name = '"' + db_name + '"' if '"' not in object_name: object_name = object_name.upper() object_name = f"""{schema_name}.{object_name}""" - if not self.object_name_check(db_name=db_name, - object_name=object_name) and object_name not in object_name_list: - result = ReviewResult(id=line, errlevel=2, - stagestatus=f"""{object_name}对象不存在!""", - errormessage=f"""{object_name}对象不存在!""", - sql=sqlitem.statement) + if ( + not self.object_name_check( + db_name=db_name, object_name=object_name + ) + and object_name not in object_name_list + ): + result = ReviewResult( + id=line, + errlevel=2, + stagestatus=f"""{object_name}对象不存在!""", + errormessage=f"""{object_name}对象不存在!""", + sql=sqlitem.statement, + ) else: - result = ReviewResult(id=line, errlevel=1, - stagestatus='当前平台,此语法不支持审核!', - errormessage='当前平台,此语法不支持审核!', - sql=sqlitem.statement, - stmt_type=sqlitem.stmt_type, - object_owner=sqlitem.object_owner, - object_type=sqlitem.object_type, - object_name=sqlitem.object_name, - affected_rows=0, - execute_time=0, ) + result = ReviewResult( + id=line, + errlevel=1, + stagestatus="当前平台,此语法不支持审核!", + errormessage="当前平台,此语法不支持审核!", + sql=sqlitem.statement, + stmt_type=sqlitem.stmt_type, + object_owner=sqlitem.object_owner, + object_type=sqlitem.object_type, + object_name=sqlitem.object_name, + affected_rows=0, + execute_time=0, + ) # 对create做对象存在性检查 elif re.match(r"^create", sql_lower): object_name = self.get_sql_first_object_name(sql=sql_nolower) - if '.' in object_name: - schema_name = object_name.split('.')[0] - object_name = object_name.split('.')[1] + if "." in object_name: + schema_name = object_name.split(".")[0] + object_name = object_name.split(".")[1] if '"' in schema_name: schema_name = schema_name if '"' not in object_name: @@ -781,47 +954,62 @@ def execute_check(self, db_name=None, sql='', close_conn=True): if '"' not in object_name: object_name = object_name.upper() else: - schema_name = ('"' + db_name + '"') + schema_name = '"' + db_name + '"' if '"' not in object_name: object_name = object_name.upper() object_name = f"""{schema_name}.{object_name}""" - if self.object_name_check(db_name=db_name, - object_name=object_name) or object_name in object_name_list: - result = ReviewResult(id=line, errlevel=2, - stagestatus=f"""{object_name}对象已经存在!""", - errormessage=f"""{object_name}对象已经存在!""", - sql=sqlitem.statement) + if ( + self.object_name_check( + db_name=db_name, object_name=object_name + ) + or object_name in object_name_list + ): + result = ReviewResult( + id=line, + errlevel=2, + stagestatus=f"""{object_name}对象已经存在!""", + errormessage=f"""{object_name}对象已经存在!""", + sql=sqlitem.statement, + ) else: object_name_list.add(object_name) - result = ReviewResult(id=line, errlevel=1, - stagestatus='当前平台,此语法不支持审核!', - errormessage='当前平台,此语法不支持审核!', - sql=sqlitem.statement, - stmt_type=sqlitem.stmt_type, - object_owner=sqlitem.object_owner, - object_type=sqlitem.object_type, - object_name=sqlitem.object_name, - affected_rows=0, - execute_time=0, ) + result = ReviewResult( + id=line, + errlevel=1, + stagestatus="当前平台,此语法不支持审核!", + errormessage="当前平台,此语法不支持审核!", + sql=sqlitem.statement, + stmt_type=sqlitem.stmt_type, + object_owner=sqlitem.object_owner, + object_type=sqlitem.object_type, + object_name=sqlitem.object_name, + affected_rows=0, + execute_time=0, + ) else: - result = ReviewResult(id=line, errlevel=1, - stagestatus='当前平台,此语法不支持审核!', - errormessage='当前平台,此语法不支持审核!', - sql=sqlitem.statement, - stmt_type=sqlitem.stmt_type, - object_owner=sqlitem.object_owner, - object_type=sqlitem.object_type, - object_name=sqlitem.object_name, - affected_rows=0, - execute_time=0, ) + result = ReviewResult( + id=line, + errlevel=1, + stagestatus="当前平台,此语法不支持审核!", + errormessage="当前平台,此语法不支持审核!", + sql=sqlitem.statement, + stmt_type=sqlitem.stmt_type, + object_owner=sqlitem.object_owner, + object_type=sqlitem.object_type, + object_name=sqlitem.object_name, + affected_rows=0, + execute_time=0, + ) # 判断工单类型 - if get_syntax_type(sql=sqlitem.statement, db_type='oracle') == 'DDL': + if get_syntax_type(sql=sqlitem.statement, db_type="oracle") == "DDL": check_result.syntax_type = 1 check_result.rows += [result] line += 1 except Exception as e: - logger.warning(f"Oracle 语句执行报错,第{line}个SQL:{sqlitem.statement},错误信息{traceback.format_exc()}") + logger.warning( + f"Oracle 语句执行报错,第{line}个SQL:{sqlitem.statement},错误信息{traceback.format_exc()}" + ) check_result.error = str(e) finally: if close_conn: @@ -836,9 +1024,9 @@ def execute_check(self, db_name=None, sql='', close_conn=True): def execute_workflow(self, workflow, close_conn=True): """执行上线单,返回Review set - 原来的逻辑是根据 sql_content简单来分割SQL,进而再执行这些SQL - 新的逻辑变更为根据审核结果中记录的sql来执行, - 如果是PLSQL存储过程等对象定义操作,还需检查确认新建对象是否编译通过! + 原来的逻辑是根据 sql_content简单来分割SQL,进而再执行这些SQL + 新的逻辑变更为根据审核结果中记录的sql来执行, + 如果是PLSQL存储过程等对象定义操作,还需检查确认新建对象是否编译通过! """ review_content = workflow.sqlworkflowcontent.review_content review_result = json.loads(review_content) @@ -861,7 +1049,7 @@ def execute_workflow(self, workflow, close_conn=True): for sqlitem in sqlitemList: statement = sqlitem.statement if sqlitem.stmt_type == "SQL": - statement = statement.rstrip(';') + statement = statement.rstrip(";") # 如果是DDL的工单,获取对象的原定义,并保存到sql_rollback.undo_sql # 需要授权 grant execute on dbms_metadata to xxxxx if workflow.syntax_type == 1: @@ -874,13 +1062,18 @@ def execute_workflow(self, workflow, close_conn=True): metdata_back_flag = self.metdata_backup(workflow, cursor, statement) with FuncTimer() as t: - if statement != '': + if statement != "": cursor.execute(statement) conn.commit() rowcount = cursor.rowcount stagestatus = "Execute Successfully" - if sqlitem.stmt_type == "PLSQL" and sqlitem.object_name and sqlitem.object_name != 'ANONYMOUS' and sqlitem.object_name != '': + if ( + sqlitem.stmt_type == "PLSQL" + and sqlitem.object_name + and sqlitem.object_name != "ANONYMOUS" + and sqlitem.object_name != "" + ): query_obj_sql = f"""SELECT OBJECT_NAME, STATUS, TO_CHAR(LAST_DDL_TIME, 'YYYY-MM-DD HH24:MI:SS') FROM ALL_OBJECTS WHERE OWNER = '{sqlitem.object_owner}' AND OBJECT_NAME = '{sqlitem.object_name}' @@ -890,49 +1083,69 @@ def execute_workflow(self, workflow, close_conn=True): if row: status = row[1] if status and status == "INVALID": - stagestatus = "Compile Failed. Object " + sqlitem.object_owner + "." + sqlitem.object_name + " is invalid." + stagestatus = ( + "Compile Failed. Object " + + sqlitem.object_owner + + "." + + sqlitem.object_name + + " is invalid." + ) else: - stagestatus = "Compile Failed. Object " + sqlitem.object_owner + "." + sqlitem.object_name + " doesn't exist." + stagestatus = ( + "Compile Failed. Object " + + sqlitem.object_owner + + "." + + sqlitem.object_name + + " doesn't exist." + ) if stagestatus != "Execute Successfully": raise Exception(stagestatus) - execute_result.rows.append(ReviewResult( - id=line, - errlevel=0, - stagestatus=stagestatus, - errormessage='None', - sql=statement, - affected_rows=cursor.rowcount, - execute_time=t.cost, - )) + execute_result.rows.append( + ReviewResult( + id=line, + errlevel=0, + stagestatus=stagestatus, + errormessage="None", + sql=statement, + affected_rows=cursor.rowcount, + execute_time=t.cost, + ) + ) line += 1 except Exception as e: - logger.warning(f"Oracle命令执行报错,工单id:{workflow.id},语句:{statement or sql}, 错误信息:{traceback.format_exc()}") + logger.warning( + f"Oracle命令执行报错,工单id:{workflow.id},语句:{statement or sql}, 错误信息:{traceback.format_exc()}" + ) execute_result.error = str(e) # conn.rollback() # 追加当前报错语句信息到执行结果中 - execute_result.rows.append(ReviewResult( - id=line, - errlevel=2, - stagestatus='Execute Failed', - errormessage=f'异常信息:{e}', - sql=statement or sql, - affected_rows=0, - execute_time=0, - )) - line += 1 - # 报错语句后面的语句标记为审核通过、未执行,追加到执行结果中 - for sqlitem in sqlitemList[line - 1:]: - execute_result.rows.append(ReviewResult( + execute_result.rows.append( + ReviewResult( id=line, - errlevel=0, - stagestatus='Audit completed', - errormessage=f'前序语句失败, 未执行', - sql=sqlitem.statement, + errlevel=2, + stagestatus="Execute Failed", + errormessage=f"异常信息:{e}", + sql=statement or sql, affected_rows=0, execute_time=0, - )) + ) + ) + line += 1 + # 报错语句后面的语句标记为审核通过、未执行,追加到执行结果中 + for sqlitem in sqlitemList[line - 1 :]: + execute_result.rows.append( + ReviewResult( + id=line, + errlevel=0, + stagestatus="Audit completed", + errormessage=f"前序语句失败, 未执行", + sql=sqlitem.statement, + affected_rows=0, + execute_time=0, + ) + ) line += 1 finally: # 备份 @@ -941,9 +1154,16 @@ def execute_workflow(self, workflow, close_conn=True): cursor.execute(f"select sysdate from dual") rows = cursor.fetchone() end_time = rows[0] - self.backup(workflow, cursor=cursor, begin_time=begin_time, end_time=end_time) + self.backup( + workflow, + cursor=cursor, + begin_time=begin_time, + end_time=end_time, + ) except Exception as e: - logger.error(f"Oracle工单备份异常,工单id:{workflow.id}, 错误信息:{traceback.format_exc()}") + logger.error( + f"Oracle工单备份异常,工单id:{workflow.id}, 错误信息:{traceback.format_exc()}" + ) if close_conn: self.close() return execute_result @@ -967,32 +1187,34 @@ def backup(self, workflow, cursor, begin_time, end_time): backup_cursor = conn.cursor() backup_cursor.execute(f"""create database if not exists ora_backup;""") backup_cursor.execute(f"use ora_backup;") - backup_cursor.execute(f"""CREATE TABLE if not exists `sql_rollback` ( + backup_cursor.execute( + f"""CREATE TABLE if not exists `sql_rollback` ( `id` bigint(20) NOT NULL AUTO_INCREMENT, `redo_sql` mediumtext, `undo_sql` mediumtext, `workflow_id` bigint(20) NOT NULL, PRIMARY KEY (`id`), key `idx_sql_rollback_01` (`workflow_id`) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;""") + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;""" + ) # 使用logminer抓取回滚SQL - logmnr_start_sql = f'''begin + logmnr_start_sql = f"""begin dbms_logmnr.start_logmnr( starttime=>to_date('{begin_time}','yyyy-mm-dd hh24:mi:ss'), endtime=>to_date('{end_time}','yyyy/mm/dd hh24:mi:ss'), options=>dbms_logmnr.dict_from_online_catalog + dbms_logmnr.continuous_mine); - end;''' - undo_sql = f'''select + end;""" + undo_sql = f"""select xmlagg(xmlparse(content sql_redo wellformed) order by scn,rs_id,ssn,rownum).getclobval() , xmlagg(xmlparse(content sql_undo wellformed) order by scn,rs_id,ssn,rownum).getclobval() from v$logmnr_contents where SEG_OWNER not in ('SYS') and session# = (select sid from v$mystat where rownum = 1) and serial# = (select serial# from v$session s where s.sid = (select sid from v$mystat where rownum = 1 )) - group by scn,rs_id,ssn order by scn desc''' - logmnr_end_sql = f'''begin + group by scn,rs_id,ssn order by scn desc""" + logmnr_end_sql = f"""begin dbms_logmnr.end_logmnr; - end;''' + end;""" cursor.execute(logmnr_start_sql) cursor.execute(undo_sql) rows = cursor.fetchall() @@ -1002,7 +1224,7 @@ def backup(self, workflow, cursor, begin_time, end_time): redo_sql = f"{row[0]}" redo_sql = redo_sql.replace("'", "\\'") if row[1] is None: - undo_sql = f' ' + undo_sql = f" " else: undo_sql = f"{row[1]}" undo_sql = undo_sql.replace("'", "\\'") @@ -1032,19 +1254,21 @@ def metdata_backup(self, workflow, cursor, redo_sql): backup_cursor = conn.cursor() backup_cursor.execute(f"""create database if not exists ora_backup;""") backup_cursor.execute(f"use ora_backup;") - backup_cursor.execute(f"""CREATE TABLE if not exists `sql_rollback` ( + backup_cursor.execute( + f"""CREATE TABLE if not exists `sql_rollback` ( `id` bigint(20) NOT NULL AUTO_INCREMENT, `redo_sql` mediumtext, `undo_sql` mediumtext, `workflow_id` bigint(20) NOT NULL, PRIMARY KEY (`id`), key `idx_sql_rollback_01` (`workflow_id`) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;""") + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;""" + ) rows = cursor.fetchall() if len(rows) > 0: for row in rows: if row[0] is None: - undo_sql = f' ' + undo_sql = f" " else: undo_sql = f"{row[0]}" undo_sql = undo_sql.replace("'", "\\'") @@ -1090,7 +1314,7 @@ def get_rollback(self, workflow): conn.close() return list_backup_sql - def sqltuningadvisor(self, db_name=None, sql='', close_conn=True, **kwargs): + def sqltuningadvisor(self, db_name=None, sql="", close_conn=True, **kwargs): """ add by Jan.song 20200421 使用DBMS_SQLTUNE包做sql tuning支持 @@ -1098,14 +1322,14 @@ def sqltuningadvisor(self, db_name=None, sql='', close_conn=True, **kwargs): 返回 ResultSet """ result_set = ResultSet(full_sql=sql) - task_name = 'sqlaudit' + f'''{threading.currentThread().ident}''' + task_name = "sqlaudit" + f"""{threading.currentThread().ident}""" task_begin = 0 try: conn = self.get_connection() cursor = conn.cursor() - sql = sql.rstrip(';') + sql = sql.rstrip(";") # 创建分析任务 - create_task_sql = f'''DECLARE + create_task_sql = f"""DECLARE my_task_name VARCHAR2(30); my_sqltext CLOB; BEGIN @@ -1118,15 +1342,20 @@ def sqltuningadvisor(self, db_name=None, sql='', close_conn=True, **kwargs): task_name => '{task_name}', description => 'tuning'); DBMS_SQLTUNE.EXECUTE_TUNING_TASK( task_name => '{task_name}'); - END;''' + END;""" task_begin = 1 cursor.execute(create_task_sql) # 获取分析报告 - get_task_sql = f'''select DBMS_SQLTUNE.REPORT_TUNING_TASK( '{task_name}') from dual''' + get_task_sql = ( + f"""select DBMS_SQLTUNE.REPORT_TUNING_TASK( '{task_name}') from dual""" + ) cursor.execute(get_task_sql) fields = cursor.description if any(x[1] == cx_Oracle.CLOB for x in fields): - rows = [tuple([(c.read() if type(c) == cx_Oracle.LOB else c) for c in r]) for r in cursor] + rows = [ + tuple([(c.read() if type(c) == cx_Oracle.LOB else c) for c in r]) + for r in cursor + ] else: rows = cursor.fetchall() result_set.column_list = [i[0] for i in fields] if fields else [] @@ -1138,10 +1367,10 @@ def sqltuningadvisor(self, db_name=None, sql='', close_conn=True, **kwargs): finally: # 结束分析任务 if task_begin == 1: - end_sql = f'''DECLARE + end_sql = f"""DECLARE begin dbms_sqltune.drop_tuning_task('{task_name}'); - end;''' + end;""" cursor.execute(end_sql) if close_conn: self.close() diff --git a/sql/engines/pgsql.py b/sql/engines/pgsql.py index 9b4de94bf7..0146469743 100644 --- a/sql/engines/pgsql.py +++ b/sql/engines/pgsql.py @@ -18,29 +18,35 @@ from .models import ResultSet, ReviewSet, ReviewResult from sql.utils.data_masking import simple_column_mask -__author__ = 'hhyo、yyukai' +__author__ = "hhyo、yyukai" -logger = logging.getLogger('default') +logger = logging.getLogger("default") class PgSQLEngine(EngineBase): test_query = "SELECT 1" def get_connection(self, db_name=None): - db_name = db_name or self.db_name or 'postgres' + db_name = db_name or self.db_name or "postgres" if self.conn: return self.conn - self.conn = psycopg2.connect(host=self.host, port=self.port, user=self.user, - password=self.password, dbname=db_name, connect_timeout=10) + self.conn = psycopg2.connect( + host=self.host, + port=self.port, + user=self.user, + password=self.password, + dbname=db_name, + connect_timeout=10, + ) return self.conn @property def name(self): - return 'PgSQL' + return "PgSQL" @property def info(self): - return 'PgSQL engine' + return "PgSQL engine" def get_all_databases(self): """ @@ -48,7 +54,11 @@ def get_all_databases(self): :return: """ result = self.query(sql=f"SELECT datname FROM pg_database;") - db_list = [row[0] for row in result.rows if row[0] not in ['postgres', 'template0', 'template1']] + db_list = [ + row[0] + for row in result.rows + if row[0] not in ["postgres", "template0", "template1"] + ] result.rows = db_list return result @@ -57,10 +67,21 @@ def get_all_schemas(self, db_name, **kwargs): 获取模式列表 :return: """ - result = self.query(db_name=db_name, sql=f"select schema_name from information_schema.schemata;") - schema_list = [row[0] for row in result.rows if row[0] not in ['information_schema', - 'pg_catalog', 'pg_toast_temp_1', - 'pg_temp_1', 'pg_toast']] + result = self.query( + db_name=db_name, sql=f"select schema_name from information_schema.schemata;" + ) + schema_list = [ + row[0] + for row in result.rows + if row[0] + not in [ + "information_schema", + "pg_catalog", + "pg_toast_temp_1", + "pg_temp_1", + "pg_toast", + ] + ] result.rows = schema_list return result @@ -71,12 +92,12 @@ def get_all_tables(self, db_name, **kwargs): :param schema_name: :return: """ - schema_name = kwargs.get('schema_name') + schema_name = kwargs.get("schema_name") sql = f"""SELECT table_name FROM information_schema.tables where table_schema ='{schema_name}';""" result = self.query(db_name=db_name, sql=sql) - tb_list = [row[0] for row in result.rows if row[0] not in ['test']] + tb_list = [row[0] for row in result.rows if row[0] not in ["test"]] result.rows = tb_list return result @@ -88,7 +109,7 @@ def get_all_columns_by_tb(self, db_name, tb_name, **kwargs): :param schema_name: :return: """ - schema_name = kwargs.get('schema_name') + schema_name = kwargs.get("schema_name") sql = f"""SELECT column_name FROM information_schema.columns where table_name='{tb_name}' @@ -106,7 +127,7 @@ def describe_table(self, db_name, tb_name, **kwargs): :param schema_name: :return: """ - schema_name = kwargs.get('schema_name') + schema_name = kwargs.get("schema_name") sql = f"""select col.column_name, col.data_type, @@ -126,32 +147,32 @@ def describe_table(self, db_name, tb_name, **kwargs): result = self.query(db_name=db_name, schema_name=schema_name, sql=sql) return result - def query_check(self, db_name=None, sql=''): + def query_check(self, db_name=None, sql=""): # 查询语句的检查、注释去除、切分 - result = {'msg': '', 'bad_query': False, 'filtered_sql': sql, 'has_star': False} + result = {"msg": "", "bad_query": False, "filtered_sql": sql, "has_star": False} # 删除注释语句,进行语法判断,执行第一条有效sql try: sql = sqlparse.format(sql, strip_comments=True) sql = sqlparse.split(sql)[0] - result['filtered_sql'] = sql.strip() + result["filtered_sql"] = sql.strip() except IndexError: - result['bad_query'] = True - result['msg'] = '没有有效的SQL语句' + result["bad_query"] = True + result["msg"] = "没有有效的SQL语句" if re.match(r"^select", sql, re.I) is None: - result['bad_query'] = True - result['msg'] = '不支持的查询语法类型!' - if '*' in sql: - result['has_star'] = True - result['msg'] = 'SQL语句中含有 * ' + result["bad_query"] = True + result["msg"] = "不支持的查询语法类型!" + if "*" in sql: + result["has_star"] = True + result["msg"] = "SQL语句中含有 * " return result - def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs): - """返回 ResultSet """ - schema_name = kwargs.get('schema_name') + def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs): + """返回 ResultSet""" + schema_name = kwargs.get("schema_name") result_set = ResultSet(full_sql=sql) try: conn = self.get_connection(db_name=db_name) - max_execution_time = kwargs.get('max_execution_time', 0) + max_execution_time = kwargs.get("max_execution_time", 0) cursor = conn.cursor() try: cursor.execute(f"SET statement_timeout TO {max_execution_time};") @@ -178,16 +199,16 @@ def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs): self.close() return result_set - def filter_sql(self, sql='', limit_num=0): + def filter_sql(self, sql="", limit_num=0): # 对查询sql增加limit限制,# TODO limit改写待优化 - sql_lower = sql.lower().rstrip(';').strip() + sql_lower = sql.lower().rstrip(";").strip() if re.match(r"^select", sql_lower): if re.search(r"limit\s+(\d+)$", sql_lower) is None: if re.search(r"limit\s+\d+\s*,\s*(\d+)$", sql_lower) is None: return f"{sql.rstrip(';')} limit {limit_num};" return f"{sql.rstrip(';')};" - def query_masking(self, db_name=None, sql='', resultset=None): + def query_masking(self, db_name=None, sql="", resultset=None): """简单字段脱敏规则, 仅对select有效""" if re.match(r"^select", sql, re.I): filtered_result = simple_column_mask(self.instance, resultset) @@ -196,40 +217,49 @@ def query_masking(self, db_name=None, sql='', resultset=None): filtered_result = resultset return filtered_result - def execute_check(self, db_name=None, sql=''): + def execute_check(self, db_name=None, sql=""): """上线单执行前的检查, 返回Review set""" config = SysConfig() check_result = ReviewSet(full_sql=sql) # 禁用/高危语句检查 line = 1 - critical_ddl_regex = config.get('critical_ddl_regex', '') + critical_ddl_regex = config.get("critical_ddl_regex", "") p = re.compile(critical_ddl_regex) check_result.syntax_type = 2 # TODO 工单类型 0、其他 1、DDL,2、DML for statement in sqlparse.split(sql): statement = sqlparse.format(statement, strip_comments=True) # 禁用语句 if re.match(r"^select", statement.lower()): - result = ReviewResult(id=line, errlevel=2, - stagestatus='驳回不支持语句', - errormessage='仅支持DML和DDL语句,查询语句请使用SQL查询功能!', - sql=statement) + result = ReviewResult( + id=line, + errlevel=2, + stagestatus="驳回不支持语句", + errormessage="仅支持DML和DDL语句,查询语句请使用SQL查询功能!", + sql=statement, + ) # 高危语句 elif critical_ddl_regex and p.match(statement.strip().lower()): - result = ReviewResult(id=line, errlevel=2, - stagestatus='驳回高危SQL', - errormessage='禁止提交匹配' + critical_ddl_regex + '条件的语句!', - sql=statement) + result = ReviewResult( + id=line, + errlevel=2, + stagestatus="驳回高危SQL", + errormessage="禁止提交匹配" + critical_ddl_regex + "条件的语句!", + sql=statement, + ) # 正常语句 else: - result = ReviewResult(id=line, errlevel=0, - stagestatus='Audit completed', - errormessage='None', - sql=statement, - affected_rows=0, - execute_time=0, ) + result = ReviewResult( + id=line, + errlevel=0, + stagestatus="Audit completed", + errormessage="None", + sql=statement, + affected_rows=0, + execute_time=0, + ) # 判断工单类型 - if get_syntax_type(statement) == 'DDL': + if get_syntax_type(statement) == "DDL": check_result.syntax_type = 1 check_result.rows += [result] line += 1 @@ -256,45 +286,53 @@ def execute_workflow(self, workflow, close_conn=True): cursor = conn.cursor() # 逐条执行切分语句,追加到执行结果中 for statement in split_sql: - statement = statement.rstrip(';') + statement = statement.rstrip(";") with FuncTimer() as t: cursor.execute(statement) conn.commit() - execute_result.rows.append(ReviewResult( - id=line, - errlevel=0, - stagestatus='Execute Successfully', - errormessage='None', - sql=statement, - affected_rows=cursor.rowcount, - execute_time=t.cost, - )) + execute_result.rows.append( + ReviewResult( + id=line, + errlevel=0, + stagestatus="Execute Successfully", + errormessage="None", + sql=statement, + affected_rows=cursor.rowcount, + execute_time=t.cost, + ) + ) line += 1 except Exception as e: - logger.warning(f"PGSQL命令执行报错,语句:{statement or sql}, 错误信息:{traceback.format_exc()}") + logger.warning( + f"PGSQL命令执行报错,语句:{statement or sql}, 错误信息:{traceback.format_exc()}" + ) execute_result.error = str(e) # 追加当前报错语句信息到执行结果中 - execute_result.rows.append(ReviewResult( - id=line, - errlevel=2, - stagestatus='Execute Failed', - errormessage=f'异常信息:{e}', - sql=statement or sql, - affected_rows=0, - execute_time=0, - )) - line += 1 - # 报错语句后面的语句标记为审核通过、未执行,追加到执行结果中 - for statement in split_sql[line - 1:]: - execute_result.rows.append(ReviewResult( + execute_result.rows.append( + ReviewResult( id=line, - errlevel=0, - stagestatus='Audit completed', - errormessage=f'前序语句失败, 未执行', - sql=statement, + errlevel=2, + stagestatus="Execute Failed", + errormessage=f"异常信息:{e}", + sql=statement or sql, affected_rows=0, execute_time=0, - )) + ) + ) + line += 1 + # 报错语句后面的语句标记为审核通过、未执行,追加到执行结果中 + for statement in split_sql[line - 1 :]: + execute_result.rows.append( + ReviewResult( + id=line, + errlevel=0, + stagestatus="Audit completed", + errormessage=f"前序语句失败, 未执行", + sql=statement, + affected_rows=0, + execute_time=0, + ) + ) line += 1 finally: if close_conn: diff --git a/sql/engines/phoenix.py b/sql/engines/phoenix.py index 8da6e7f025..c68c911d19 100644 --- a/sql/engines/phoenix.py +++ b/sql/engines/phoenix.py @@ -8,7 +8,7 @@ from . import EngineBase from .models import ResultSet, ReviewSet, ReviewResult -logger = logging.getLogger('default') +logger = logging.getLogger("default") class PhoenixEngine(EngineBase): @@ -18,7 +18,7 @@ def get_connection(self, db_name=None): if self.conn: return self.conn - database_url = f'http://{self.host}:{self.port}/' + database_url = f"http://{self.host}:{self.port}/" self.conn = phoenixdb.connect(database_url, autocommit=True) return self.conn @@ -50,43 +50,45 @@ def describe_table(self, db_name, tb_name, **kwargs): result = self.query(sql=sql) return result - def query_check(self, db_name=None, sql=''): + def query_check(self, db_name=None, sql=""): # 查询语句的检查、注释去除、切分 - result = {'msg': '', 'bad_query': False, 'filtered_sql': sql, 'has_star': False} - keyword_warning = '' - sql_whitelist = ['select', 'explain'] + result = {"msg": "", "bad_query": False, "filtered_sql": sql, "has_star": False} + keyword_warning = "" + sql_whitelist = ["select", "explain"] # 根据白名单list拼接pattern语句 whitelist_pattern = "^" + "|^".join(sql_whitelist) # 删除注释语句,进行语法判断,执行第一条有效sql try: sql = sql.format(sql, strip_comments=True) sql = sqlparse.split(sql)[0] - result['filtered_sql'] = sql.strip() + result["filtered_sql"] = sql.strip() # sql_lower = sql.lower() except IndexError: - result['bad_query'] = True - result['msg'] = '没有有效的SQL语句' + result["bad_query"] = True + result["msg"] = "没有有效的SQL语句" return result if re.match(whitelist_pattern, sql) is None: - result['bad_query'] = True - result['msg'] = '仅支持{}语法!'.format(','.join(sql_whitelist)) + result["bad_query"] = True + result["msg"] = "仅支持{}语法!".format(",".join(sql_whitelist)) return result - if result.get('bad_query'): - result['msg'] = keyword_warning + if result.get("bad_query"): + result["msg"] = keyword_warning return result - def filter_sql(self, sql='', limit_num=0): + def filter_sql(self, sql="", limit_num=0): """检查是SELECT语句否添加了limit限制关键词""" - sql = sql.rstrip(';').strip() + sql = sql.rstrip(";").strip() if re.match(r"^select", sql, re.I): - if not re.compile(r'limit\s+(\d+)\s*((,|offset)\s*\d+)?\s*$', re.I).search(sql): - sql = f'{sql} limit {limit_num}' + if not re.compile(r"limit\s+(\d+)\s*((,|offset)\s*\d+)?\s*$", re.I).search( + sql + ): + sql = f"{sql} limit {limit_num}" else: - sql = f'{sql};' + sql = f"{sql};" return sql.strip() - def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs): - """返回 ResultSet """ + def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs): + """返回 ResultSet""" result_set = ResultSet(full_sql=sql) try: conn = self.get_connection() @@ -109,34 +111,38 @@ def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs): self.close() return result_set - def query_masking(self, db_name=None, sql='', resultset=None): + def query_masking(self, db_name=None, sql="", resultset=None): """传入 sql语句, db名, 结果集, 返回一个脱敏后的结果集""" return resultset - def execute_check(self, db_name=None, sql=''): + def execute_check(self, db_name=None, sql=""): """上线单执行前的检查, 返回Review set""" check_result = ReviewSet(full_sql=sql) # 切分语句,追加到检测结果中,默认全部检测通过 rowid = 1 split_sql = sqlparse.split(sql) for statement in split_sql: - check_result.rows.append(ReviewResult( - id=rowid, - errlevel=0, - stagestatus='Audit completed', - errormessage='None', - sql=statement, - affected_rows=0, - execute_time=0, ) + check_result.rows.append( + ReviewResult( + id=rowid, + errlevel=0, + stagestatus="Audit completed", + errormessage="None", + sql=statement, + affected_rows=0, + execute_time=0, + ) ) rowid += 1 return check_result def execute_workflow(self, workflow): """PhoenixDB无需备份""" - return self.execute(db_name=workflow.db_name, sql=workflow.sqlworkflowcontent.sql_content) + return self.execute( + db_name=workflow.db_name, sql=workflow.sqlworkflowcontent.sql_content + ) - def execute(self, db_name=None, sql='', close_conn=True): + def execute(self, db_name=None, sql="", close_conn=True): """原生执行语句""" execute_result = ReviewSet(full_sql=sql) conn = self.get_connection(db_name=db_name) @@ -149,39 +155,45 @@ def execute(self, db_name=None, sql='', close_conn=True): except Exception as e: logger.error(f"Phoenix命令执行报错,语句:{sql}, 错误信息:{traceback.format_exc()}") execute_result.error = str(e) - execute_result.rows.append(ReviewResult( - id=rowid, - errlevel=2, - stagestatus='Execute Failed', - errormessage=f'异常信息:{e}', - sql=statement, - affected_rows=0, - execute_time=0, - )) + execute_result.rows.append( + ReviewResult( + id=rowid, + errlevel=2, + stagestatus="Execute Failed", + errormessage=f"异常信息:{e}", + sql=statement, + affected_rows=0, + execute_time=0, + ) + ) break else: - execute_result.rows.append(ReviewResult( - id=rowid, - errlevel=0, - stagestatus='Execute Successfully', - errormessage='None', - sql=statement, - affected_rows=cursor.rowcount, - execute_time=0, - )) + execute_result.rows.append( + ReviewResult( + id=rowid, + errlevel=0, + stagestatus="Execute Successfully", + errormessage="None", + sql=statement, + affected_rows=cursor.rowcount, + execute_time=0, + ) + ) rowid += 1 if execute_result.error: # 如果失败, 将剩下的部分加入结果集返回 for statement in split_sql[rowid:]: - execute_result.rows.append(ReviewResult( - id=rowid, - errlevel=2, - stagestatus='Execute Failed', - errormessage=f'前序语句失败, 未执行', - sql=statement, - affected_rows=0, - execute_time=0, - )) + execute_result.rows.append( + ReviewResult( + id=rowid, + errlevel=2, + stagestatus="Execute Failed", + errormessage=f"前序语句失败, 未执行", + sql=statement, + affected_rows=0, + execute_time=0, + ) + ) rowid += 1 if close_conn: diff --git a/sql/engines/redis.py b/sql/engines/redis.py index af84d76779..eda4779755 100644 --- a/sql/engines/redis.py +++ b/sql/engines/redis.py @@ -17,29 +17,41 @@ from . import EngineBase from .models import ResultSet, ReviewSet, ReviewResult -__author__ = 'hhyo' +__author__ = "hhyo" -logger = logging.getLogger('default') +logger = logging.getLogger("default") class RedisEngine(EngineBase): def get_connection(self, db_name=None): db_name = db_name or self.db_name - if self.mode == 'cluster': - return redis.cluster.RedisCluster(host=self.host, port=self.port, password=self.password, - encoding_errors='ignore', decode_responses=True, - socket_connect_timeout=10) + if self.mode == "cluster": + return redis.cluster.RedisCluster( + host=self.host, + port=self.port, + password=self.password, + encoding_errors="ignore", + decode_responses=True, + socket_connect_timeout=10, + ) else: - return redis.Redis(host=self.host, port=self.port, db=db_name, password=self.password, - encoding_errors='ignore', decode_responses=True, socket_connect_timeout=10) + return redis.Redis( + host=self.host, + port=self.port, + db=db_name, + password=self.password, + encoding_errors="ignore", + decode_responses=True, + socket_connect_timeout=10, + ) @property def name(self): - return 'Redis' + return "Redis" @property def info(self): - return 'Redis engine' + return "Redis engine" def test_connection(self): return self.get_all_databases() @@ -49,44 +61,74 @@ def get_all_databases(self, **kwargs): 获取数据库列表 :return: """ - result = ResultSet(full_sql='CONFIG GET databases') + result = ResultSet(full_sql="CONFIG GET databases") conn = self.get_connection() try: - rows = conn.config_get('databases')['databases'] + rows = conn.config_get("databases")["databases"] except Exception as e: logger.warning(f"Redis CONFIG GET databases 执行报错,异常信息:{e}") - dbs = [int(i.split('db')[1]) for i in conn.info('Keyspace').keys() if len(i.split('db')) == 2] + dbs = [ + int(i.split("db")[1]) + for i in conn.info("Keyspace").keys() + if len(i.split("db")) == 2 + ] rows = max(dbs, [16]) db_list = [str(x) for x in range(int(rows))] result.rows = db_list return result - def query_check(self, db_name=None, sql='', limit_num=0): + def query_check(self, db_name=None, sql="", limit_num=0): """提交查询前的检查""" - result = {'msg': '', 'bad_query': True, 'filtered_sql': sql, 'has_star': False} - safe_cmd = ["scan", "exists", "ttl", "pttl", "type", "get", "mget", "strlen", - "hgetall", "hexists", "hget", "hmget", "hkeys", "hvals", - "smembers", "scard", "sdiff", "sunion", "sismember", "llen", "lrange", "lindex", - "zrange", "zrangebyscore", "zscore", "zcard", "zcount", "zrank"] + result = {"msg": "", "bad_query": True, "filtered_sql": sql, "has_star": False} + safe_cmd = [ + "scan", + "exists", + "ttl", + "pttl", + "type", + "get", + "mget", + "strlen", + "hgetall", + "hexists", + "hget", + "hmget", + "hkeys", + "hvals", + "smembers", + "scard", + "sdiff", + "sunion", + "sismember", + "llen", + "lrange", + "lindex", + "zrange", + "zrangebyscore", + "zscore", + "zcard", + "zcount", + "zrank", + ] # 命令校验,仅可以执行safe_cmd内的命令 for cmd in safe_cmd: - if re.match(fr'^{cmd}', sql.strip(), re.I): - result['bad_query'] = False + if re.match(rf"^{cmd}", sql.strip(), re.I): + result["bad_query"] = False break - if result['bad_query']: - result['msg'] = "禁止执行该命令!" + if result["bad_query"]: + result["msg"] = "禁止执行该命令!" return result - def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs): - """返回 ResultSet """ + def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs): + """返回 ResultSet""" result_set = ResultSet(full_sql=sql) try: conn = self.get_connection(db_name=db_name) rows = conn.execute_command(*shlex.split(sql)) - result_set.column_list = ['Result'] + result_set.column_list = ["Result"] if isinstance(rows, list) or isinstance(rows, tuple): - if re.match(fr'^scan', sql.strip(), re.I): + if re.match(rf"^scan", sql.strip(), re.I): keys = [[row] for row in rows[1]] keys.insert(0, [rows[0]]) result_set.rows = tuple(keys) @@ -108,26 +150,28 @@ def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs): result_set.error = str(e) return result_set - def filter_sql(self, sql='', limit_num=0): + def filter_sql(self, sql="", limit_num=0): return sql.strip() - def query_masking(self, db_name=None, sql='', resultset=None): + def query_masking(self, db_name=None, sql="", resultset=None): """不做脱敏""" return resultset - def execute_check(self, db_name=None, sql=''): + def execute_check(self, db_name=None, sql=""): """上线单执行前的检查, 返回Review set""" check_result = ReviewSet(full_sql=sql) - split_sql = [cmd.strip() for cmd in sql.split('\n') if cmd.strip()] + split_sql = [cmd.strip() for cmd in sql.split("\n") if cmd.strip()] line = 1 for cmd in split_sql: - result = ReviewResult(id=line, - errlevel=0, - stagestatus='Audit completed', - errormessage='None', - sql=cmd, - affected_rows=0, - execute_time=0, ) + result = ReviewResult( + id=line, + errlevel=0, + stagestatus="Audit completed", + errormessage="None", + sql=cmd, + affected_rows=0, + execute_time=0, + ) check_result.rows += [result] line += 1 return check_result @@ -135,7 +179,7 @@ def execute_check(self, db_name=None, sql=''): def execute_workflow(self, workflow): """执行上线单,返回Review set""" sql = workflow.sqlworkflowcontent.sql_content - split_sql = [cmd.strip() for cmd in sql.split('\n') if cmd.strip()] + split_sql = [cmd.strip() for cmd in sql.split("\n") if cmd.strip()] execute_result = ReviewSet(full_sql=sql) line = 1 cmd = None @@ -144,40 +188,48 @@ def execute_workflow(self, workflow): for cmd in split_sql: with FuncTimer() as t: conn.execute_command(*shlex.split(cmd)) - execute_result.rows.append(ReviewResult( - id=line, - errlevel=0, - stagestatus='Execute Successfully', - errormessage='None', - sql=cmd, - affected_rows=0, - execute_time=t.cost, - )) + execute_result.rows.append( + ReviewResult( + id=line, + errlevel=0, + stagestatus="Execute Successfully", + errormessage="None", + sql=cmd, + affected_rows=0, + execute_time=t.cost, + ) + ) line += 1 except Exception as e: - logger.warning(f"Redis命令执行报错,语句:{cmd or sql}, 错误信息:{traceback.format_exc()}") + logger.warning( + f"Redis命令执行报错,语句:{cmd or sql}, 错误信息:{traceback.format_exc()}" + ) # 追加当前报错语句信息到执行结果中 execute_result.error = str(e) - execute_result.rows.append(ReviewResult( - id=line, - errlevel=2, - stagestatus='Execute Failed', - errormessage=f'异常信息:{e}', - sql=cmd, - affected_rows=0, - execute_time=0, - )) - line += 1 - # 报错语句后面的语句标记为审核通过、未执行,追加到执行结果中 - for statement in split_sql[line - 1:]: - execute_result.rows.append(ReviewResult( + execute_result.rows.append( + ReviewResult( id=line, - errlevel=0, - stagestatus='Audit completed', - errormessage=f'前序语句失败, 未执行', - sql=statement, + errlevel=2, + stagestatus="Execute Failed", + errormessage=f"异常信息:{e}", + sql=cmd, affected_rows=0, execute_time=0, - )) + ) + ) + line += 1 + # 报错语句后面的语句标记为审核通过、未执行,追加到执行结果中 + for statement in split_sql[line - 1 :]: + execute_result.rows.append( + ReviewResult( + id=line, + errlevel=0, + stagestatus="Audit completed", + errormessage=f"前序语句失败, 未执行", + sql=statement, + affected_rows=0, + execute_time=0, + ) + ) line += 1 return execute_result diff --git a/sql/engines/tests.py b/sql/engines/tests.py index 14da252e3a..e284ff2d62 100644 --- a/sql/engines/tests.py +++ b/sql/engines/tests.py @@ -27,39 +27,44 @@ class TestReviewSet(TestCase): def test_review_set(self): new_review_set = ReviewSet() - new_review_set.rows = [{'id': '1679123'}] - self.assertIn('1679123', new_review_set.json()) + new_review_set.rows = [{"id": "1679123"}] + self.assertIn("1679123", new_review_set.json()) class TestEngineBase(TestCase): @classmethod def setUpClass(cls): - cls.u1 = User(username='some_user', display='用户1') + cls.u1 = User(username="some_user", display="用户1") cls.u1.save() - cls.ins1 = Instance(instance_name='some_ins', type='master', db_type='mssql', host='some_host', - port=1366, user='ins_user', password='some_str') + cls.ins1 = Instance( + instance_name="some_ins", + type="master", + db_type="mssql", + host="some_host", + port=1366, + user="ins_user", + password="some_str", + ) cls.ins1.save() cls.wf1 = SqlWorkflow.objects.create( - workflow_name='some_name', + workflow_name="some_name", group_id=1, - group_name='g1', + group_name="g1", engineer=cls.u1.username, engineer_display=cls.u1.display, - audit_auth_groups='some_group', + audit_auth_groups="some_group", create_time=datetime.now() - timedelta(days=1), - status='workflow_finish', + status="workflow_finish", is_backup=True, instance=cls.ins1, - db_name='some_db', - syntax_type=1 + db_name="some_db", + syntax_type=1, ) cls.wfc1 = SqlWorkflowContent.objects.create( workflow=cls.wf1, - sql_content='some_sql', - execute_result=json.dumps([{ - 'id': 1, - 'sql': 'some_content' - }])) + sql_content="some_sql", + execute_result=json.dumps([{"id": 1, "sql": "some_content"}]), + ) @classmethod def tearDownClass(cls): @@ -75,27 +80,35 @@ def test_init_with_ins(self): class TestMssql(TestCase): - @classmethod def setUpClass(cls): - cls.ins1 = Instance(instance_name='some_ins', type='slave', db_type='mssql', host='some_host', - port=1366, user='ins_user', password='some_str') + cls.ins1 = Instance( + instance_name="some_ins", + type="slave", + db_type="mssql", + host="some_host", + port=1366, + user="ins_user", + password="some_str", + ) cls.ins1.save() cls.engine = MssqlEngine(instance=cls.ins1) cls.wf = SqlWorkflow.objects.create( - workflow_name='some_name', + workflow_name="some_name", group_id=1, - group_name='g1', - engineer_display='', - audit_auth_groups='some_group', + group_name="g1", + engineer_display="", + audit_auth_groups="some_group", create_time=datetime.now() - timedelta(days=1), - status='workflow_finish', + status="workflow_finish", is_backup=True, instance=cls.ins1, - db_name='some_db', - syntax_type=1 + db_name="some_db", + syntax_type=1, + ) + SqlWorkflowContent.objects.create( + workflow=cls.wf, sql_content="insert into some_tb values (1)" ) - SqlWorkflowContent.objects.create(workflow=cls.wf, sql_content='insert into some_tb values (1)') @classmethod def tearDownClass(cls): @@ -103,91 +116,96 @@ def tearDownClass(cls): cls.wf.delete() SqlWorkflowContent.objects.all().delete() - @patch('sql.engines.mssql.pyodbc.connect') + @patch("sql.engines.mssql.pyodbc.connect") def testGetConnection(self, connect): new_engine = MssqlEngine(instance=self.ins1) new_engine.get_connection() connect.assert_called_once() - @patch('sql.engines.mssql.pyodbc.connect') + @patch("sql.engines.mssql.pyodbc.connect") def testQuery(self, connect): cur = Mock() connect.return_value.cursor = cur cur.return_value.execute = Mock() - cur.return_value.fetchmany.return_value = (('v1', 'v2'),) - cur.return_value.description = (('k1', 'some_other_des'), ('k2', 'some_other_des')) + cur.return_value.fetchmany.return_value = (("v1", "v2"),) + cur.return_value.description = ( + ("k1", "some_other_des"), + ("k2", "some_other_des"), + ) new_engine = MssqlEngine(instance=self.ins1) - query_result = new_engine.query(sql='some_str', limit_num=100) + query_result = new_engine.query(sql="some_str", limit_num=100) cur.return_value.execute.assert_called() cur.return_value.fetchmany.assert_called_once_with(100) connect.return_value.close.assert_called_once() self.assertIsInstance(query_result, ResultSet) - @patch.object(MssqlEngine, 'query') + @patch.object(MssqlEngine, "query") def testAllDb(self, mock_query): db_result = ResultSet() - db_result.rows = [('db_1',), ('db_2',)] + db_result.rows = [("db_1",), ("db_2",)] mock_query.return_value = db_result new_engine = MssqlEngine(instance=self.ins1) dbs = new_engine.get_all_databases() - self.assertEqual(dbs.rows, ['db_1', 'db_2']) + self.assertEqual(dbs.rows, ["db_1", "db_2"]) - @patch.object(MssqlEngine, 'query') + @patch.object(MssqlEngine, "query") def testAllTables(self, mock_query): table_result = ResultSet() - table_result.rows = [('tb_1', 'some_des'), ('tb_2', 'some_des')] + table_result.rows = [("tb_1", "some_des"), ("tb_2", "some_des")] mock_query.return_value = table_result new_engine = MssqlEngine(instance=self.ins1) - tables = new_engine.get_all_tables('some_db') - mock_query.assert_called_once_with(db_name='some_db', sql=ANY) - self.assertEqual(tables.rows, ['tb_1', 'tb_2']) + tables = new_engine.get_all_tables("some_db") + mock_query.assert_called_once_with(db_name="some_db", sql=ANY) + self.assertEqual(tables.rows, ["tb_1", "tb_2"]) - @patch.object(MssqlEngine, 'query') + @patch.object(MssqlEngine, "query") def testAllColumns(self, mock_query): db_result = ResultSet() - db_result.rows = [('col_1', 'type'), ('col_2', 'type2')] + db_result.rows = [("col_1", "type"), ("col_2", "type2")] mock_query.return_value = db_result new_engine = MssqlEngine(instance=self.ins1) - dbs = new_engine.get_all_columns_by_tb('some_db', 'some_tb') - self.assertEqual(dbs.rows, ['col_1', 'col_2']) + dbs = new_engine.get_all_columns_by_tb("some_db", "some_tb") + self.assertEqual(dbs.rows, ["col_1", "col_2"]) - @patch.object(MssqlEngine, 'query') + @patch.object(MssqlEngine, "query") def testDescribe(self, mock_query): new_engine = MssqlEngine(instance=self.ins1) - new_engine.describe_table('some_db', 'some_db') + new_engine.describe_table("some_db", "some_db") mock_query.assert_called_once() def testQueryCheck(self): new_engine = MssqlEngine(instance=self.ins1) # 只抽查一个函数 - banned_sql = 'select concat(phone,1) from user_table' - check_result = new_engine.query_check(db_name='some_db', sql=banned_sql) - self.assertTrue(check_result.get('bad_query')) - banned_sql = 'select phone from user_table where phone=concat(phone,1)' - check_result = new_engine.query_check(db_name='some_db', sql=banned_sql) - self.assertTrue(check_result.get('bad_query')) + banned_sql = "select concat(phone,1) from user_table" + check_result = new_engine.query_check(db_name="some_db", sql=banned_sql) + self.assertTrue(check_result.get("bad_query")) + banned_sql = "select phone from user_table where phone=concat(phone,1)" + check_result = new_engine.query_check(db_name="some_db", sql=banned_sql) + self.assertTrue(check_result.get("bad_query")) sp_sql = "sp_helptext '[SomeName].[SomeAction]'" - check_result = new_engine.query_check(db_name='some_db', sql=sp_sql) - self.assertFalse(check_result.get('bad_query')) - self.assertEqual(check_result.get('filtered_sql'), sp_sql) + check_result = new_engine.query_check(db_name="some_db", sql=sp_sql) + self.assertFalse(check_result.get("bad_query")) + self.assertEqual(check_result.get("filtered_sql"), sp_sql) def test_filter_sql(self): new_engine = MssqlEngine(instance=self.ins1) # 只抽查一个函数 - banned_sql = 'select user from user_table' + banned_sql = "select user from user_table" check_result = new_engine.filter_sql(sql=banned_sql, limit_num=10) self.assertEqual(check_result, "select top 10 user from user_table") def test_execute_check(self): new_engine = MssqlEngine(instance=self.ins1) - test_sql = 'use database\ngo\nsome sql1\nGO\nsome sql2\n\r\nGo\nsome sql3\n\r\ngO\n' + test_sql = ( + "use database\ngo\nsome sql1\nGO\nsome sql2\n\r\nGo\nsome sql3\n\r\ngO\n" + ) check_result = new_engine.execute_check(db_name=None, sql=test_sql) self.assertIsInstance(check_result, ReviewSet) - self.assertEqual(check_result.rows[1].__dict__['sql'], "use database\n") - self.assertEqual(check_result.rows[2].__dict__['sql'], "\nsome sql1\n") - self.assertEqual(check_result.rows[4].__dict__['sql'], "\nsome sql3\n\r\n") + self.assertEqual(check_result.rows[1].__dict__["sql"], "use database\n") + self.assertEqual(check_result.rows[2].__dict__["sql"], "\nsome sql1\n") + self.assertEqual(check_result.rows[4].__dict__["sql"], "\nsome sql3\n\r\n") - @patch('sql.engines.mssql.MssqlEngine.execute') + @patch("sql.engines.mssql.MssqlEngine.execute") def test_execute_workflow(self, mock_execute): mock_execute.return_value.error = None new_engine = MssqlEngine(instance=self.ins1) @@ -196,48 +214,56 @@ def test_execute_workflow(self, mock_execute): mock_execute.assert_called() self.assertEqual(1, mock_execute.call_count) - @patch('sql.engines.mssql.MssqlEngine.get_connection') + @patch("sql.engines.mssql.MssqlEngine.get_connection") def test_execute(self, mock_connect): mock_cursor = Mock() mock_connect.return_value.cursor = mock_cursor new_engine = MssqlEngine(instance=self.ins1) - execute_result = new_engine.execute('some_db', 'some_sql') + execute_result = new_engine.execute("some_db", "some_sql") # 验证结果, 无异常 self.assertIsNone(execute_result.error) - self.assertEqual('some_sql', execute_result.full_sql) + self.assertEqual("some_sql", execute_result.full_sql) self.assertEqual(2, len(execute_result.rows)) mock_cursor.return_value.execute.assert_called() mock_cursor.return_value.commit.assert_called() mock_cursor.reset_mock() # 验证异常 - mock_cursor.return_value.execute.side_effect = Exception('Boom! some exception!') - execute_result = new_engine.execute('some_db', 'some_sql') - self.assertIn('Boom! some exception!', execute_result.error) - self.assertEqual('some_sql', execute_result.full_sql) + mock_cursor.return_value.execute.side_effect = Exception( + "Boom! some exception!" + ) + execute_result = new_engine.execute("some_db", "some_sql") + self.assertIn("Boom! some exception!", execute_result.error) + self.assertEqual("some_sql", execute_result.full_sql) self.assertEqual(2, len(execute_result.rows)) mock_cursor.return_value.commit.assert_not_called() mock_cursor.return_value.rollback.assert_called() class TestMysql(TestCase): - def setUp(self): - self.ins1 = Instance(instance_name='some_ins', type='slave', db_type='mysql', host='some_host', - port=1366, user='ins_user', password='some_str') + self.ins1 = Instance( + instance_name="some_ins", + type="slave", + db_type="mysql", + host="some_host", + port=1366, + user="ins_user", + password="some_str", + ) self.ins1.save() self.sys_config = SysConfig() self.wf = SqlWorkflow.objects.create( - workflow_name='some_name', + workflow_name="some_name", group_id=1, - group_name='g1', - engineer_display='', - audit_auth_groups='some_group', + group_name="g1", + engineer_display="", + audit_auth_groups="some_group", create_time=datetime.now() - timedelta(days=1), - status='workflow_finish', + status="workflow_finish", is_backup=True, instance=self.ins1, - db_name='some_db', - syntax_type=1 + db_name="some_db", + syntax_type=1, ) SqlWorkflowContent.objects.create(workflow=self.wf) @@ -247,359 +273,408 @@ def tearDown(self): SqlWorkflow.objects.all().delete() SqlWorkflowContent.objects.all().delete() - @patch('MySQLdb.connect') + @patch("MySQLdb.connect") def test_engine_base_info(self, _conn): new_engine = MysqlEngine(instance=self.ins1) - self.assertEqual(new_engine.name, 'MySQL') - self.assertEqual(new_engine.info, 'MySQL engine') + self.assertEqual(new_engine.name, "MySQL") + self.assertEqual(new_engine.info, "MySQL engine") - @patch('MySQLdb.connect') + @patch("MySQLdb.connect") def testGetConnection(self, connect): new_engine = MysqlEngine(instance=self.ins1) new_engine.get_connection() connect.assert_called_once() - @patch('MySQLdb.connect') + @patch("MySQLdb.connect") def testQuery(self, connect): cur = Mock() connect.return_value.cursor = cur cur.return_value.execute = Mock() - cur.return_value.fetchmany.return_value = (('v1', 'v2'),) - cur.return_value.description = (('k1', 'some_other_des'), ('k2', 'some_other_des')) + cur.return_value.fetchmany.return_value = (("v1", "v2"),) + cur.return_value.description = ( + ("k1", "some_other_des"), + ("k2", "some_other_des"), + ) new_engine = MysqlEngine(instance=self.ins1) - query_result = new_engine.query(sql='some_str', limit_num=100) + query_result = new_engine.query(sql="some_str", limit_num=100) cur.return_value.execute.assert_called() cur.return_value.fetchmany.assert_called_once_with(size=100) connect.return_value.close.assert_called_once() self.assertIsInstance(query_result, ResultSet) - @patch.object(MysqlEngine, 'query') + @patch.object(MysqlEngine, "query") def testAllDb(self, mock_query): db_result = ResultSet() - db_result.rows = [('db_1',), ('db_2',)] + db_result.rows = [("db_1",), ("db_2",)] mock_query.return_value = db_result new_engine = MysqlEngine(instance=self.ins1) dbs = new_engine.get_all_databases() - self.assertEqual(dbs.rows, ['db_1', 'db_2']) + self.assertEqual(dbs.rows, ["db_1", "db_2"]) - @patch.object(MysqlEngine, 'query') + @patch.object(MysqlEngine, "query") def testAllTables(self, mock_query): table_result = ResultSet() - table_result.rows = [('tb_1', 'some_des'), ('tb_2', 'some_des')] + table_result.rows = [("tb_1", "some_des"), ("tb_2", "some_des")] mock_query.return_value = table_result new_engine = MysqlEngine(instance=self.ins1) - tables = new_engine.get_all_tables('some_db') - mock_query.assert_called_once_with(db_name='some_db', sql=ANY) - self.assertEqual(tables.rows, ['tb_1', 'tb_2']) + tables = new_engine.get_all_tables("some_db") + mock_query.assert_called_once_with(db_name="some_db", sql=ANY) + self.assertEqual(tables.rows, ["tb_1", "tb_2"]) - @patch.object(MysqlEngine, 'query') + @patch.object(MysqlEngine, "query") def testAllColumns(self, mock_query): db_result = ResultSet() - db_result.rows = [('col_1', 'type'), ('col_2', 'type2')] + db_result.rows = [("col_1", "type"), ("col_2", "type2")] mock_query.return_value = db_result new_engine = MysqlEngine(instance=self.ins1) - dbs = new_engine.get_all_columns_by_tb('some_db', 'some_tb') - self.assertEqual(dbs.rows, ['col_1', 'col_2']) + dbs = new_engine.get_all_columns_by_tb("some_db", "some_tb") + self.assertEqual(dbs.rows, ["col_1", "col_2"]) - @patch.object(MysqlEngine, 'query') + @patch.object(MysqlEngine, "query") def testDescribe(self, mock_query): new_engine = MysqlEngine(instance=self.ins1) - new_engine.describe_table('some_db', 'some_db') + new_engine.describe_table("some_db", "some_db") mock_query.assert_called_once() def testQueryCheck(self): new_engine = MysqlEngine(instance=self.ins1) - sql_without_limit = '-- 测试\n select user from usertable' - check_result = new_engine.query_check(db_name='some_db', sql=sql_without_limit) - self.assertEqual(check_result['filtered_sql'], 'select user from usertable') + sql_without_limit = "-- 测试\n select user from usertable" + check_result = new_engine.query_check(db_name="some_db", sql=sql_without_limit) + self.assertEqual(check_result["filtered_sql"], "select user from usertable") def test_query_check_wrong_sql(self): new_engine = MysqlEngine(instance=self.ins1) - wrong_sql = '-- 测试' - check_result = new_engine.query_check(db_name='some_db', sql=wrong_sql) - self.assertDictEqual(check_result, - {'msg': '不支持的查询语法类型!', 'bad_query': True, 'filtered_sql': '-- 测试', 'has_star': False}) + wrong_sql = "-- 测试" + check_result = new_engine.query_check(db_name="some_db", sql=wrong_sql) + self.assertDictEqual( + check_result, + { + "msg": "不支持的查询语法类型!", + "bad_query": True, + "filtered_sql": "-- 测试", + "has_star": False, + }, + ) def test_query_check_update_sql(self): new_engine = MysqlEngine(instance=self.ins1) - update_sql = 'update user set id=0' - check_result = new_engine.query_check(db_name='some_db', sql=update_sql) - self.assertDictEqual(check_result, - {'msg': '不支持的查询语法类型!', 'bad_query': True, 'filtered_sql': 'update user set id=0', - 'has_star': False}) + update_sql = "update user set id=0" + check_result = new_engine.query_check(db_name="some_db", sql=update_sql) + self.assertDictEqual( + check_result, + { + "msg": "不支持的查询语法类型!", + "bad_query": True, + "filtered_sql": "update user set id=0", + "has_star": False, + }, + ) def test_filter_sql_with_delimiter(self): new_engine = MysqlEngine(instance=self.ins1) - sql_without_limit = 'select user from usertable;' + sql_without_limit = "select user from usertable;" check_result = new_engine.filter_sql(sql=sql_without_limit, limit_num=100) - self.assertEqual(check_result, 'select user from usertable limit 100;') + self.assertEqual(check_result, "select user from usertable limit 100;") def test_filter_sql_without_delimiter(self): new_engine = MysqlEngine(instance=self.ins1) - sql_without_limit = 'select user from usertable' + sql_without_limit = "select user from usertable" check_result = new_engine.filter_sql(sql=sql_without_limit, limit_num=100) - self.assertEqual(check_result, 'select user from usertable limit 100;') + self.assertEqual(check_result, "select user from usertable limit 100;") def test_filter_sql_with_limit(self): new_engine = MysqlEngine(instance=self.ins1) - sql_without_limit = 'select user from usertable limit 10' + sql_without_limit = "select user from usertable limit 10" check_result = new_engine.filter_sql(sql=sql_without_limit, limit_num=1) - self.assertEqual(check_result, 'select user from usertable limit 1;') + self.assertEqual(check_result, "select user from usertable limit 1;") def test_filter_sql_with_limit_min(self): new_engine = MysqlEngine(instance=self.ins1) - sql_without_limit = 'select user from usertable limit 10' + sql_without_limit = "select user from usertable limit 10" check_result = new_engine.filter_sql(sql=sql_without_limit, limit_num=100) - self.assertEqual(check_result, 'select user from usertable limit 10;') + self.assertEqual(check_result, "select user from usertable limit 10;") def test_filter_sql_with_limit_offset(self): new_engine = MysqlEngine(instance=self.ins1) - sql_without_limit = 'select user from usertable limit 10 offset 100' + sql_without_limit = "select user from usertable limit 10 offset 100" check_result = new_engine.filter_sql(sql=sql_without_limit, limit_num=1) - self.assertEqual(check_result, 'select user from usertable limit 1 offset 100;') + self.assertEqual(check_result, "select user from usertable limit 1 offset 100;") def test_filter_sql_with_limit_nn(self): new_engine = MysqlEngine(instance=self.ins1) - sql_without_limit = 'select user from usertable limit 10, 100' + sql_without_limit = "select user from usertable limit 10, 100" check_result = new_engine.filter_sql(sql=sql_without_limit, limit_num=1) - self.assertEqual(check_result, 'select user from usertable limit 10,1;') + self.assertEqual(check_result, "select user from usertable limit 10,1;") def test_filter_sql_upper(self): new_engine = MysqlEngine(instance=self.ins1) - sql_without_limit = 'SELECT USER FROM usertable LIMIT 10, 100' + sql_without_limit = "SELECT USER FROM usertable LIMIT 10, 100" check_result = new_engine.filter_sql(sql=sql_without_limit, limit_num=1) - self.assertEqual(check_result, 'SELECT USER FROM usertable limit 10,1;') + self.assertEqual(check_result, "SELECT USER FROM usertable limit 10,1;") def test_filter_sql_not_select(self): new_engine = MysqlEngine(instance=self.ins1) - sql_without_limit = 'show create table usertable;' + sql_without_limit = "show create table usertable;" check_result = new_engine.filter_sql(sql=sql_without_limit, limit_num=1) - self.assertEqual(check_result, 'show create table usertable;') + self.assertEqual(check_result, "show create table usertable;") - @patch('sql.engines.mysql.data_masking', return_value=ResultSet()) + @patch("sql.engines.mysql.data_masking", return_value=ResultSet()) def test_query_masking(self, _data_masking): query_result = ResultSet() new_engine = MysqlEngine(instance=self.ins1) - masking_result = new_engine.query_masking(db_name='archery', sql='select 1', resultset=query_result) + masking_result = new_engine.query_masking( + db_name="archery", sql="select 1", resultset=query_result + ) self.assertIsInstance(masking_result, ResultSet) - @patch('sql.engines.mysql.data_masking', return_value=ResultSet()) + @patch("sql.engines.mysql.data_masking", return_value=ResultSet()) def test_query_masking_not_select(self, _data_masking): query_result = ResultSet() new_engine = MysqlEngine(instance=self.ins1) - masking_result = new_engine.query_masking(db_name='archery', sql='explain select 1', resultset=query_result) + masking_result = new_engine.query_masking( + db_name="archery", sql="explain select 1", resultset=query_result + ) self.assertEqual(masking_result, query_result) - @patch('sql.engines.mysql.GoInceptionEngine') + @patch("sql.engines.mysql.GoInceptionEngine") def test_execute_check_select_sql(self, _inception_engine): - self.sys_config.set('goinception', 'true') - sql = 'select * from user' - inc_row = ReviewResult(id=1, - errlevel=0, - stagestatus='Audit completed', - errormessage='None', - sql=sql, - affected_rows=0, - execute_time='', ) - row = ReviewResult(id=1, errlevel=2, - stagestatus='驳回不支持语句', - errormessage='仅支持DML和DDL语句,查询语句请使用SQL查询功能!', - sql=sql) - _inception_engine.return_value.execute_check.return_value = ReviewSet(full_sql=sql, rows=[inc_row]) + self.sys_config.set("goinception", "true") + sql = "select * from user" + inc_row = ReviewResult( + id=1, + errlevel=0, + stagestatus="Audit completed", + errormessage="None", + sql=sql, + affected_rows=0, + execute_time="", + ) + row = ReviewResult( + id=1, + errlevel=2, + stagestatus="驳回不支持语句", + errormessage="仅支持DML和DDL语句,查询语句请使用SQL查询功能!", + sql=sql, + ) + _inception_engine.return_value.execute_check.return_value = ReviewSet( + full_sql=sql, rows=[inc_row] + ) new_engine = MysqlEngine(instance=self.ins1) - check_result = new_engine.execute_check(db_name='archery', sql=sql) + check_result = new_engine.execute_check(db_name="archery", sql=sql) self.assertIsInstance(check_result, ReviewSet) self.assertEqual(check_result.rows[0].__dict__, row.__dict__) - @patch('sql.engines.mysql.GoInceptionEngine') + @patch("sql.engines.mysql.GoInceptionEngine") def test_execute_check_critical_sql(self, _inception_engine): - self.sys_config.set('goinception', 'true') - self.sys_config.set('critical_ddl_regex', '^|update') + self.sys_config.set("goinception", "true") + self.sys_config.set("critical_ddl_regex", "^|update") self.sys_config.get_all_config() - sql = 'update user set id=1' - inc_row = ReviewResult(id=1, - errlevel=0, - stagestatus='Audit completed', - errormessage='None', - sql=sql, - affected_rows=0, - execute_time='', ) - row = ReviewResult(id=1, errlevel=2, - stagestatus='驳回高危SQL', - errormessage='禁止提交匹配' + '^|update' + '条件的语句!', - sql=sql) - _inception_engine.return_value.execute_check.return_value = ReviewSet(full_sql=sql, rows=[inc_row]) + sql = "update user set id=1" + inc_row = ReviewResult( + id=1, + errlevel=0, + stagestatus="Audit completed", + errormessage="None", + sql=sql, + affected_rows=0, + execute_time="", + ) + row = ReviewResult( + id=1, + errlevel=2, + stagestatus="驳回高危SQL", + errormessage="禁止提交匹配" + "^|update" + "条件的语句!", + sql=sql, + ) + _inception_engine.return_value.execute_check.return_value = ReviewSet( + full_sql=sql, rows=[inc_row] + ) new_engine = MysqlEngine(instance=self.ins1) - check_result = new_engine.execute_check(db_name='archery', sql=sql) + check_result = new_engine.execute_check(db_name="archery", sql=sql) self.assertIsInstance(check_result, ReviewSet) self.assertEqual(check_result.rows[0].__dict__, row.__dict__) - @patch('sql.engines.mysql.GoInceptionEngine') + @patch("sql.engines.mysql.GoInceptionEngine") def test_execute_check_normal_sql(self, _inception_engine): - self.sys_config.set('goinception', 'true') - sql = 'update user set id=1' - row = ReviewResult(id=1, - errlevel=0, - stagestatus='Audit completed', - errormessage='None', - sql=sql, - affected_rows=0, - execute_time=0, ) - _inception_engine.return_value.execute_check.return_value = ReviewSet(full_sql=sql, rows=[row]) + self.sys_config.set("goinception", "true") + sql = "update user set id=1" + row = ReviewResult( + id=1, + errlevel=0, + stagestatus="Audit completed", + errormessage="None", + sql=sql, + affected_rows=0, + execute_time=0, + ) + _inception_engine.return_value.execute_check.return_value = ReviewSet( + full_sql=sql, rows=[row] + ) new_engine = MysqlEngine(instance=self.ins1) - check_result = new_engine.execute_check(db_name='archery', sql=sql) + check_result = new_engine.execute_check(db_name="archery", sql=sql) self.assertIsInstance(check_result, ReviewSet) self.assertEqual(check_result.rows[0].__dict__, row.__dict__) - @patch('sql.engines.mysql.GoInceptionEngine') + @patch("sql.engines.mysql.GoInceptionEngine") def test_execute_check_normal_sql_with_Exception(self, _inception_engine): - sql = 'update user set id=1' + sql = "update user set id=1" _inception_engine.return_value.execute_check.side_effect = RuntimeError() new_engine = MysqlEngine(instance=self.ins1) with self.assertRaises(RuntimeError): new_engine.execute_check(db_name=0, sql=sql) - @patch.object(MysqlEngine, 'query') - @patch('sql.engines.mysql.GoInceptionEngine') + @patch.object(MysqlEngine, "query") + @patch("sql.engines.mysql.GoInceptionEngine") def test_execute_workflow(self, _inception_engine, _query): - self.sys_config.set('goinception', 'true') - sql = 'update user set id=1' + self.sys_config.set("goinception", "true") + sql = "update user set id=1" _inception_engine.return_value.execute.return_value = ReviewSet(full_sql=sql) - _query.return_value.rows = (('0',),) + _query.return_value.rows = (("0",),) new_engine = MysqlEngine(instance=self.ins1) execute_result = new_engine.execute_workflow(self.wf) self.assertIsInstance(execute_result, ReviewSet) - @patch('MySQLdb.connect.cursor.execute') - @patch('MySQLdb.connect.cursor') - @patch('MySQLdb.connect') + @patch("MySQLdb.connect.cursor.execute") + @patch("MySQLdb.connect.cursor") + @patch("MySQLdb.connect") def test_execute(self, _connect, _cursor, _execute): new_engine = MysqlEngine(instance=self.ins1) execute_result = new_engine.execute(self.wf) self.assertIsInstance(execute_result, ResultSet) - @patch('MySQLdb.connect') + @patch("MySQLdb.connect") def test_server_version(self, _connect): - _connect.return_value.get_server_info.return_value = '5.7.20-16log' + _connect.return_value.get_server_info.return_value = "5.7.20-16log" new_engine = MysqlEngine(instance=self.ins1) server_version = new_engine.server_version self.assertTupleEqual(server_version, (5, 7, 20)) - @patch.object(MysqlEngine, 'query') + @patch.object(MysqlEngine, "query") def test_get_variables_not_filter(self, _query): new_engine = MysqlEngine(instance=self.ins1) new_engine.get_variables() _query.assert_called_once() - @patch('MySQLdb.connect') - @patch.object(MysqlEngine, 'query') + @patch("MySQLdb.connect") + @patch.object(MysqlEngine, "query") def test_get_variables_filter(self, _query, _connect): - _connect.return_value.get_server_info.return_value = '5.7.20-16log' + _connect.return_value.get_server_info.return_value = "5.7.20-16log" new_engine = MysqlEngine(instance=self.ins1) - new_engine.get_variables(variables=['binlog_format']) + new_engine.get_variables(variables=["binlog_format"]) _query.assert_called() - @patch.object(MysqlEngine, 'query') + @patch.object(MysqlEngine, "query") def test_set_variable(self, _query): new_engine = MysqlEngine(instance=self.ins1) - new_engine.set_variable('binlog_format', 'ROW') + new_engine.set_variable("binlog_format", "ROW") _query.assert_called_once_with(sql="set global binlog_format=ROW;") - @patch('sql.engines.mysql.GoInceptionEngine') + @patch("sql.engines.mysql.GoInceptionEngine") def test_osc_go_inception(self, _inception_engine): - self.sys_config.set('goinception', 'false') + self.sys_config.set("goinception", "false") _inception_engine.return_value.osc_control.return_value = ReviewSet() - command = 'get' - sqlsha1 = 'xxxxx' + command = "get" + sqlsha1 = "xxxxx" new_engine = MysqlEngine(instance=self.ins1) new_engine.osc_control(sqlsha1=sqlsha1, command=command) - @patch('sql.engines.mysql.GoInceptionEngine') + @patch("sql.engines.mysql.GoInceptionEngine") def test_osc_inception(self, _inception_engine): - self.sys_config.set('goinception', 'true') + self.sys_config.set("goinception", "true") _inception_engine.return_value.osc_control.return_value = ReviewSet() - command = 'get' - sqlsha1 = 'xxxxx' + command = "get" + sqlsha1 = "xxxxx" new_engine = MysqlEngine(instance=self.ins1) new_engine.osc_control(sqlsha1=sqlsha1, command=command) - @patch.object(MysqlEngine, 'query') + @patch.object(MysqlEngine, "query") def test_kill_connection(self, _query): new_engine = MysqlEngine(instance=self.ins1) new_engine.kill_connection(100) _query.assert_called_once_with(sql="kill 100") - @patch.object(MysqlEngine, 'query') + @patch.object(MysqlEngine, "query") def test_seconds_behind_master(self, _query): new_engine = MysqlEngine(instance=self.ins1) new_engine.seconds_behind_master - _query.assert_called_once_with(sql="show slave status", close_conn=False, - cursorclass=MySQLdb.cursors.DictCursor) - - @patch.object(MysqlEngine, 'query') + _query.assert_called_once_with( + sql="show slave status", + close_conn=False, + cursorclass=MySQLdb.cursors.DictCursor, + ) + + @patch.object(MysqlEngine, "query") def test_processlist(self, _query): new_engine = MysqlEngine(instance=self.ins1) _query.return_value = ResultSet() - for command_type in [ 'Query', 'All', 'Not Sleep']: + for command_type in ["Query", "All", "Not Sleep"]: r = new_engine.processlist(command_type) self.assertIsInstance(r, ResultSet) - - @patch.object(MysqlEngine, 'query') + + @patch.object(MysqlEngine, "query") def test_get_kill_command(self, _query): new_engine = MysqlEngine(instance=self.ins1) - _query.return_value.rows = (('kill 1;',),('kill 2;',)) - r = new_engine.get_kill_command([1,2]) - self.assertEqual(r, 'kill 1;kill 2;') - - @patch('MySQLdb.connect.cursor.execute') - @patch('MySQLdb.connect.cursor') - @patch('MySQLdb.connect') - @patch.object(MysqlEngine, 'query') + _query.return_value.rows = (("kill 1;",), ("kill 2;",)) + r = new_engine.get_kill_command([1, 2]) + self.assertEqual(r, "kill 1;kill 2;") + + @patch("MySQLdb.connect.cursor.execute") + @patch("MySQLdb.connect.cursor") + @patch("MySQLdb.connect") + @patch.object(MysqlEngine, "query") def test_kill(self, _query, _connect, _cursor, _execute): new_engine = MysqlEngine(instance=self.ins1) - _query.return_value.rows = (('kill 1;',),('kill 2;',)) - _execute.return_value = ResultSet() - r = new_engine.kill([1,2]) + _query.return_value.rows = (("kill 1;",), ("kill 2;",)) + _execute.return_value = ResultSet() + r = new_engine.kill([1, 2]) self.assertIsInstance(r, ResultSet) - - @patch.object(MysqlEngine, 'query') + + @patch.object(MysqlEngine, "query") def test_tablesapce(self, _query): new_engine = MysqlEngine(instance=self.ins1) _query.return_value = ResultSet() r = new_engine.tablesapce() self.assertIsInstance(r, ResultSet) - - @patch.object(MysqlEngine, 'query') + + @patch.object(MysqlEngine, "query") def test_tablesapce_num(self, _query): new_engine = MysqlEngine(instance=self.ins1) _query.return_value = ResultSet() r = new_engine.tablesapce_num() self.assertIsInstance(r, ResultSet) - - @patch.object(MysqlEngine, 'query') - @patch('MySQLdb.connect') + + @patch.object(MysqlEngine, "query") + @patch("MySQLdb.connect") def test_trxandlocks(self, _connect, _query): new_engine = MysqlEngine(instance=self.ins1) _connect.return_value = Mock() - for v in ['5.7.0','8.0.1']: + for v in ["5.7.0", "8.0.1"]: _connect.return_value.get_server_info.return_value = v _query.return_value = ResultSet() r = new_engine.trxandlocks() self.assertIsInstance(r, ResultSet) - @patch.object(MysqlEngine, 'query') + @patch.object(MysqlEngine, "query") def test_get_long_transaction(self, _query): new_engine = MysqlEngine(instance=self.ins1) _query.return_value = ResultSet() r = new_engine.get_long_transaction() self.assertIsInstance(r, ResultSet) - + class TestRedis(TestCase): @classmethod def setUpClass(cls): - cls.ins = Instance(instance_name='some_ins', type='slave', db_type='redis', mode='standalone', - host='some_host', port=1366, user='ins_user', password='some_str') + cls.ins = Instance( + instance_name="some_ins", + type="slave", + db_type="redis", + mode="standalone", + host="some_host", + port=1366, + user="ins_user", + password="some_str", + ) cls.ins.save() @classmethod @@ -608,107 +683,127 @@ def tearDownClass(cls): SqlWorkflow.objects.all().delete() SqlWorkflowContent.objects.all().delete() - @patch('redis.Redis') + @patch("redis.Redis") def test_engine_base_info(self, _conn): new_engine = RedisEngine(instance=self.ins) - self.assertEqual(new_engine.name, 'Redis') - self.assertEqual(new_engine.info, 'Redis engine') + self.assertEqual(new_engine.name, "Redis") + self.assertEqual(new_engine.info, "Redis engine") - @patch('redis.Redis') + @patch("redis.Redis") def test_get_connection(self, _conn): new_engine = RedisEngine(instance=self.ins) new_engine.get_connection() _conn.assert_called_once() - @patch('redis.Redis.execute_command', return_value=[1, 2, 3]) + @patch("redis.Redis.execute_command", return_value=[1, 2, 3]) def test_query_return_list(self, _execute_command): new_engine = RedisEngine(instance=self.ins) - query_result = new_engine.query(db_name=0, sql='keys *', limit_num=100) + query_result = new_engine.query(db_name=0, sql="keys *", limit_num=100) self.assertIsInstance(query_result, ResultSet) self.assertTupleEqual(query_result.rows, ([1], [2], [3])) - @patch('redis.Redis.execute_command', return_value='text') + @patch("redis.Redis.execute_command", return_value="text") def test_query_return_str(self, _execute_command): new_engine = RedisEngine(instance=self.ins) - query_result = new_engine.query(db_name=0, sql='keys *', limit_num=100) + query_result = new_engine.query(db_name=0, sql="keys *", limit_num=100) self.assertIsInstance(query_result, ResultSet) - self.assertTupleEqual(query_result.rows, (['text'],)) + self.assertTupleEqual(query_result.rows, (["text"],)) - @patch('redis.Redis.execute_command', return_value='text') + @patch("redis.Redis.execute_command", return_value="text") def test_query_execute(self, _execute_command): new_engine = RedisEngine(instance=self.ins) - query_result = new_engine.query(db_name=0, sql='keys *', limit_num=100) + query_result = new_engine.query(db_name=0, sql="keys *", limit_num=100) self.assertIsInstance(query_result, ResultSet) - self.assertTupleEqual(query_result.rows, (['text'],)) + self.assertTupleEqual(query_result.rows, (["text"],)) - @patch('redis.Redis.config_get', return_value={"databases": 4}) + @patch("redis.Redis.config_get", return_value={"databases": 4}) def test_get_all_databases(self, _config_get): new_engine = RedisEngine(instance=self.ins) dbs = new_engine.get_all_databases() - self.assertListEqual(dbs.rows, ['0', '1', '2', '3']) + self.assertListEqual(dbs.rows, ["0", "1", "2", "3"]) def test_query_check_safe_cmd(self): safe_cmd = "keys 1*" new_engine = RedisEngine(instance=self.ins) check_result = new_engine.query_check(db_name=0, sql=safe_cmd) - self.assertDictEqual(check_result, - {'msg': '禁止执行该命令!', 'bad_query': True, 'filtered_sql': safe_cmd, 'has_star': False}) + self.assertDictEqual( + check_result, + { + "msg": "禁止执行该命令!", + "bad_query": True, + "filtered_sql": safe_cmd, + "has_star": False, + }, + ) def test_query_check_danger_cmd(self): safe_cmd = "keys *" new_engine = RedisEngine(instance=self.ins) check_result = new_engine.query_check(db_name=0, sql=safe_cmd) - self.assertDictEqual(check_result, - {'msg': '禁止执行该命令!', 'bad_query': True, 'filtered_sql': safe_cmd, 'has_star': False}) + self.assertDictEqual( + check_result, + { + "msg": "禁止执行该命令!", + "bad_query": True, + "filtered_sql": safe_cmd, + "has_star": False, + }, + ) def test_filter_sql(self): safe_cmd = "keys 1*" new_engine = RedisEngine(instance=self.ins) check_result = new_engine.filter_sql(sql=safe_cmd, limit_num=100) - self.assertEqual(check_result, 'keys 1*') + self.assertEqual(check_result, "keys 1*") def test_query_masking(self): query_result = ResultSet() new_engine = RedisEngine(instance=self.ins) - masking_result = new_engine.query_masking(db_name=0, sql='', resultset=query_result) + masking_result = new_engine.query_masking( + db_name=0, sql="", resultset=query_result + ) self.assertEqual(masking_result, query_result) def test_execute_check(self): - sql = 'set 1 1' - row = ReviewResult(id=1, - errlevel=0, - stagestatus='Audit completed', - errormessage='None', - sql=sql, - affected_rows=0, - execute_time=0) + sql = "set 1 1" + row = ReviewResult( + id=1, + errlevel=0, + stagestatus="Audit completed", + errormessage="None", + sql=sql, + affected_rows=0, + execute_time=0, + ) new_engine = RedisEngine(instance=self.ins) check_result = new_engine.execute_check(db_name=0, sql=sql) self.assertIsInstance(check_result, ReviewSet) self.assertEqual(check_result.rows[0].__dict__, row.__dict__) - @patch('redis.Redis.execute_command', return_value='text') + @patch("redis.Redis.execute_command", return_value="text") def test_execute_workflow_success(self, _execute_command): - sql = 'set 1 1' - row = ReviewResult(id=1, - errlevel=0, - stagestatus='Execute Successfully', - errormessage='None', - sql=sql, - affected_rows=0, - execute_time=0) + sql = "set 1 1" + row = ReviewResult( + id=1, + errlevel=0, + stagestatus="Execute Successfully", + errormessage="None", + sql=sql, + affected_rows=0, + execute_time=0, + ) wf = SqlWorkflow.objects.create( - workflow_name='some_name', + workflow_name="some_name", group_id=1, - group_name='g1', - engineer_display='', - audit_auth_groups='some_group', + group_name="g1", + engineer_display="", + audit_auth_groups="some_group", create_time=datetime.now() - timedelta(days=1), - status='workflow_finish', + status="workflow_finish", is_backup=True, instance=self.ins, - db_name='some_db', - syntax_type=1 + db_name="some_db", + syntax_type=1, ) SqlWorkflowContent.objects.create(workflow=wf, sql_content=sql) new_engine = RedisEngine(instance=self.ins) @@ -720,8 +815,15 @@ def test_execute_workflow_success(self, _execute_command): class TestPgSQL(TestCase): @classmethod def setUpClass(cls): - cls.ins = Instance(instance_name='some_ins', type='slave', db_type='pgsql', host='some_host', - port=1366, user='ins_user', password='some_str') + cls.ins = Instance( + instance_name="some_ins", + type="slave", + db_type="pgsql", + host="some_host", + port=1366, + user="ins_user", + password="some_str", + ) cls.ins.save() cls.sys_config = SysConfig() @@ -730,85 +832,130 @@ def tearDownClass(cls): cls.ins.delete() cls.sys_config.purge() - @patch('psycopg2.connect') + @patch("psycopg2.connect") def test_engine_base_info(self, _conn): new_engine = PgSQLEngine(instance=self.ins) - self.assertEqual(new_engine.name, 'PgSQL') - self.assertEqual(new_engine.info, 'PgSQL engine') + self.assertEqual(new_engine.name, "PgSQL") + self.assertEqual(new_engine.info, "PgSQL engine") - @patch('psycopg2.connect') + @patch("psycopg2.connect") def test_get_connection(self, _conn): new_engine = PgSQLEngine(instance=self.ins) new_engine.get_connection("some_dbname") _conn.assert_called_once() - @patch('psycopg2.connect.cursor.execute') - @patch('psycopg2.connect.cursor') - @patch('psycopg2.connect') + @patch("psycopg2.connect.cursor.execute") + @patch("psycopg2.connect.cursor") + @patch("psycopg2.connect") def test_query(self, _conn, _cursor, _execute): _conn.return_value.cursor.return_value.fetchmany.return_value = [(1,)] new_engine = PgSQLEngine(instance=self.ins) - query_result = new_engine.query(db_name="some_dbname", sql='select 1', limit_num=100, schema_name="some_schema") + query_result = new_engine.query( + db_name="some_dbname", + sql="select 1", + limit_num=100, + schema_name="some_schema", + ) self.assertIsInstance(query_result, ResultSet) self.assertListEqual(query_result.rows, [(1,)]) - @patch('psycopg2.connect.cursor.execute') - @patch('psycopg2.connect.cursor') - @patch('psycopg2.connect') + @patch("psycopg2.connect.cursor.execute") + @patch("psycopg2.connect.cursor") + @patch("psycopg2.connect") def test_query_not_limit(self, _conn, _cursor, _execute): _conn.return_value.cursor.return_value.fetchall.return_value = [(1,)] new_engine = PgSQLEngine(instance=self.ins) - query_result = new_engine.query(db_name="some_dbname", sql='select 1', limit_num=0, schema_name="some_schema") + query_result = new_engine.query( + db_name="some_dbname", + sql="select 1", + limit_num=0, + schema_name="some_schema", + ) self.assertIsInstance(query_result, ResultSet) self.assertListEqual(query_result.rows, [(1,)]) - @patch('sql.engines.pgsql.PgSQLEngine.query', - return_value=ResultSet(rows=[('postgres',), ('archery',), ('template1',), ('template0',)])) + @patch( + "sql.engines.pgsql.PgSQLEngine.query", + return_value=ResultSet( + rows=[("postgres",), ("archery",), ("template1",), ("template0",)] + ), + ) def test_get_all_databases(self, query): new_engine = PgSQLEngine(instance=self.ins) dbs = new_engine.get_all_databases() - self.assertListEqual(dbs.rows, ['archery']) - - @patch('sql.engines.pgsql.PgSQLEngine.query', - return_value=ResultSet(rows=[('information_schema',), ('archery',), ('pg_catalog',)])) + self.assertListEqual(dbs.rows, ["archery"]) + + @patch( + "sql.engines.pgsql.PgSQLEngine.query", + return_value=ResultSet( + rows=[("information_schema",), ("archery",), ("pg_catalog",)] + ), + ) def test_get_all_schemas(self, _query): new_engine = PgSQLEngine(instance=self.ins) - schemas = new_engine.get_all_schemas(db_name='archery') - self.assertListEqual(schemas.rows, ['archery']) + schemas = new_engine.get_all_schemas(db_name="archery") + self.assertListEqual(schemas.rows, ["archery"]) - @patch('sql.engines.pgsql.PgSQLEngine.query', return_value=ResultSet(rows=[('test',), ('test2',)])) + @patch( + "sql.engines.pgsql.PgSQLEngine.query", + return_value=ResultSet(rows=[("test",), ("test2",)]), + ) def test_get_all_tables(self, _query): new_engine = PgSQLEngine(instance=self.ins) - tables = new_engine.get_all_tables(db_name='archery', schema_name='archery') - self.assertListEqual(tables.rows, ['test2']) + tables = new_engine.get_all_tables(db_name="archery", schema_name="archery") + self.assertListEqual(tables.rows, ["test2"]) - @patch('sql.engines.pgsql.PgSQLEngine.query', - return_value=ResultSet(rows=[('id',), ('name',)])) + @patch( + "sql.engines.pgsql.PgSQLEngine.query", + return_value=ResultSet(rows=[("id",), ("name",)]), + ) def test_get_all_columns_by_tb(self, _query): new_engine = PgSQLEngine(instance=self.ins) - columns = new_engine.get_all_columns_by_tb(db_name='archery', tb_name='test2', schema_name='archery') - self.assertListEqual(columns.rows, ['id', 'name']) - - @patch('sql.engines.pgsql.PgSQLEngine.query', - return_value=ResultSet(rows=[('postgres',), ('archery',), ('template1',), ('template0',)])) + columns = new_engine.get_all_columns_by_tb( + db_name="archery", tb_name="test2", schema_name="archery" + ) + self.assertListEqual(columns.rows, ["id", "name"]) + + @patch( + "sql.engines.pgsql.PgSQLEngine.query", + return_value=ResultSet( + rows=[("postgres",), ("archery",), ("template1",), ("template0",)] + ), + ) def test_describe_table(self, _query): new_engine = PgSQLEngine(instance=self.ins) - describe = new_engine.describe_table(db_name='archery', schema_name='archery', tb_name='text') + describe = new_engine.describe_table( + db_name="archery", schema_name="archery", tb_name="text" + ) self.assertIsInstance(describe, ResultSet) def test_query_check_disable_sql(self): sql = "update xxx set a=1 " new_engine = PgSQLEngine(instance=self.ins) - check_result = new_engine.query_check(db_name='archery', sql=sql) - self.assertDictEqual(check_result, - {'msg': '不支持的查询语法类型!', 'bad_query': True, 'filtered_sql': sql.strip(), 'has_star': False}) + check_result = new_engine.query_check(db_name="archery", sql=sql) + self.assertDictEqual( + check_result, + { + "msg": "不支持的查询语法类型!", + "bad_query": True, + "filtered_sql": sql.strip(), + "has_star": False, + }, + ) def test_query_check_star_sql(self): sql = "select * from xx " new_engine = PgSQLEngine(instance=self.ins) - check_result = new_engine.query_check(db_name='archery', sql=sql) - self.assertDictEqual(check_result, - {'msg': 'SQL语句中含有 * ', 'bad_query': False, 'filtered_sql': sql.strip(), 'has_star': True}) + check_result = new_engine.query_check(db_name="archery", sql=sql) + self.assertDictEqual( + check_result, + { + "msg": "SQL语句中含有 * ", + "bad_query": False, + "filtered_sql": sql.strip(), + "has_star": True, + }, + ) def test_filter_sql_with_delimiter(self): sql = "select * from xx;" @@ -831,72 +978,84 @@ def test_filter_sql_with_limit(self): def test_query_masking(self): query_result = ResultSet() new_engine = PgSQLEngine(instance=self.ins) - masking_result = new_engine.query_masking(db_name=0, sql='', resultset=query_result) + masking_result = new_engine.query_masking( + db_name=0, sql="", resultset=query_result + ) self.assertEqual(masking_result, query_result) def test_execute_check_select_sql(self): - sql = 'select * from user;' - row = ReviewResult(id=1, errlevel=2, - stagestatus='驳回不支持语句', - errormessage='仅支持DML和DDL语句,查询语句请使用SQL查询功能!', - sql=sql) + sql = "select * from user;" + row = ReviewResult( + id=1, + errlevel=2, + stagestatus="驳回不支持语句", + errormessage="仅支持DML和DDL语句,查询语句请使用SQL查询功能!", + sql=sql, + ) new_engine = PgSQLEngine(instance=self.ins) - check_result = new_engine.execute_check(db_name='archery', sql=sql) + check_result = new_engine.execute_check(db_name="archery", sql=sql) self.assertIsInstance(check_result, ReviewSet) self.assertEqual(check_result.rows[0].__dict__, row.__dict__) def test_execute_check_critical_sql(self): - self.sys_config.set('critical_ddl_regex', '^|update') + self.sys_config.set("critical_ddl_regex", "^|update") self.sys_config.get_all_config() - sql = 'update user set id=1' - row = ReviewResult(id=1, errlevel=2, - stagestatus='驳回高危SQL', - errormessage='禁止提交匹配' + '^|update' + '条件的语句!', - sql=sql) + sql = "update user set id=1" + row = ReviewResult( + id=1, + errlevel=2, + stagestatus="驳回高危SQL", + errormessage="禁止提交匹配" + "^|update" + "条件的语句!", + sql=sql, + ) new_engine = PgSQLEngine(instance=self.ins) - check_result = new_engine.execute_check(db_name='archery', sql=sql) + check_result = new_engine.execute_check(db_name="archery", sql=sql) self.assertIsInstance(check_result, ReviewSet) self.assertEqual(check_result.rows[0].__dict__, row.__dict__) def test_execute_check_normal_sql(self): self.sys_config.purge() - sql = 'alter table tb set id=1' - row = ReviewResult(id=1, - errlevel=0, - stagestatus='Audit completed', - errormessage='None', - sql=sql, - affected_rows=0, - execute_time=0, ) + sql = "alter table tb set id=1" + row = ReviewResult( + id=1, + errlevel=0, + stagestatus="Audit completed", + errormessage="None", + sql=sql, + affected_rows=0, + execute_time=0, + ) new_engine = PgSQLEngine(instance=self.ins) - check_result = new_engine.execute_check(db_name='archery', sql=sql) + check_result = new_engine.execute_check(db_name="archery", sql=sql) self.assertIsInstance(check_result, ReviewSet) self.assertEqual(check_result.rows[0].__dict__, row.__dict__) - @patch('psycopg2.connect.cursor.execute') - @patch('psycopg2.connect.cursor') - @patch('psycopg2.connect') + @patch("psycopg2.connect.cursor.execute") + @patch("psycopg2.connect.cursor") + @patch("psycopg2.connect") def test_execute_workflow_success(self, _conn, _cursor, _execute): - sql = 'update user set id=1' - row = ReviewResult(id=1, - errlevel=0, - stagestatus='Execute Successfully', - errormessage='None', - sql=sql, - affected_rows=0, - execute_time=0) + sql = "update user set id=1" + row = ReviewResult( + id=1, + errlevel=0, + stagestatus="Execute Successfully", + errormessage="None", + sql=sql, + affected_rows=0, + execute_time=0, + ) wf = SqlWorkflow.objects.create( - workflow_name='some_name', + workflow_name="some_name", group_id=1, - group_name='g1', - engineer_display='', - audit_auth_groups='some_group', + group_name="g1", + engineer_display="", + audit_auth_groups="some_group", create_time=datetime.now() - timedelta(days=1), - status='workflow_finish', + status="workflow_finish", is_backup=True, instance=self.ins, - db_name='some_db', - syntax_type=1 + db_name="some_db", + syntax_type=1, ) SqlWorkflowContent.objects.create(workflow=wf, sql_content=sql) new_engine = PgSQLEngine(instance=self.ins) @@ -904,41 +1063,44 @@ def test_execute_workflow_success(self, _conn, _cursor, _execute): self.assertIsInstance(execute_result, ReviewSet) self.assertEqual(execute_result.rows[0].__dict__.keys(), row.__dict__.keys()) - @patch('psycopg2.connect.cursor.execute') - @patch('psycopg2.connect.cursor') - @patch('psycopg2.connect', return_value=RuntimeError) + @patch("psycopg2.connect.cursor.execute") + @patch("psycopg2.connect.cursor") + @patch("psycopg2.connect", return_value=RuntimeError) def test_execute_workflow_exception(self, _conn, _cursor, _execute): - sql = 'update user set id=1' - row = ReviewResult(id=1, - errlevel=2, - stagestatus='Execute Failed', - errormessage=f'异常信息:{f"Oracle命令执行报错,语句:{sql}"}', - sql=sql, - affected_rows=0, - execute_time=0, ) + sql = "update user set id=1" + row = ReviewResult( + id=1, + errlevel=2, + stagestatus="Execute Failed", + errormessage=f'异常信息:{f"Oracle命令执行报错,语句:{sql}"}', + sql=sql, + affected_rows=0, + execute_time=0, + ) wf = SqlWorkflow.objects.create( - workflow_name='some_name', + workflow_name="some_name", group_id=1, - group_name='g1', - engineer_display='', - audit_auth_groups='some_group', + group_name="g1", + engineer_display="", + audit_auth_groups="some_group", create_time=datetime.now() - timedelta(days=1), - status='workflow_finish', + status="workflow_finish", is_backup=True, instance=self.ins, - db_name='some_db', - syntax_type=1 + db_name="some_db", + syntax_type=1, ) SqlWorkflowContent.objects.create(workflow=wf, sql_content=sql) with self.assertRaises(AttributeError): new_engine = PgSQLEngine(instance=self.ins) execute_result = new_engine.execute_workflow(workflow=wf) self.assertIsInstance(execute_result, ReviewSet) - self.assertEqual(execute_result.rows[0].__dict__.keys(), row.__dict__.keys()) + self.assertEqual( + execute_result.rows[0].__dict__.keys(), row.__dict__.keys() + ) class TestModel(TestCase): - def setUp(self): pass @@ -963,23 +1125,34 @@ def test_result_set_rows_shadow(self): class TestGoInception(TestCase): def setUp(self): - self.ins = Instance.objects.create(instance_name='some_ins', type='slave', db_type='mysql', - host='some_host', - port=3306, user='ins_user', password='some_str') - self.ins_inc = Instance.objects.create(instance_name='some_ins_inc', type='slave', db_type='goinception', - host='some_host', port=4000) + self.ins = Instance.objects.create( + instance_name="some_ins", + type="slave", + db_type="mysql", + host="some_host", + port=3306, + user="ins_user", + password="some_str", + ) + self.ins_inc = Instance.objects.create( + instance_name="some_ins_inc", + type="slave", + db_type="goinception", + host="some_host", + port=4000, + ) self.wf = SqlWorkflow.objects.create( - workflow_name='some_name', + workflow_name="some_name", group_id=1, - group_name='g1', - engineer_display='', - audit_auth_groups='some_group', + group_name="g1", + engineer_display="", + audit_auth_groups="some_group", create_time=datetime.now() - timedelta(days=1), - status='workflow_finish', + status="workflow_finish", is_backup=True, instance=self.ins, - db_name='some_db', - syntax_type=1 + db_name="some_db", + syntax_type=1, ) SqlWorkflowContent.objects.create(workflow=self.wf) @@ -989,119 +1162,186 @@ def tearDown(self): SqlWorkflow.objects.all().delete() SqlWorkflowContent.objects.all().delete() - @patch('MySQLdb.connect') + @patch("MySQLdb.connect") def test_get_connection(self, _connect): new_engine = GoInceptionEngine() new_engine.get_connection() _connect.assert_called_once() - @patch('sql.engines.goinception.GoInceptionEngine.query') + @patch("sql.engines.goinception.GoInceptionEngine.query") def test_execute_check_normal_sql(self, _query): - sql = 'update user set id=100' - row = [1, 'CHECKED', 0, 'Audit completed', 'None', 'use archery', 0, "'0_0_0'", 'None', '0', '', ''] + sql = "update user set id=100" + row = [ + 1, + "CHECKED", + 0, + "Audit completed", + "None", + "use archery", + 0, + "'0_0_0'", + "None", + "0", + "", + "", + ] _query.return_value = ResultSet(full_sql=sql, rows=[row]) new_engine = GoInceptionEngine() check_result = new_engine.execute_check(instance=self.ins, db_name=0, sql=sql) self.assertIsInstance(check_result, ReviewSet) - @patch('sql.engines.goinception.GoInceptionEngine.query') + @patch("sql.engines.goinception.GoInceptionEngine.query") def test_execute_exception(self, _query): - sql = 'update user set id=100' - row = [1, 'CHECKED', 1, 'Execute failed', 'None', 'use archery', 0, "'0_0_0'", 'None', '0', '', ''] - column_list = ['order_id', 'stage', 'error_level', 'stage_status', 'error_message', 'sql', - 'affected_rows', 'sequence', 'backup_dbname', 'execute_time', 'sqlsha1', 'backup_time'] - _query.return_value = ResultSet(full_sql=sql, rows=[row], column_list=column_list) + sql = "update user set id=100" + row = [ + 1, + "CHECKED", + 1, + "Execute failed", + "None", + "use archery", + 0, + "'0_0_0'", + "None", + "0", + "", + "", + ] + column_list = [ + "order_id", + "stage", + "error_level", + "stage_status", + "error_message", + "sql", + "affected_rows", + "sequence", + "backup_dbname", + "execute_time", + "sqlsha1", + "backup_time", + ] + _query.return_value = ResultSet( + full_sql=sql, rows=[row], column_list=column_list + ) new_engine = GoInceptionEngine() execute_result = new_engine.execute(workflow=self.wf) self.assertIsInstance(execute_result, ReviewSet) - @patch('sql.engines.goinception.GoInceptionEngine.query') + @patch("sql.engines.goinception.GoInceptionEngine.query") def test_execute_finish(self, _query): - sql = 'update user set id=100' - row = [1, 'CHECKED', 0, 'Execute Successfully', 'None', 'use archery', 0, "'0_0_0'", 'None', '0', '', ''] - column_list = ['order_id', 'stage', 'error_level', 'stage_status', 'error_message', 'sql', - 'affected_rows', 'sequence', 'backup_dbname', 'execute_time', 'sqlsha1', 'backup_time'] - _query.return_value = ResultSet(full_sql=sql, rows=[row], column_list=column_list) + sql = "update user set id=100" + row = [ + 1, + "CHECKED", + 0, + "Execute Successfully", + "None", + "use archery", + 0, + "'0_0_0'", + "None", + "0", + "", + "", + ] + column_list = [ + "order_id", + "stage", + "error_level", + "stage_status", + "error_message", + "sql", + "affected_rows", + "sequence", + "backup_dbname", + "execute_time", + "sqlsha1", + "backup_time", + ] + _query.return_value = ResultSet( + full_sql=sql, rows=[row], column_list=column_list + ) new_engine = GoInceptionEngine() execute_result = new_engine.execute(workflow=self.wf) self.assertIsInstance(execute_result, ReviewSet) - @patch('MySQLdb.connect.cursor.execute') - @patch('MySQLdb.connect.cursor') - @patch('MySQLdb.connect') + @patch("MySQLdb.connect.cursor.execute") + @patch("MySQLdb.connect.cursor") + @patch("MySQLdb.connect") def test_query(self, _conn, _cursor, _execute): _conn.return_value.cursor.return_value.fetchall.return_value = [(1,)] new_engine = GoInceptionEngine() - query_result = new_engine.query(db_name=0, sql='select 1', limit_num=100) + query_result = new_engine.query(db_name=0, sql="select 1", limit_num=100) self.assertIsInstance(query_result, ResultSet) - @patch('MySQLdb.connect.cursor.execute') - @patch('MySQLdb.connect.cursor') - @patch('MySQLdb.connect') + @patch("MySQLdb.connect.cursor.execute") + @patch("MySQLdb.connect.cursor") + @patch("MySQLdb.connect") def test_query_not_limit(self, _conn, _cursor, _execute): _conn.return_value.cursor.return_value.fetchall.return_value = [(1,)] new_engine = GoInceptionEngine(instance=self.ins) - query_result = new_engine.query(db_name=0, sql='select 1', limit_num=0) + query_result = new_engine.query(db_name=0, sql="select 1", limit_num=0) self.assertIsInstance(query_result, ResultSet) - @patch('sql.engines.goinception.GoInceptionEngine.query') + @patch("sql.engines.goinception.GoInceptionEngine.query") def test_osc_get(self, _query): new_engine = GoInceptionEngine() - command = 'get' - sqlsha1 = 'xxxxx' + command = "get" + sqlsha1 = "xxxxx" sql = f"inception get osc_percent '{sqlsha1}';" _query.return_value = ResultSet(full_sql=sql, rows=[], column_list=[]) new_engine.osc_control(sqlsha1=sqlsha1, command=command) _query.assert_called_once_with(sql=sql) - @patch('sql.engines.goinception.GoInceptionEngine.query') + @patch("sql.engines.goinception.GoInceptionEngine.query") def test_osc_pause(self, _query): new_engine = GoInceptionEngine() - command = 'pause' - sqlsha1 = 'xxxxx' + command = "pause" + sqlsha1 = "xxxxx" sql = f"inception {command} osc '{sqlsha1}';" _query.return_value = ResultSet(full_sql=sql, rows=[], column_list=[]) new_engine.osc_control(sqlsha1=sqlsha1, command=command) _query.assert_called_once_with(sql=sql) - @patch('sql.engines.goinception.GoInceptionEngine.query') + @patch("sql.engines.goinception.GoInceptionEngine.query") def test_osc_resume(self, _query): new_engine = GoInceptionEngine() - command = 'resume' - sqlsha1 = 'xxxxx' + command = "resume" + sqlsha1 = "xxxxx" sql = f"inception {command} osc '{sqlsha1}';" _query.return_value = ResultSet(full_sql=sql, rows=[], column_list=[]) new_engine.osc_control(sqlsha1=sqlsha1, command=command) _query.assert_called_once_with(sql=sql) - @patch('sql.engines.goinception.GoInceptionEngine.query') + @patch("sql.engines.goinception.GoInceptionEngine.query") def test_osc_kill(self, _query): new_engine = GoInceptionEngine() - command = 'kill' - sqlsha1 = 'xxxxx' + command = "kill" + sqlsha1 = "xxxxx" sql = f"inception kill osc '{sqlsha1}';" _query.return_value = ResultSet(full_sql=sql, rows=[], column_list=[]) new_engine.osc_control(sqlsha1=sqlsha1, command=command) _query.assert_called_once_with(sql=sql) - @patch('sql.engines.goinception.GoInceptionEngine.query') + @patch("sql.engines.goinception.GoInceptionEngine.query") def test_get_variables(self, _query): new_engine = GoInceptionEngine(instance=self.ins_inc) new_engine.get_variables() sql = f"inception get variables;" _query.assert_called_once_with(sql=sql) - @patch('sql.engines.goinception.GoInceptionEngine.query') + @patch("sql.engines.goinception.GoInceptionEngine.query") def test_get_variables_filter(self, _query): new_engine = GoInceptionEngine(instance=self.ins_inc) - new_engine.get_variables(variables=['inception_osc_on']) + new_engine.get_variables(variables=["inception_osc_on"]) sql = f"inception get variables like 'inception_osc_on';" _query.assert_called_once_with(sql=sql) - @patch('sql.engines.goinception.GoInceptionEngine.query') + @patch("sql.engines.goinception.GoInceptionEngine.query") def test_set_variable(self, _query): new_engine = GoInceptionEngine(instance=self.ins) - new_engine.set_variable('inception_osc_on', 'on') + new_engine.set_variable("inception_osc_on", "on") _query.assert_called_once_with(sql="inception set inception_osc_on=on;") @@ -1109,21 +1349,28 @@ class TestOracle(TestCase): """Oracle 测试""" def setUp(self): - self.ins = Instance.objects.create(instance_name='some_ins', type='slave', db_type='oracle', - host='some_host', port=3306, user='ins_user', password='some_str', - sid='some_id') + self.ins = Instance.objects.create( + instance_name="some_ins", + type="slave", + db_type="oracle", + host="some_host", + port=3306, + user="ins_user", + password="some_str", + sid="some_id", + ) self.wf = SqlWorkflow.objects.create( - workflow_name='some_name', + workflow_name="some_name", group_id=1, - group_name='g1', - engineer_display='', - audit_auth_groups='some_group', + group_name="g1", + engineer_display="", + audit_auth_groups="some_group", create_time=datetime.now() - timedelta(days=1), - status='workflow_finish', + status="workflow_finish", is_backup=True, instance=self.ins, - db_name='some_db', - syntax_type=1 + db_name="some_db", + syntax_type=1, ) SqlWorkflowContent.objects.create(workflow=self.wf) self.sys_config = SysConfig() @@ -1134,8 +1381,8 @@ def tearDown(self): SqlWorkflow.objects.all().delete() SqlWorkflowContent.objects.all().delete() - @patch('cx_Oracle.makedsn') - @patch('cx_Oracle.connect') + @patch("cx_Oracle.makedsn") + @patch("cx_Oracle.connect") def test_get_connection(self, _connect, _makedsn): # 填写 sid 测试 new_engine = OracleEngine(self.ins) @@ -1145,8 +1392,8 @@ def test_get_connection(self, _connect, _makedsn): # 填写 service_name 测试 _connect.reset_mock() _makedsn.reset_mock() - self.ins.service_name = 'some_service' - self.ins.sid = '' + self.ins.service_name = "some_service" + self.ins.sid = "" self.ins.save() new_engine = OracleEngine(self.ins) new_engine.get_connection() @@ -1155,364 +1402,475 @@ def test_get_connection(self, _connect, _makedsn): # 都不填写, 检测 ValueError _connect.reset_mock() _makedsn.reset_mock() - self.ins.service_name = '' - self.ins.sid = '' + self.ins.service_name = "" + self.ins.sid = "" self.ins.save() new_engine = OracleEngine(self.ins) with self.assertRaises(ValueError): new_engine.get_connection() - @patch('cx_Oracle.connect') + @patch("cx_Oracle.connect") def test_engine_base_info(self, _conn): new_engine = OracleEngine(instance=self.ins) - self.assertEqual(new_engine.name, 'Oracle') - self.assertEqual(new_engine.info, 'Oracle engine') - _conn.return_value.version = '12.1.0.2.0' - self.assertTupleEqual(new_engine.server_version, ('12', '1', '0')) - - @patch('cx_Oracle.connect.cursor.execute') - @patch('cx_Oracle.connect.cursor') - @patch('cx_Oracle.connect') + self.assertEqual(new_engine.name, "Oracle") + self.assertEqual(new_engine.info, "Oracle engine") + _conn.return_value.version = "12.1.0.2.0" + self.assertTupleEqual(new_engine.server_version, ("12", "1", "0")) + + @patch("cx_Oracle.connect.cursor.execute") + @patch("cx_Oracle.connect.cursor") + @patch("cx_Oracle.connect") def test_query(self, _conn, _cursor, _execute): _conn.return_value.cursor.return_value.fetchmany.return_value = [(1,)] new_engine = OracleEngine(instance=self.ins) - query_result = new_engine.query(db_name='archery', sql='select 1', limit_num=100) + query_result = new_engine.query( + db_name="archery", sql="select 1", limit_num=100 + ) self.assertIsInstance(query_result, ResultSet) self.assertListEqual(query_result.rows, [(1,)]) - @patch('cx_Oracle.connect.cursor.execute') - @patch('cx_Oracle.connect.cursor') - @patch('cx_Oracle.connect') + @patch("cx_Oracle.connect.cursor.execute") + @patch("cx_Oracle.connect.cursor") + @patch("cx_Oracle.connect") def test_query_not_limit(self, _conn, _cursor, _execute): _conn.return_value.cursor.return_value.fetchall.return_value = [(1,)] new_engine = OracleEngine(instance=self.ins) - query_result = new_engine.query(db_name=0, sql='select 1', limit_num=0) + query_result = new_engine.query(db_name=0, sql="select 1", limit_num=0) self.assertIsInstance(query_result, ResultSet) self.assertListEqual(query_result.rows, [(1,)]) - @patch('sql.engines.oracle.OracleEngine.query', - return_value=ResultSet(rows=[('AUD_SYS',), ('archery',), ('ANONYMOUS',)])) + @patch( + "sql.engines.oracle.OracleEngine.query", + return_value=ResultSet(rows=[("AUD_SYS",), ("archery",), ("ANONYMOUS",)]), + ) def test_get_all_databases(self, _query): new_engine = OracleEngine(instance=self.ins) dbs = new_engine.get_all_databases() - self.assertListEqual(dbs.rows, ['archery']) + self.assertListEqual(dbs.rows, ["archery"]) - @patch('sql.engines.oracle.OracleEngine.query', - return_value=ResultSet(rows=[('AUD_SYS',), ('archery',), ('ANONYMOUS',)])) + @patch( + "sql.engines.oracle.OracleEngine.query", + return_value=ResultSet(rows=[("AUD_SYS",), ("archery",), ("ANONYMOUS",)]), + ) def test__get_all_databases(self, _query): new_engine = OracleEngine(instance=self.ins) dbs = new_engine._get_all_databases() - self.assertListEqual(dbs.rows, ['AUD_SYS', 'archery', 'ANONYMOUS']) + self.assertListEqual(dbs.rows, ["AUD_SYS", "archery", "ANONYMOUS"]) - @patch('sql.engines.oracle.OracleEngine.query', - return_value=ResultSet(rows=[('archery',)])) + @patch( + "sql.engines.oracle.OracleEngine.query", + return_value=ResultSet(rows=[("archery",)]), + ) def test__get_all_instances(self, _query): new_engine = OracleEngine(instance=self.ins) dbs = new_engine._get_all_instances() - self.assertListEqual(dbs.rows, ['archery']) + self.assertListEqual(dbs.rows, ["archery"]) - @patch('sql.engines.oracle.OracleEngine.query', - return_value=ResultSet(rows=[('ANONYMOUS',), ('archery',), ('SYSTEM',)])) + @patch( + "sql.engines.oracle.OracleEngine.query", + return_value=ResultSet(rows=[("ANONYMOUS",), ("archery",), ("SYSTEM",)]), + ) def test_get_all_schemas(self, _query): new_engine = OracleEngine(instance=self.ins) schemas = new_engine._get_all_schemas() - self.assertListEqual(schemas.rows, ['archery']) + self.assertListEqual(schemas.rows, ["archery"]) - @patch('sql.engines.oracle.OracleEngine.query', return_value=ResultSet(rows=[('test',), ('test2',)])) + @patch( + "sql.engines.oracle.OracleEngine.query", + return_value=ResultSet(rows=[("test",), ("test2",)]), + ) def test_get_all_tables(self, _query): new_engine = OracleEngine(instance=self.ins) - tables = new_engine.get_all_tables(db_name='archery') - self.assertListEqual(tables.rows, ['test2']) + tables = new_engine.get_all_tables(db_name="archery") + self.assertListEqual(tables.rows, ["test2"]) - @patch('sql.engines.oracle.OracleEngine.query', - return_value=ResultSet(rows=[('id',), ('name',)])) + @patch( + "sql.engines.oracle.OracleEngine.query", + return_value=ResultSet(rows=[("id",), ("name",)]), + ) def test_get_all_columns_by_tb(self, _query): new_engine = OracleEngine(instance=self.ins) - columns = new_engine.get_all_columns_by_tb(db_name='archery', tb_name='test2') - self.assertListEqual(columns.rows, ['id', 'name']) + columns = new_engine.get_all_columns_by_tb(db_name="archery", tb_name="test2") + self.assertListEqual(columns.rows, ["id", "name"]) - @patch('sql.engines.oracle.OracleEngine.query', - return_value=ResultSet(rows=[('archery',), ('template1',), ('template0',)])) + @patch( + "sql.engines.oracle.OracleEngine.query", + return_value=ResultSet(rows=[("archery",), ("template1",), ("template0",)]), + ) def test_describe_table(self, _query): new_engine = OracleEngine(instance=self.ins) - describe = new_engine.describe_table(db_name='archery', tb_name='text') + describe = new_engine.describe_table(db_name="archery", tb_name="text") self.assertIsInstance(describe, ResultSet) def test_query_check_disable_sql(self): sql = "update xxx set a=1;" new_engine = OracleEngine(instance=self.ins) - check_result = new_engine.query_check(db_name='archery', sql=sql) - self.assertDictEqual(check_result, - {'msg': '不支持语法!', 'bad_query': True, 'filtered_sql': sql.strip(';'), - 'has_star': False}) + check_result = new_engine.query_check(db_name="archery", sql=sql) + self.assertDictEqual( + check_result, + { + "msg": "不支持语法!", + "bad_query": True, + "filtered_sql": sql.strip(";"), + "has_star": False, + }, + ) - @patch('sql.engines.oracle.OracleEngine.explain_check', return_value={'msg': '', 'rows': 0}) + @patch( + "sql.engines.oracle.OracleEngine.explain_check", + return_value={"msg": "", "rows": 0}, + ) def test_query_check_star_sql(self, _explain_check): sql = "select * from xx;" new_engine = OracleEngine(instance=self.ins) - check_result = new_engine.query_check(db_name='archery', sql=sql) - self.assertDictEqual(check_result, - {'msg': '禁止使用 * 关键词\n', 'bad_query': False, 'filtered_sql': sql.strip(';'), - 'has_star': True}) + check_result = new_engine.query_check(db_name="archery", sql=sql) + self.assertDictEqual( + check_result, + { + "msg": "禁止使用 * 关键词\n", + "bad_query": False, + "filtered_sql": sql.strip(";"), + "has_star": True, + }, + ) def test_query_check_IndexError(self): sql = "" new_engine = OracleEngine(instance=self.ins) - check_result = new_engine.query_check(db_name='archery', sql=sql) - self.assertDictEqual(check_result, - {'msg': '没有有效的SQL语句', 'bad_query': True, 'filtered_sql': sql.strip(), 'has_star': False}) + check_result = new_engine.query_check(db_name="archery", sql=sql) + self.assertDictEqual( + check_result, + { + "msg": "没有有效的SQL语句", + "bad_query": True, + "filtered_sql": sql.strip(), + "has_star": False, + }, + ) def test_filter_sql_with_delimiter(self): sql = "select * from xx;" new_engine = OracleEngine(instance=self.ins) check_result = new_engine.filter_sql(sql=sql, limit_num=100) - self.assertEqual(check_result, "select sql_audit.* from (select * from xx) sql_audit where rownum <= 100") + self.assertEqual( + check_result, + "select sql_audit.* from (select * from xx) sql_audit where rownum <= 100", + ) def test_filter_sql_with_delimiter_and_where(self): sql = "select * from xx where id>1;" new_engine = OracleEngine(instance=self.ins) check_result = new_engine.filter_sql(sql=sql, limit_num=100) - self.assertEqual(check_result, - "select sql_audit.* from (select * from xx where id>1) sql_audit where rownum <= 100") + self.assertEqual( + check_result, + "select sql_audit.* from (select * from xx where id>1) sql_audit where rownum <= 100", + ) def test_filter_sql_without_delimiter(self): sql = "select * from xx;" new_engine = OracleEngine(instance=self.ins) check_result = new_engine.filter_sql(sql=sql, limit_num=100) - self.assertEqual(check_result, "select sql_audit.* from (select * from xx) sql_audit where rownum <= 100") + self.assertEqual( + check_result, + "select sql_audit.* from (select * from xx) sql_audit where rownum <= 100", + ) def test_filter_sql_with_limit(self): sql = "select * from xx limit 10;" new_engine = OracleEngine(instance=self.ins) check_result = new_engine.filter_sql(sql=sql, limit_num=1) - self.assertEqual(check_result, - "select sql_audit.* from (select * from xx limit 10) sql_audit where rownum <= 1") + self.assertEqual( + check_result, + "select sql_audit.* from (select * from xx limit 10) sql_audit where rownum <= 1", + ) def test_query_masking(self): query_result = ResultSet() new_engine = OracleEngine(instance=self.ins) - masking_result = new_engine.query_masking(sql='select 1 from dual', resultset=query_result) + masking_result = new_engine.query_masking( + sql="select 1 from dual", resultset=query_result + ) self.assertEqual(masking_result, query_result) def test_execute_check_select_sql(self): - sql = 'select * from user;' - row = ReviewResult(id=1, errlevel=2, - stagestatus='驳回不支持语句', - errormessage='仅支持DML和DDL语句,查询语句请使用SQL查询功能!', - sql=sqlparse.format(sql, strip_comments=True, reindent=True, keyword_case='lower')) + sql = "select * from user;" + row = ReviewResult( + id=1, + errlevel=2, + stagestatus="驳回不支持语句", + errormessage="仅支持DML和DDL语句,查询语句请使用SQL查询功能!", + sql=sqlparse.format( + sql, strip_comments=True, reindent=True, keyword_case="lower" + ), + ) new_engine = OracleEngine(instance=self.ins) - check_result = new_engine.execute_check(db_name='archery', sql=sql) + check_result = new_engine.execute_check(db_name="archery", sql=sql) self.assertIsInstance(check_result, ReviewSet) self.assertEqual(check_result.rows[0].__dict__, row.__dict__) def test_execute_check_critical_sql(self): - self.sys_config.set('critical_ddl_regex', '^|update') + self.sys_config.set("critical_ddl_regex", "^|update") self.sys_config.get_all_config() - sql = 'update user set id=1' - row = ReviewResult(id=1, errlevel=2, - stagestatus='驳回高危SQL', - errormessage='禁止提交匹配' + '^|update' + '条件的语句!', - sql=sqlparse.format(sql, strip_comments=True, reindent=True, keyword_case='lower')) + sql = "update user set id=1" + row = ReviewResult( + id=1, + errlevel=2, + stagestatus="驳回高危SQL", + errormessage="禁止提交匹配" + "^|update" + "条件的语句!", + sql=sqlparse.format( + sql, strip_comments=True, reindent=True, keyword_case="lower" + ), + ) new_engine = OracleEngine(instance=self.ins) - check_result = new_engine.execute_check(db_name='archery', sql=sql) + check_result = new_engine.execute_check(db_name="archery", sql=sql) self.assertIsInstance(check_result, ReviewSet) self.assertEqual(check_result.rows[0].__dict__, row.__dict__) - @patch('sql.engines.oracle.OracleEngine.explain_check', return_value={'msg': '', 'rows': 0}) - @patch('sql.engines.oracle.OracleEngine.get_sql_first_object_name', return_value='tb') - @patch('sql.engines.oracle.OracleEngine.object_name_check', return_value=True) - def test_execute_check_normal_sql(self, _explain_check, _get_sql_first_object_name, _object_name_check): + @patch( + "sql.engines.oracle.OracleEngine.explain_check", + return_value={"msg": "", "rows": 0}, + ) + @patch( + "sql.engines.oracle.OracleEngine.get_sql_first_object_name", return_value="tb" + ) + @patch("sql.engines.oracle.OracleEngine.object_name_check", return_value=True) + def test_execute_check_normal_sql( + self, _explain_check, _get_sql_first_object_name, _object_name_check + ): self.sys_config.purge() - sql = 'alter table tb set id=1' - row = ReviewResult(id=1, - errlevel=1, - stagestatus='当前平台,此语法不支持审核!', - errormessage='当前平台,此语法不支持审核!', - sql=sqlparse.format(sql, strip_comments=True, reindent=True, keyword_case='lower'), - affected_rows=0, - execute_time=0, - stmt_type='SQL', - object_owner='', - object_type='', - object_name='', - ) + sql = "alter table tb set id=1" + row = ReviewResult( + id=1, + errlevel=1, + stagestatus="当前平台,此语法不支持审核!", + errormessage="当前平台,此语法不支持审核!", + sql=sqlparse.format( + sql, strip_comments=True, reindent=True, keyword_case="lower" + ), + affected_rows=0, + execute_time=0, + stmt_type="SQL", + object_owner="", + object_type="", + object_name="", + ) new_engine = OracleEngine(instance=self.ins) - check_result = new_engine.execute_check(db_name='archery', sql=sql) + check_result = new_engine.execute_check(db_name="archery", sql=sql) self.assertIsInstance(check_result, ReviewSet) self.assertEqual(check_result.rows[0].__dict__, row.__dict__) - @patch('cx_Oracle.connect.cursor.execute') - @patch('cx_Oracle.connect.cursor') - @patch('cx_Oracle.connect') + @patch("cx_Oracle.connect.cursor.execute") + @patch("cx_Oracle.connect.cursor") + @patch("cx_Oracle.connect") def test_execute_workflow_success(self, _conn, _cursor, _execute): - sql = 'update user set id=1' - review_row = ReviewResult(id=1, - errlevel=0, - stagestatus='Execute Successfully', - errormessage='None', - sql=sql, - affected_rows=0, - execute_time=0, - stmt_type='SQL', - object_owner='', - object_type='', - object_name='', ) - execute_row = ReviewResult(id=1, - errlevel=0, - stagestatus='Execute Successfully', - errormessage='None', - sql=sql, - affected_rows=0, - execute_time=0) + sql = "update user set id=1" + review_row = ReviewResult( + id=1, + errlevel=0, + stagestatus="Execute Successfully", + errormessage="None", + sql=sql, + affected_rows=0, + execute_time=0, + stmt_type="SQL", + object_owner="", + object_type="", + object_name="", + ) + execute_row = ReviewResult( + id=1, + errlevel=0, + stagestatus="Execute Successfully", + errormessage="None", + sql=sql, + affected_rows=0, + execute_time=0, + ) wf = SqlWorkflow.objects.create( - workflow_name='some_name', + workflow_name="some_name", group_id=1, - group_name='g1', - engineer_display='', - audit_auth_groups='some_group', + group_name="g1", + engineer_display="", + audit_auth_groups="some_group", create_time=datetime.now() - timedelta(days=1), - status='workflow_finish', + status="workflow_finish", is_backup=True, instance=self.ins, - db_name='some_db', - syntax_type=1 + db_name="some_db", + syntax_type=1, + ) + SqlWorkflowContent.objects.create( + workflow=wf, + sql_content=sql, + review_content=ReviewSet(rows=[review_row]).json(), ) - SqlWorkflowContent.objects.create(workflow=wf, sql_content=sql, - review_content=ReviewSet(rows=[review_row]).json()) new_engine = OracleEngine(instance=self.ins) execute_result = new_engine.execute_workflow(workflow=wf) self.assertIsInstance(execute_result, ReviewSet) - self.assertEqual(execute_result.rows[0].__dict__.keys(), execute_row.__dict__.keys()) + self.assertEqual( + execute_result.rows[0].__dict__.keys(), execute_row.__dict__.keys() + ) - @patch('cx_Oracle.connect.cursor.execute') - @patch('cx_Oracle.connect.cursor') - @patch('cx_Oracle.connect', return_value=RuntimeError) + @patch("cx_Oracle.connect.cursor.execute") + @patch("cx_Oracle.connect.cursor") + @patch("cx_Oracle.connect", return_value=RuntimeError) def test_execute_workflow_exception(self, _conn, _cursor, _execute): - sql = 'update user set id=1' - row = ReviewResult(id=1, - errlevel=2, - stagestatus='Execute Failed', - errormessage=f'异常信息:{f"Oracle命令执行报错,语句:{sql}"}', - sql=sql, - affected_rows=0, - execute_time=0, - stmt_type='SQL', - object_owner='', - object_type='', - object_name='', - ) + sql = "update user set id=1" + row = ReviewResult( + id=1, + errlevel=2, + stagestatus="Execute Failed", + errormessage=f'异常信息:{f"Oracle命令执行报错,语句:{sql}"}', + sql=sql, + affected_rows=0, + execute_time=0, + stmt_type="SQL", + object_owner="", + object_type="", + object_name="", + ) wf = SqlWorkflow.objects.create( - workflow_name='some_name', + workflow_name="some_name", group_id=1, - group_name='g1', - engineer_display='', - audit_auth_groups='some_group', + group_name="g1", + engineer_display="", + audit_auth_groups="some_group", create_time=datetime.now() - timedelta(days=1), - status='workflow_finish', + status="workflow_finish", is_backup=True, instance=self.ins, - db_name='some_db', - syntax_type=1 + db_name="some_db", + syntax_type=1, + ) + SqlWorkflowContent.objects.create( + workflow=wf, sql_content=sql, review_content=ReviewSet(rows=[row]).json() ) - SqlWorkflowContent.objects.create(workflow=wf, sql_content=sql, review_content=ReviewSet(rows=[row]).json()) with self.assertRaises(AttributeError): new_engine = OracleEngine(instance=self.ins) execute_result = new_engine.execute_workflow(workflow=wf) self.assertIsInstance(execute_result, ReviewSet) - self.assertEqual(execute_result.rows[0].__dict__.keys(), row.__dict__.keys()) + self.assertEqual( + execute_result.rows[0].__dict__.keys(), row.__dict__.keys() + ) class MongoTest(TestCase): def setUp(self) -> None: - self.ins = Instance.objects.create(instance_name='some_ins', type='slave', db_type='mongo', - host='some_host', port=3306, user='ins_user') + self.ins = Instance.objects.create( + instance_name="some_ins", + type="slave", + db_type="mongo", + host="some_host", + port=3306, + user="ins_user", + ) self.engine = MongoEngine(instance=self.ins) def tearDown(self) -> None: self.ins.delete() - @patch('sql.engines.mongo.pymongo') + @patch("sql.engines.mongo.pymongo") def test_get_connection(self, mock_pymongo): _ = self.engine.get_connection() mock_pymongo.MongoClient.assert_called_once() - @patch('sql.engines.mongo.MongoEngine.get_connection') + @patch("sql.engines.mongo.MongoEngine.get_connection") def test_query(self, mock_get_connection): # TODO 正常查询还没做 test_sql = """db.job.find().count()""" - self.assertIsInstance(self.engine.query('some_db', test_sql), ResultSet) + self.assertIsInstance(self.engine.query("some_db", test_sql), ResultSet) - @patch('sql.engines.mongo.MongoEngine.get_all_tables') + @patch("sql.engines.mongo.MongoEngine.get_all_tables") def test_query_check(self, mock_get_all_tables): test_sql = """db.job.find().count()""" - mock_get_all_tables.return_value.rows = ("job") - check_result = self.engine.query_check('some_db', sql=test_sql) + mock_get_all_tables.return_value.rows = "job" + check_result = self.engine.query_check("some_db", sql=test_sql) mock_get_all_tables.assert_called_once() - self.assertEqual(False, check_result.get('bad_query')) + self.assertEqual(False, check_result.get("bad_query")) - @patch('sql.engines.mongo.MongoEngine.get_connection') + @patch("sql.engines.mongo.MongoEngine.get_connection") def test_get_all_databases(self, mock_get_connection): db_list = self.engine.get_all_databases() self.assertIsInstance(db_list, ResultSet) # mock_get_connection.return_value.list_database_names.assert_called_once() - @patch('sql.engines.mongo.MongoEngine.get_connection') + @patch("sql.engines.mongo.MongoEngine.get_connection") def test_get_all_tables(self, mock_get_connection): mock_db = Mock() # 下面是查表示例返回结果 - mock_db.list_collection_names.return_value = ['u', 'v', 'w'] - mock_get_connection.return_value = {'some_db': mock_db} - table_list = self.engine.get_all_tables('some_db') + mock_db.list_collection_names.return_value = ["u", "v", "w"] + mock_get_connection.return_value = {"some_db": mock_db} + table_list = self.engine.get_all_tables("some_db") mock_db.list_collection_names.assert_called_once() - self.assertEqual(table_list.rows, ['u', 'v', 'w']) + self.assertEqual(table_list.rows, ["u", "v", "w"]) def test_filter_sql(self): sql = """explain db.job.find().count()""" check_result = self.engine.filter_sql(sql, 0) - self.assertEqual(check_result, 'db.job.find().count().explain()') + self.assertEqual(check_result, "db.job.find().count().explain()") - @patch('sql.engines.mongo.MongoEngine.exec_cmd') + @patch("sql.engines.mongo.MongoEngine.exec_cmd") def test_get_slave(self, mock_exec_cmd): mock_exec_cmd.return_value = "172.30.2.123:27017" flag = self.engine.get_slave() self.assertEqual(True, flag) - @patch('sql.engines.mongo.MongoEngine.get_all_columns_by_tb') + @patch("sql.engines.mongo.MongoEngine.get_all_columns_by_tb") def test_parse_tuple(self, mock_get_all_columns_by_tb): cols = ["_id", "title", "tags", "likes"] mock_get_all_columns_by_tb.return_value.rows = cols - cursor = [{'_id': {'$oid': '5f10162029684728e70045ab'}, 'title': 'MongoDB', 'tags': 'mongodb', 'likes': 100}] - rows, columns = self.engine.parse_tuple(cursor, 'some_db', 'job') - alldata = json.dumps(cursor[0], ensure_ascii=False, indent=2, separators=(",", ":")) - rerows = (alldata, "ObjectId('5f10162029684728e70045ab')", 'MongoDB', 'mongodb', '100') - self.assertEqual(columns, ['mongodballdata', '_id', 'title', 'tags', 'likes']) + cursor = [ + { + "_id": {"$oid": "5f10162029684728e70045ab"}, + "title": "MongoDB", + "tags": "mongodb", + "likes": 100, + } + ] + rows, columns = self.engine.parse_tuple(cursor, "some_db", "job") + alldata = json.dumps( + cursor[0], ensure_ascii=False, indent=2, separators=(",", ":") + ) + rerows = ( + alldata, + "ObjectId('5f10162029684728e70045ab')", + "MongoDB", + "mongodb", + "100", + ) + self.assertEqual(columns, ["mongodballdata", "_id", "title", "tags", "likes"]) self.assertEqual(rows[0], rerows) - @patch('sql.engines.mongo.MongoEngine.get_table_conut') - @patch('sql.engines.mongo.MongoEngine.get_all_tables') + @patch("sql.engines.mongo.MongoEngine.get_table_conut") + @patch("sql.engines.mongo.MongoEngine.get_all_tables") def test_execute_check(self, mock_get_all_tables, mock_get_table_conut): - sql = '''db.job.createIndex({"skuId":1},{background:true});''' - mock_get_all_tables.return_value.rows = ("job") + sql = """db.job.createIndex({"skuId":1},{background:true});""" + mock_get_all_tables.return_value.rows = "job" mock_get_table_conut.return_value = 1000 - row = ReviewResult(id=1, errlevel=0, - stagestatus='Audit completed', - errormessage='检测通过', - affected_rows=1000, - sql=sql, - execute_time=0) - check_result = self.engine.execute_check('some_db', sql) - self.assertEqual(check_result.rows[0].__dict__["errormessage"], row.__dict__["errormessage"]) - - @patch('sql.engines.mongo.MongoEngine.exec_cmd') - @patch('sql.engines.mongo.MongoEngine.get_master') + row = ReviewResult( + id=1, + errlevel=0, + stagestatus="Audit completed", + errormessage="检测通过", + affected_rows=1000, + sql=sql, + execute_time=0, + ) + check_result = self.engine.execute_check("some_db", sql) + self.assertEqual( + check_result.rows[0].__dict__["errormessage"], row.__dict__["errormessage"] + ) + + @patch("sql.engines.mongo.MongoEngine.exec_cmd") + @patch("sql.engines.mongo.MongoEngine.get_master") def test_execute(self, mock_get_master, mock_exec_cmd): - sql = '''db.job.find().createIndex({"skuId":1},{background:true})''' - mock_exec_cmd.return_value = '''{ + sql = """db.job.find().createIndex({"skuId":1},{background:true})""" + mock_exec_cmd.return_value = """{ "createdCollectionAutomatically" : false, "numIndexesBefore" : 2, "numIndexesAfter" : 3, "ok" : 1 - }''' + }""" check_result = self.engine.execute("some_db", sql) mock_get_master.assert_called_once() @@ -1520,74 +1878,92 @@ def test_execute(self, mock_get_master, mock_exec_cmd): def test_fill_query_columns(self): columns = ["_id", "title", "tags", "likes"] - cursor = [{"_id": {"$oid": "5f10162029684728e70045ab"}, "title": "MongoDB", "text": "archery", "likes": 100}, - {"_id": {"$oid": "7f10162029684728e70045ab"}, "author": "archery"}] + cursor = [ + { + "_id": {"$oid": "5f10162029684728e70045ab"}, + "title": "MongoDB", + "text": "archery", + "likes": 100, + }, + {"_id": {"$oid": "7f10162029684728e70045ab"}, "author": "archery"}, + ] cols = self.engine.fill_query_columns(cursor, columns=columns) self.assertEqual(cols, ["_id", "title", "tags", "likes", "text", "author"]) - @patch('sql.engines.mongo.MongoEngine.get_connection') + @patch("sql.engines.mongo.MongoEngine.get_connection") def test_current_op(self, mock_get_connection): - class Aggregate(): + class Aggregate: def __enter__(self): - yield {'client':'single_client'} - yield {'clientMetadata':{'mongos':{'client':'sharding_client'}}} - def __exit__(self,*arg, **kwargs): + yield {"client": "single_client"} + yield {"clientMetadata": {"mongos": {"client": "sharding_client"}}} + + def __exit__(self, *arg, **kwargs): pass + mock_conn = Mock() mock_conn.admin.aggregate.return_value = Aggregate() mock_get_connection.return_value = mock_conn - command_types = ['Full','All','Inner','Active'] + command_types = ["Full", "All", "Inner", "Active"] for command_type in command_types: result_set = self.engine.current_op(command_type) self.assertIsInstance(result_set, ResultSet) - @patch('sql.engines.mongo.MongoEngine.get_connection') - def test_get_kill_command(self, mock_get_connection): - class Aggregate(): + @patch("sql.engines.mongo.MongoEngine.get_connection") + def test_get_kill_command(self, mock_get_connection): + class Aggregate: def __enter__(self): - yield {'opid': 111} - yield {'opid': 'shard1: 111'} - def __exit__(self,*arg, **kwargs): + yield {"opid": 111} + yield {"opid": "shard1: 111"} + + def __exit__(self, *arg, **kwargs): pass + mock_conn = Mock() mock_conn.admin.aggregate.return_value = Aggregate() mock_get_connection.return_value = mock_conn - kill_command1 = self.engine.get_kill_command([111,222]) - kill_command2 = self.engine.get_kill_command(['shard1: 111','shard2: 222']) - self.assertEqual(kill_command1, 'db.killOp(111);') + kill_command1 = self.engine.get_kill_command([111, 222]) + kill_command2 = self.engine.get_kill_command(["shard1: 111", "shard2: 222"]) + self.assertEqual(kill_command1, "db.killOp(111);") self.assertEqual(kill_command2, 'db.killOp("shard1: 111");') - @patch('sql.engines.mongo.MongoEngine.get_connection') + @patch("sql.engines.mongo.MongoEngine.get_connection") def test_kill_op(self, mock_get_connection): def command(self, *arg, **kwargs): pass + mock_conn = Mock() mock_conn.admin.command.return_value = command mock_get_connection.return_value = mock_conn - self.engine.kill_op([111,222]) - self.engine.kill_op(['shards: 111','shards: 222']) + self.engine.kill_op([111, 222]) + self.engine.kill_op(["shards: 111", "shards: 222"]) mock_conn.admin.command.assert_called() class TestClickHouse(TestCase): - def setUp(self): - self.ins1 = Instance(instance_name='some_ins', type='slave', db_type='clickhouse', host='some_host', - port=9000, user='ins_user', password='some_str') + self.ins1 = Instance( + instance_name="some_ins", + type="slave", + db_type="clickhouse", + host="some_host", + port=9000, + user="ins_user", + password="some_str", + ) self.ins1.save() self.sys_config = SysConfig() self.wf = SqlWorkflow.objects.create( - workflow_name='some_name', + workflow_name="some_name", group_id=1, - group_name='g1', - engineer_display='', - audit_auth_groups='some_group', + group_name="g1", + engineer_display="", + audit_auth_groups="some_group", create_time=datetime.now() - timedelta(days=1), - status='workflow_finish', + status="workflow_finish", is_backup=False, instance=self.ins1, - db_name='some_db', - syntax_type=1 + db_name="some_db", + syntax_type=1, ) SqlWorkflowContent.objects.create(workflow=self.wf) @@ -1597,205 +1973,234 @@ def tearDown(self): SqlWorkflow.objects.all().delete() SqlWorkflowContent.objects.all().delete() - @patch.object(ClickHouseEngine, 'query') + @patch.object(ClickHouseEngine, "query") def test_server_version(self, mock_query): result = ResultSet() - result.rows = [('ClickHouse 22.1.3.7',)] + result.rows = [("ClickHouse 22.1.3.7",)] mock_query.return_value = result new_engine = ClickHouseEngine(instance=self.ins1) server_version = new_engine.server_version self.assertTupleEqual(server_version, (22, 1, 3)) - @patch.object(ClickHouseEngine, 'query') + @patch.object(ClickHouseEngine, "query") def test_table_engine(self, mock_query): - table_name = 'default.tb_test' + table_name = "default.tb_test" result = ResultSet() - result.rows = [('MergeTree',)] + result.rows = [("MergeTree",)] mock_query.return_value = result new_engine = ClickHouseEngine(instance=self.ins1) table_engine = new_engine.get_table_engine(table_name) - self.assertDictEqual(table_engine, {'status': 1, 'engine': 'MergeTree'}) + self.assertDictEqual(table_engine, {"status": 1, "engine": "MergeTree"}) - @patch('clickhouse_driver.connect') + @patch("clickhouse_driver.connect") def test_engine_base_info(self, _conn): new_engine = ClickHouseEngine(instance=self.ins1) - self.assertEqual(new_engine.name, 'ClickHouse') - self.assertEqual(new_engine.info, 'ClickHouse engine') + self.assertEqual(new_engine.name, "ClickHouse") + self.assertEqual(new_engine.info, "ClickHouse engine") - @patch.object(ClickHouseEngine, 'get_connection') + @patch.object(ClickHouseEngine, "get_connection") def testGetConnection(self, connect): new_engine = ClickHouseEngine(instance=self.ins1) new_engine.get_connection() connect.assert_called_once() - @patch.object(ClickHouseEngine, 'query') + @patch.object(ClickHouseEngine, "query") def testQuery(self, mock_query): result = ResultSet() - result.rows = [('v1', 'v2'), ] + result.rows = [ + ("v1", "v2"), + ] mock_query.return_value = result new_engine = ClickHouseEngine(instance=self.ins1) - query_result = new_engine.query(sql='some_sql', limit_num=100) - self.assertListEqual(query_result.rows, [('v1', 'v2'), ]) + query_result = new_engine.query(sql="some_sql", limit_num=100) + self.assertListEqual( + query_result.rows, + [ + ("v1", "v2"), + ], + ) - @patch.object(ClickHouseEngine, 'query') + @patch.object(ClickHouseEngine, "query") def testAllDb(self, mock_query): db_result = ResultSet() - db_result.rows = [('db_1',), ('db_2',)] + db_result.rows = [("db_1",), ("db_2",)] mock_query.return_value = db_result new_engine = ClickHouseEngine(instance=self.ins1) dbs = new_engine.get_all_databases() - self.assertEqual(dbs.rows, ['db_1', 'db_2']) + self.assertEqual(dbs.rows, ["db_1", "db_2"]) - @patch.object(ClickHouseEngine, 'query') + @patch.object(ClickHouseEngine, "query") def testAllTables(self, mock_query): table_result = ResultSet() - table_result.rows = [('tb_1', 'some_des'), ('tb_2', 'some_des')] + table_result.rows = [("tb_1", "some_des"), ("tb_2", "some_des")] mock_query.return_value = table_result new_engine = ClickHouseEngine(instance=self.ins1) - tables = new_engine.get_all_tables('some_db') - mock_query.assert_called_once_with(db_name='some_db', sql=ANY) - self.assertEqual(tables.rows, ['tb_1', 'tb_2']) + tables = new_engine.get_all_tables("some_db") + mock_query.assert_called_once_with(db_name="some_db", sql=ANY) + self.assertEqual(tables.rows, ["tb_1", "tb_2"]) - @patch.object(ClickHouseEngine, 'query') + @patch.object(ClickHouseEngine, "query") def testAllColumns(self, mock_query): db_result = ResultSet() - db_result.rows = [('col_1', 'type'), ('col_2', 'type2')] + db_result.rows = [("col_1", "type"), ("col_2", "type2")] mock_query.return_value = db_result new_engine = ClickHouseEngine(instance=self.ins1) - dbs = new_engine.get_all_columns_by_tb('some_db', 'some_tb') - self.assertEqual(dbs.rows, ['col_1', 'col_2']) + dbs = new_engine.get_all_columns_by_tb("some_db", "some_tb") + self.assertEqual(dbs.rows, ["col_1", "col_2"]) - @patch.object(ClickHouseEngine, 'query') + @patch.object(ClickHouseEngine, "query") def testDescribe(self, mock_query): new_engine = ClickHouseEngine(instance=self.ins1) - new_engine.describe_table('some_db', 'some_db') + new_engine.describe_table("some_db", "some_db") mock_query.assert_called_once() def test_query_check_wrong_sql(self): new_engine = ClickHouseEngine(instance=self.ins1) - wrong_sql = '-- 测试' - check_result = new_engine.query_check(db_name='some_db', sql=wrong_sql) - self.assertDictEqual(check_result, - {'msg': '不支持的查询语法类型!', 'bad_query': True, 'filtered_sql': '-- 测试', 'has_star': False}) + wrong_sql = "-- 测试" + check_result = new_engine.query_check(db_name="some_db", sql=wrong_sql) + self.assertDictEqual( + check_result, + { + "msg": "不支持的查询语法类型!", + "bad_query": True, + "filtered_sql": "-- 测试", + "has_star": False, + }, + ) def test_query_check_update_sql(self): new_engine = ClickHouseEngine(instance=self.ins1) - update_sql = 'update user set id=0' - check_result = new_engine.query_check(db_name='some_db', sql=update_sql) - self.assertDictEqual(check_result, - {'msg': '不支持的查询语法类型!', 'bad_query': True, 'filtered_sql': 'update user set id=0', - 'has_star': False}) + update_sql = "update user set id=0" + check_result = new_engine.query_check(db_name="some_db", sql=update_sql) + self.assertDictEqual( + check_result, + { + "msg": "不支持的查询语法类型!", + "bad_query": True, + "filtered_sql": "update user set id=0", + "has_star": False, + }, + ) - @patch.object(ClickHouseEngine, 'query') + @patch.object(ClickHouseEngine, "query") def test_explain_check(self, mock_query): result = ResultSet() - result.rows = [('ClickHouse 20.1.3.7',)] + result.rows = [("ClickHouse 20.1.3.7",)] mock_query.return_value = result new_engine = ClickHouseEngine(instance=self.ins1) server_version = new_engine.server_version sql = "insert into tb_test(note) values ('xbb');" check_result = ReviewSet(full_sql=sql) - explain_result = new_engine.explain_check(check_result, db_name='some_db', line=1, statement=sql) + explain_result = new_engine.explain_check( + check_result, db_name="some_db", line=1, statement=sql + ) self.assertEqual(explain_result.stagestatus, "Audit completed") def test_execute_check_select_sql(self): new_engine = ClickHouseEngine(instance=self.ins1) - select_sql = 'select id,name from tb_test' - check_result = new_engine.execute_check(db_name='some_db', sql=select_sql) - self.assertEqual(check_result.rows[0].errormessage, "仅支持DML和DDL语句,查询语句请使用SQL查询功能!") + select_sql = "select id,name from tb_test" + check_result = new_engine.execute_check(db_name="some_db", sql=select_sql) + self.assertEqual( + check_result.rows[0].errormessage, "仅支持DML和DDL语句,查询语句请使用SQL查询功能!" + ) - @patch.object(ClickHouseEngine, 'query') + @patch.object(ClickHouseEngine, "query") def test_execute_check_alter_sql(self, mock_query): - table_name = 'default.tb_test' + table_name = "default.tb_test" result = ResultSet() - result.rows = [('Log',)] + result.rows = [("Log",)] mock_query.return_value = result new_engine = ClickHouseEngine(instance=self.ins1) table_engine = new_engine.get_table_engine(table_name) alter_sql = "alter table default.tb_test add column remark String" - check_result = new_engine.execute_check(db_name='some_db', sql=alter_sql) - self.assertEqual(check_result.rows[0].errormessage, "ALTER TABLE仅支持*MergeTree,Merge以及Distributed等引擎表!") + check_result = new_engine.execute_check(db_name="some_db", sql=alter_sql) + self.assertEqual( + check_result.rows[0].errormessage, + "ALTER TABLE仅支持*MergeTree,Merge以及Distributed等引擎表!", + ) def test_filter_sql_with_delimiter(self): new_engine = ClickHouseEngine(instance=self.ins1) - sql_without_limit = 'select user from usertable;' + sql_without_limit = "select user from usertable;" check_result = new_engine.filter_sql(sql=sql_without_limit, limit_num=100) - self.assertEqual(check_result, 'select user from usertable limit 100;') + self.assertEqual(check_result, "select user from usertable limit 100;") def test_filter_sql_without_delimiter(self): new_engine = ClickHouseEngine(instance=self.ins1) - sql_without_limit = 'select user from usertable' + sql_without_limit = "select user from usertable" check_result = new_engine.filter_sql(sql=sql_without_limit, limit_num=100) - self.assertEqual(check_result, 'select user from usertable limit 100;') + self.assertEqual(check_result, "select user from usertable limit 100;") def test_filter_sql_with_limit(self): new_engine = ClickHouseEngine(instance=self.ins1) - sql_without_limit = 'select user from usertable limit 10' + sql_without_limit = "select user from usertable limit 10" check_result = new_engine.filter_sql(sql=sql_without_limit, limit_num=1) - self.assertEqual(check_result, 'select user from usertable limit 1;') + self.assertEqual(check_result, "select user from usertable limit 1;") def test_filter_sql_with_limit_min(self): new_engine = ClickHouseEngine(instance=self.ins1) - sql_without_limit = 'select user from usertable limit 10' + sql_without_limit = "select user from usertable limit 10" check_result = new_engine.filter_sql(sql=sql_without_limit, limit_num=100) - self.assertEqual(check_result, 'select user from usertable limit 10;') + self.assertEqual(check_result, "select user from usertable limit 10;") def test_filter_sql_with_limit_offset(self): new_engine = ClickHouseEngine(instance=self.ins1) - sql_without_limit = 'select user from usertable limit 10 offset 100' + sql_without_limit = "select user from usertable limit 10 offset 100" check_result = new_engine.filter_sql(sql=sql_without_limit, limit_num=1) - self.assertEqual(check_result, 'select user from usertable limit 1 offset 100;') + self.assertEqual(check_result, "select user from usertable limit 1 offset 100;") def test_filter_sql_with_limit_nn(self): new_engine = ClickHouseEngine(instance=self.ins1) - sql_without_limit = 'select user from usertable limit 10, 100' + sql_without_limit = "select user from usertable limit 10, 100" check_result = new_engine.filter_sql(sql=sql_without_limit, limit_num=1) - self.assertEqual(check_result, 'select user from usertable limit 10,1;') + self.assertEqual(check_result, "select user from usertable limit 10,1;") def test_filter_sql_upper(self): new_engine = ClickHouseEngine(instance=self.ins1) - sql_without_limit = 'SELECT USER FROM usertable LIMIT 10, 100' + sql_without_limit = "SELECT USER FROM usertable LIMIT 10, 100" check_result = new_engine.filter_sql(sql=sql_without_limit, limit_num=1) - self.assertEqual(check_result, 'SELECT USER FROM usertable limit 10,1;') + self.assertEqual(check_result, "SELECT USER FROM usertable limit 10,1;") def test_filter_sql_not_select(self): new_engine = ClickHouseEngine(instance=self.ins1) - sql_without_limit = 'show create table usertable;' + sql_without_limit = "show create table usertable;" check_result = new_engine.filter_sql(sql=sql_without_limit, limit_num=1) - self.assertEqual(check_result, 'show create table usertable;') + self.assertEqual(check_result, "show create table usertable;") - @patch('clickhouse_driver.connect.cursor.execute') - @patch('clickhouse_driver.connect.cursor') - @patch('clickhouse_driver.connect') + @patch("clickhouse_driver.connect.cursor.execute") + @patch("clickhouse_driver.connect.cursor") + @patch("clickhouse_driver.connect") def test_execute(self, _connect, _cursor, _execute): new_engine = ClickHouseEngine(instance=self.ins1) execute_result = new_engine.execute(self.wf) self.assertIsInstance(execute_result, ResultSet) - @patch('clickhouse_driver.connect.cursor.execute') - @patch('clickhouse_driver.connect.cursor') - @patch('clickhouse_driver.connect') + @patch("clickhouse_driver.connect.cursor.execute") + @patch("clickhouse_driver.connect.cursor") + @patch("clickhouse_driver.connect") def test_execute_workflow_success(self, _conn, _cursor, _execute): sql = "insert into tb_test values('test')" - row = ReviewResult(id=1, - errlevel=0, - stagestatus='Execute Successfully', - errormessage='None', - sql=sql, - affected_rows=0, - execute_time=0) + row = ReviewResult( + id=1, + errlevel=0, + stagestatus="Execute Successfully", + errormessage="None", + sql=sql, + affected_rows=0, + execute_time=0, + ) wf = SqlWorkflow.objects.create( - workflow_name='some_name', + workflow_name="some_name", group_id=1, - group_name='g1', - engineer_display='', - audit_auth_groups='some_group', + group_name="g1", + engineer_display="", + audit_auth_groups="some_group", create_time=datetime.now() - timedelta(days=1), - status='workflow_finish', + status="workflow_finish", is_backup=False, instance=self.ins1, - db_name='some_db', - syntax_type=1 + db_name="some_db", + syntax_type=1, ) SqlWorkflowContent.objects.create(workflow=wf, sql_content=sql) new_engine = ClickHouseEngine(instance=self.ins1) @@ -1806,22 +2211,29 @@ def test_execute_workflow_success(self, _conn, _cursor, _execute): class ODPSTest(TestCase): def setUp(self) -> None: - self.ins = Instance.objects.create(instance_name='some_ins', type='slave', db_type='odps', - host='some_host', port=9200, user='ins_user', db_name='some_db') + self.ins = Instance.objects.create( + instance_name="some_ins", + type="slave", + db_type="odps", + host="some_host", + port=9200, + user="ins_user", + db_name="some_db", + ) self.engine = ODPSEngine(instance=self.ins) def tearDown(self) -> None: self.ins.delete() - @patch('sql.engines.odps.ODPSEngine.get_connection') + @patch("sql.engines.odps.ODPSEngine.get_connection") def test_get_connection(self, mock_odps): _ = self.engine.get_connection() mock_odps.assert_called_once() - @patch('sql.engines.odps.ODPSEngine.get_connection') + @patch("sql.engines.odps.ODPSEngine.get_connection") def test_query(self, mock_get_connection): test_sql = """select 123""" - self.assertIsInstance(self.engine.query('some_db', test_sql), ResultSet) + self.assertIsInstance(self.engine.query("some_db", test_sql), ResultSet) def test_query_check(self): test_sql = """select 123; -- this is comment @@ -1843,20 +2255,20 @@ def test_query_check_error(self): self.assertIsInstance(check_result, dict) self.assertEqual(True, check_result.get("bad_query")) - @patch('sql.engines.odps.ODPSEngine.get_connection') + @patch("sql.engines.odps.ODPSEngine.get_connection") def test_get_all_databases(self, mock_get_connection): mock_conn = Mock() mock_conn.exist_project.return_value = True - mock_conn.project = 'some_db' + mock_conn.project = "some_db" mock_get_connection.return_value = mock_conn result = self.engine.get_all_databases() self.assertIsInstance(result, ResultSet) - self.assertEqual(result.rows, ['some_db']) + self.assertEqual(result.rows, ["some_db"]) - @patch('sql.engines.odps.ODPSEngine.get_connection') + @patch("sql.engines.odps.ODPSEngine.get_connection") def test_get_all_tables(self, mock_get_connection): # 下面是查表示例返回结果 class T: @@ -1864,33 +2276,35 @@ def __init__(self, name): self.name = name mock_conn = Mock() - mock_conn.list_tables.return_value = [T('u'), T('v'), T('w')] + mock_conn.list_tables.return_value = [T("u"), T("v"), T("w")] mock_get_connection.return_value = mock_conn - table_list = self.engine.get_all_tables('some_db') + table_list = self.engine.get_all_tables("some_db") - self.assertEqual(table_list.rows, ['u', 'v', 'w']) + self.assertEqual(table_list.rows, ["u", "v", "w"]) - @patch('sql.engines.odps.ODPSEngine.get_all_columns_by_tb') + @patch("sql.engines.odps.ODPSEngine.get_all_columns_by_tb") def test_describe_table(self, mock_get_all_columns_by_tb): - self.engine.describe_table('some_db', 'some_table') + self.engine.describe_table("some_db", "some_table") mock_get_all_columns_by_tb.assert_called_once() - @patch('sql.engines.odps.ODPSEngine.get_connection') + @patch("sql.engines.odps.ODPSEngine.get_connection") def test_get_all_columns_by_tb(self, mock_get_connection): mock_conn = Mock() mock_cols = Mock() mock_col = Mock() - mock_col.name, mock_col.type, mock_col.comment = 'XiaoMing', 'string', 'name' + mock_col.name, mock_col.type, mock_col.comment = "XiaoMing", "string", "name" mock_cols.schema.columns = [mock_col] mock_conn.get_table.return_value = mock_cols mock_get_connection.return_value = mock_conn - result = self.engine.get_all_columns_by_tb('some_db', 'some_table') + result = self.engine.get_all_columns_by_tb("some_db", "some_table") mock_get_connection.assert_called_once() mock_conn.get_table.assert_called_once() - self.assertEqual(result.rows, [['XiaoMing', 'string', 'name']]) - self.assertEqual(result.column_list, ['COLUMN_NAME', 'COLUMN_TYPE', 'COLUMN_COMMENT']) + self.assertEqual(result.rows, [["XiaoMing", "string", "name"]]) + self.assertEqual( + result.column_list, ["COLUMN_NAME", "COLUMN_TYPE", "COLUMN_COMMENT"] + ) diff --git a/sql/form.py b/sql/form.py index 7459dd9f6f..3187b08833 100644 --- a/sql/form.py +++ b/sql/form.py @@ -18,16 +18,18 @@ class Meta: model = Tunnel fields = "__all__" widgets = { - 'PKey': Textarea(attrs={'cols': 40, 'rows': 8}), + "PKey": Textarea(attrs={"cols": 40, "rows": 8}), } def clean(self): cleaned_data = super().clean() - if cleaned_data.get('pkey_path'): + if cleaned_data.get("pkey_path"): try: - pkey_path = cleaned_data.get('pkey_path').read() + pkey_path = cleaned_data.get("pkey_path").read() if pkey_path: - cleaned_data['pkey'] = str(pkey_path, 'utf-8').replace(r'\r', '').replace(r'\n', '') + cleaned_data["pkey"] = ( + str(pkey_path, "utf-8").replace(r"\r", "").replace(r"\n", "") + ) except IOError: raise ValidationError("秘钥文件不存在, 请勾选秘钥路径的清除选项再进行保存") @@ -35,4 +37,7 @@ def clean(self): class InstanceForm(ModelForm): class Media: model = Instance - js = ('jquery/jquery.min.js', 'dist/js/utils.js', ) + js = ( + "jquery/jquery.min.js", + "dist/js/utils.js", + ) diff --git a/sql/instance.py b/sql/instance.py index 6b75676e0a..9e31fca3ec 100644 --- a/sql/instance.py +++ b/sql/instance.py @@ -19,30 +19,30 @@ from .models import Instance, ParamTemplate, ParamHistory -@permission_required('sql.menu_instance_list', raise_exception=True) +@permission_required("sql.menu_instance_list", raise_exception=True) def lists(request): """获取实例列表""" - limit = int(request.POST.get('limit')) - offset = int(request.POST.get('offset')) - type = request.POST.get('type') - db_type = request.POST.get('db_type') - tags = request.POST.getlist('tags[]') + limit = int(request.POST.get("limit")) + offset = int(request.POST.get("offset")) + type = request.POST.get("type") + db_type = request.POST.get("db_type") + tags = request.POST.getlist("tags[]") limit = offset + limit - search = request.POST.get('search', '') - sortName = str(request.POST.get('sortName')) - sortOrder = str(request.POST.get('sortOrder')).lower() + search = request.POST.get("search", "") + sortName = str(request.POST.get("sortName")) + sortOrder = str(request.POST.get("sortOrder")).lower() # 组合筛选项 filter_dict = dict() # 过滤搜索 if search: - filter_dict['instance_name__icontains'] = search + filter_dict["instance_name__icontains"] = search # 过滤实例类型 if type: - filter_dict['type'] = type + filter_dict["type"] = type # 过滤数据库类型 if db_type: - filter_dict['db_type'] = db_type + filter_dict["db_type"] = db_type instances = Instance.objects.filter(**filter_dict) # 过滤标签,返回同时包含全部标签的实例,TODO 循环会生成多表JOIN,如果数据量大会存在效率问题 @@ -51,41 +51,57 @@ def lists(request): instances = instances.filter(instance_tag=tag, instance_tag__active=True) count = instances.count() - if sortName == 'instance_name': - instances = instances.order_by(getattr(Convert(sortName, 'gbk'), sortOrder)())[offset:limit] + if sortName == "instance_name": + instances = instances.order_by(getattr(Convert(sortName, "gbk"), sortOrder)())[ + offset:limit + ] else: - instances = instances.order_by('-' + sortName if sortOrder == 'desc' else sortName)[offset:limit] - instances = instances.values("id", "instance_name", "db_type", "type", "host", "port", "user") + instances = instances.order_by( + "-" + sortName if sortOrder == "desc" else sortName + )[offset:limit] + instances = instances.values( + "id", "instance_name", "db_type", "type", "host", "port", "user" + ) # QuerySet 序列化 rows = [row for row in instances] result = {"total": count, "rows": rows} - return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), - content_type='application/json') + return HttpResponse( + json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), + content_type="application/json", + ) -@permission_required('sql.param_view', raise_exception=True) +@permission_required("sql.param_view", raise_exception=True) def param_list(request): """ 获取实例参数列表 :param request: :return: """ - instance_id = request.POST.get('instance_id') - editable = True if request.POST.get('editable') else False - search = request.POST.get('search', '') + instance_id = request.POST.get("instance_id") + editable = True if request.POST.get("editable") else False + search = request.POST.get("search", "") try: ins = Instance.objects.get(id=instance_id) except Instance.DoesNotExist: - result = {'status': 1, 'msg': '实例不存在', 'data': []} - return HttpResponse(json.dumps(result), content_type='application/json') + result = {"status": 1, "msg": "实例不存在", "data": []} + return HttpResponse(json.dumps(result), content_type="application/json") # 获取已配置参数列表 cnf_params = dict() - for param in ParamTemplate.objects.filter(db_type=ins.db_type, variable_name__contains=search).values( - 'id', 'variable_name', 'default_value', 'valid_values', 'description', 'editable'): - param['variable_name'] = param['variable_name'].lower() - cnf_params[param['variable_name']] = param + for param in ParamTemplate.objects.filter( + db_type=ins.db_type, variable_name__contains=search + ).values( + "id", + "variable_name", + "default_value", + "valid_values", + "description", + "editable", + ): + param["variable_name"] = param["variable_name"].lower() + cnf_params[param["variable_name"]] = param # 获取实例参数列表 engine = get_engine(instance=ins) ins_variables = engine.get_variables() @@ -94,73 +110,85 @@ def param_list(request): for variable in ins_variables.rows: variable_name = variable[0].lower() row = { - 'variable_name': variable_name, - 'runtime_value': variable[1], - 'editable': False, + "variable_name": variable_name, + "runtime_value": variable[1], + "editable": False, } if variable_name in cnf_params.keys(): row = dict(row, **cnf_params[variable_name]) rows.append(row) # 过滤参数 if editable: - rows = [row for row in rows if row['editable']] + rows = [row for row in rows if row["editable"]] else: - rows = [row for row in rows if not row['editable']] - return HttpResponse(json.dumps(rows, cls=ExtendJSONEncoder, bigint_as_string=True), - content_type='application/json') + rows = [row for row in rows if not row["editable"]] + return HttpResponse( + json.dumps(rows, cls=ExtendJSONEncoder, bigint_as_string=True), + content_type="application/json", + ) -@permission_required('sql.param_view', raise_exception=True) +@permission_required("sql.param_view", raise_exception=True) def param_history(request): """实例参数修改历史""" - limit = int(request.POST.get('limit')) - offset = int(request.POST.get('offset')) + limit = int(request.POST.get("limit")) + offset = int(request.POST.get("offset")) limit = offset + limit - instance_id = request.POST.get('instance_id') - search = request.POST.get('search', '') + instance_id = request.POST.get("instance_id") + search = request.POST.get("search", "") phs = ParamHistory.objects.filter(instance__id=instance_id) # 过滤搜索条件 if search: phs = ParamHistory.objects.filter(variable_name__contains=search) count = phs.count() - phs = phs[offset:limit].values("instance__instance_name", "variable_name", "old_var", "new_var", - "user_display", "create_time") + phs = phs[offset:limit].values( + "instance__instance_name", + "variable_name", + "old_var", + "new_var", + "user_display", + "create_time", + ) # QuerySet 序列化 rows = [row for row in phs] result = {"total": count, "rows": rows} - return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), - content_type='application/json') + return HttpResponse( + json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), + content_type="application/json", + ) -@permission_required('sql.param_edit', raise_exception=True) +@permission_required("sql.param_edit", raise_exception=True) def param_edit(request): user = request.user - instance_id = request.POST.get('instance_id') - variable_name = request.POST.get('variable_name') - variable_value = request.POST.get('runtime_value') + instance_id = request.POST.get("instance_id") + variable_name = request.POST.get("variable_name") + variable_value = request.POST.get("runtime_value") try: ins = Instance.objects.get(id=instance_id) except Instance.DoesNotExist: - result = {'status': 1, 'msg': '实例不存在', 'data': []} - return HttpResponse(json.dumps(result), content_type='application/json') + result = {"status": 1, "msg": "实例不存在", "data": []} + return HttpResponse(json.dumps(result), content_type="application/json") # 修改参数 engine = get_engine(instance=ins) # 校验是否配置模板 if not ParamTemplate.objects.filter(variable_name=variable_name).exists(): - result = {'status': 1, 'msg': '请先在参数模板中配置该参数!', 'data': []} - return HttpResponse(json.dumps(result), content_type='application/json') + result = {"status": 1, "msg": "请先在参数模板中配置该参数!", "data": []} + return HttpResponse(json.dumps(result), content_type="application/json") # 获取当前运行参数值 runtime_value = engine.get_variables(variables=[variable_name]).rows[0][1] if variable_value == runtime_value: - result = {'status': 1, 'msg': '参数值与实际运行值一致,未调整!', 'data': []} - return HttpResponse(json.dumps(result), content_type='application/json') - set_result = engine.set_variable(variable_name=variable_name, variable_value=variable_value) + result = {"status": 1, "msg": "参数值与实际运行值一致,未调整!", "data": []} + return HttpResponse(json.dumps(result), content_type="application/json") + set_result = engine.set_variable( + variable_name=variable_name, variable_value=variable_value + ) if set_result.error: - result = {'status': 1, 'msg': f'设置错误,错误信息:{set_result.error}', 'data': []} - return HttpResponse(json.dumps(result), content_type='application/json') + result = {"status": 1, "msg": f"设置错误,错误信息:{set_result.error}", "data": []} + return HttpResponse(json.dumps(result), content_type="application/json") # 修改成功的保存修改记录 else: ParamHistory.objects.create( @@ -170,27 +198,31 @@ def param_edit(request): new_var=variable_value, set_sql=set_result.full_sql, user_name=user.username, - user_display=user.display + user_display=user.display, ) - result = {'status': 0, 'msg': '修改成功,请手动持久化到配置文件!', 'data': []} - return HttpResponse(json.dumps(result), content_type='application/json') + result = {"status": 0, "msg": "修改成功,请手动持久化到配置文件!", "data": []} + return HttpResponse(json.dumps(result), content_type="application/json") -@permission_required('sql.menu_schemasync', raise_exception=True) +@permission_required("sql.menu_schemasync", raise_exception=True) def schemasync(request): """对比实例schema信息""" - instance_name = request.POST.get('instance_name') - db_name = request.POST.get('db_name') - target_instance_name = request.POST.get('target_instance_name') - target_db_name = request.POST.get('target_db_name') - sync_auto_inc = True if request.POST.get('sync_auto_inc') == 'true' else False - sync_comments = True if request.POST.get('sync_comments') == 'true' else False - result = {'status': 0, 'msg': 'ok', 'data': {'diff_stdout': '', 'patch_stdout': '', 'revert_stdout': ''}} + instance_name = request.POST.get("instance_name") + db_name = request.POST.get("db_name") + target_instance_name = request.POST.get("target_instance_name") + target_db_name = request.POST.get("target_db_name") + sync_auto_inc = True if request.POST.get("sync_auto_inc") == "true" else False + sync_comments = True if request.POST.get("sync_comments") == "true" else False + result = { + "status": 0, + "msg": "ok", + "data": {"diff_stdout": "", "patch_stdout": "", "revert_stdout": ""}, + } # 循环对比全部数据库 - if db_name == 'all' or target_db_name == 'all': - db_name = '*' - target_db_name = '*' + if db_name == "all" or target_db_name == "all": + db_name = "*" + target_db_name = "*" # 取出该实例的连接方式 instance_info = Instance.objects.get(instance_name=instance_name) @@ -200,7 +232,7 @@ def schemasync(request): schema_sync = SchemaSync() # 准备参数 tag = int(time.time()) - output_directory = os.path.join(settings.BASE_DIR, 'downloads/schemasync/') + output_directory = os.path.join(settings.BASE_DIR, "downloads/schemasync/") os.makedirs(output_directory, exist_ok=True) db_name = shlex.quote(db_name) target_db_name = shlex.quote(target_db_name) @@ -209,50 +241,74 @@ def schemasync(request): "sync-comments": sync_comments, "tag": tag, "output-directory": output_directory, - "source": r"mysql://{user}:{pwd}@{host}:{port}/{database}".format(user=shlex.quote(str(instance_info.user)), - pwd=shlex.quote(str(instance_info.password)), - host=shlex.quote(str(instance_info.host)), - port=shlex.quote(str(instance_info.port)), - database=db_name), - "target": r"mysql://{user}:{pwd}@{host}:{port}/{database}".format(user=shlex.quote(str(target_instance_info.user)), - pwd=shlex.quote(str(target_instance_info.password)), - host=shlex.quote(str(target_instance_info.host)), - port=shlex.quote(str(target_instance_info.port)), - database=target_db_name) + "source": r"mysql://{user}:{pwd}@{host}:{port}/{database}".format( + user=shlex.quote(str(instance_info.user)), + pwd=shlex.quote(str(instance_info.password)), + host=shlex.quote(str(instance_info.host)), + port=shlex.quote(str(instance_info.port)), + database=db_name, + ), + "target": r"mysql://{user}:{pwd}@{host}:{port}/{database}".format( + user=shlex.quote(str(target_instance_info.user)), + pwd=shlex.quote(str(target_instance_info.password)), + host=shlex.quote(str(target_instance_info.host)), + port=shlex.quote(str(target_instance_info.port)), + database=target_db_name, + ), } # 参数检查 args_check_result = schema_sync.check_args(args) - if args_check_result['status'] == 1: - return HttpResponse(json.dumps(args_check_result), content_type='application/json') + if args_check_result["status"] == 1: + return HttpResponse( + json.dumps(args_check_result), content_type="application/json" + ) # 参数转换 cmd_args = schema_sync.generate_args2cmd(args, shell=True) # 执行命令 try: stdout, stderr = schema_sync.execute_cmd(cmd_args, shell=True).communicate() - diff_stdout = f'{stdout}{stderr}' + diff_stdout = f"{stdout}{stderr}" except RuntimeError as e: diff_stdout = str(e) # 非全部数据库对比可以读取对比结果并在前端展示 - if db_name != '*': + if db_name != "*": date = time.strftime("%Y%m%d", time.localtime()) - patch_sql_file = '%s%s_%s.%s.patch.sql' % (output_directory, target_db_name, tag, date) - revert_sql_file = '%s%s_%s.%s.revert.sql' % (output_directory, target_db_name, tag, date) + patch_sql_file = "%s%s_%s.%s.patch.sql" % ( + output_directory, + target_db_name, + tag, + date, + ) + revert_sql_file = "%s%s_%s.%s.revert.sql" % ( + output_directory, + target_db_name, + tag, + date, + ) try: - with open(patch_sql_file, 'r') as f: + with open(patch_sql_file, "r") as f: patch_sql = f.read() except FileNotFoundError as e: patch_sql = str(e) try: - with open(revert_sql_file, 'r') as f: + with open(revert_sql_file, "r") as f: revert_sql = f.read() except FileNotFoundError as e: revert_sql = str(e) - result['data'] = {'diff_stdout': diff_stdout, 'patch_stdout': patch_sql, 'revert_stdout': revert_sql} + result["data"] = { + "diff_stdout": diff_stdout, + "patch_stdout": patch_sql, + "revert_stdout": revert_sql, + } else: - result['data'] = {'diff_stdout': diff_stdout, 'patch_stdout': '', 'revert_stdout': ''} + result["data"] = { + "diff_stdout": diff_stdout, + "patch_stdout": "", + "revert_stdout": "", + } - return HttpResponse(json.dumps(result), content_type='application/json') + return HttpResponse(json.dumps(result), content_type="application/json") @cache_page(60 * 5, key_prefix="insRes") @@ -262,73 +318,79 @@ def instance_resource(request): :param request: :return: """ - instance_id = request.GET.get('instance_id') - instance_name = request.GET.get('instance_name') - db_name = request.GET.get('db_name', '') - schema_name = request.GET.get('schema_name', '') - tb_name = request.GET.get('tb_name', '') + instance_id = request.GET.get("instance_id") + instance_name = request.GET.get("instance_name") + db_name = request.GET.get("db_name", "") + schema_name = request.GET.get("schema_name", "") + tb_name = request.GET.get("tb_name", "") - resource_type = request.GET.get('resource_type') + resource_type = request.GET.get("resource_type") if instance_id: instance = Instance.objects.get(id=instance_id) else: try: instance = Instance.objects.get(instance_name=instance_name) except Instance.DoesNotExist: - result = {'status': 1, 'msg': '实例不存在', 'data': []} - return HttpResponse(json.dumps(result), content_type='application/json') - result = {'status': 0, 'msg': 'ok', 'data': []} + result = {"status": 1, "msg": "实例不存在", "data": []} + return HttpResponse(json.dumps(result), content_type="application/json") + result = {"status": 0, "msg": "ok", "data": []} try: # escape - db_name = MySQLdb.escape_string(db_name).decode('utf-8') - schema_name = MySQLdb.escape_string(schema_name).decode('utf-8') - tb_name = MySQLdb.escape_string(tb_name).decode('utf-8') + db_name = MySQLdb.escape_string(db_name).decode("utf-8") + schema_name = MySQLdb.escape_string(schema_name).decode("utf-8") + tb_name = MySQLdb.escape_string(tb_name).decode("utf-8") query_engine = get_engine(instance=instance) - if resource_type == 'database': + if resource_type == "database": resource = query_engine.get_all_databases() - elif resource_type == 'schema' and db_name: + elif resource_type == "schema" and db_name: resource = query_engine.get_all_schemas(db_name=db_name) - elif resource_type == 'table' and db_name: - resource = query_engine.get_all_tables(db_name=db_name, schema_name=schema_name) - elif resource_type == 'column' and db_name and tb_name: - resource = query_engine.get_all_columns_by_tb(db_name=db_name, tb_name=tb_name, schema_name=schema_name) + elif resource_type == "table" and db_name: + resource = query_engine.get_all_tables( + db_name=db_name, schema_name=schema_name + ) + elif resource_type == "column" and db_name and tb_name: + resource = query_engine.get_all_columns_by_tb( + db_name=db_name, tb_name=tb_name, schema_name=schema_name + ) else: - raise TypeError('不支持的资源类型或者参数不完整!') + raise TypeError("不支持的资源类型或者参数不完整!") except Exception as msg: - result['status'] = 1 - result['msg'] = str(msg) + result["status"] = 1 + result["msg"] = str(msg) else: if resource.error: - result['status'] = 1 - result['msg'] = resource.error + result["status"] = 1 + result["msg"] = resource.error else: - result['data'] = resource.rows - return HttpResponse(json.dumps(result), content_type='application/json') + result["data"] = resource.rows + return HttpResponse(json.dumps(result), content_type="application/json") def describe(request): """获取表结构""" - instance_name = request.POST.get('instance_name') + instance_name = request.POST.get("instance_name") try: instance = Instance.objects.get(instance_name=instance_name) except Instance.DoesNotExist: - result = {'status': 1, 'msg': '实例不存在', 'data': []} - return HttpResponse(json.dumps(result), content_type='application/json') - db_name = request.POST.get('db_name') - schema_name = request.POST.get('schema_name') - tb_name = request.POST.get('tb_name') - result = {'status': 0, 'msg': 'ok', 'data': []} + result = {"status": 1, "msg": "实例不存在", "data": []} + return HttpResponse(json.dumps(result), content_type="application/json") + db_name = request.POST.get("db_name") + schema_name = request.POST.get("schema_name") + tb_name = request.POST.get("tb_name") + result = {"status": 0, "msg": "ok", "data": []} try: query_engine = get_engine(instance=instance) - query_result = query_engine.describe_table(db_name, tb_name, schema_name=schema_name) - result['data'] = query_result.__dict__ + query_result = query_engine.describe_table( + db_name, tb_name, schema_name=schema_name + ) + result["data"] = query_result.__dict__ except Exception as msg: - result['status'] = 1 - result['msg'] = str(msg) - if result['data']['error']: - result['status'] = 1 - result['msg'] = result['data']['error'] - return HttpResponse(json.dumps(result), content_type='application/json') + result["status"] = 1 + result["msg"] = str(msg) + if result["data"]["error"]: + result["status"] = 1 + result["msg"] = result["data"]["error"] + return HttpResponse(json.dumps(result), content_type="application/json") diff --git a/sql/instance_account.py b/sql/instance_account.py index 7af01a3436..06b71fe961 100644 --- a/sql/instance_account.py +++ b/sql/instance_account.py @@ -13,23 +13,25 @@ from .models import Instance, InstanceAccount -@permission_required('sql.menu_instance_account', raise_exception=True) +@permission_required("sql.menu_instance_account", raise_exception=True) def users(request): """获取实例用户列表""" - instance_id = request.POST.get('instance_id') - saved = True if request.POST.get('saved') == 'true' else False # 平台是否保存 + instance_id = request.POST.get("instance_id") + saved = True if request.POST.get("saved") == "true" else False # 平台是否保存 if not instance_id: - return JsonResponse({'status': 0, 'msg': '', 'data': []}) + return JsonResponse({"status": 0, "msg": "", "data": []}) try: - instance = user_instances(request.user, db_type=['mysql']).get(id=instance_id) + instance = user_instances(request.user, db_type=["mysql"]).get(id=instance_id) except Instance.DoesNotExist: - return JsonResponse({'status': 1, 'msg': '你所在组未关联该实例', 'data': []}) + return JsonResponse({"status": 1, "msg": "你所在组未关联该实例", "data": []}) # 获取已录入用户 cnf_users = dict() - for user in InstanceAccount.objects.filter(instance=instance).values('id', 'user', 'host', 'remark'): - user['saved'] = True + for user in InstanceAccount.objects.filter(instance=instance).values( + "id", "user", "host", "remark" + ): + user["saved"] = True cnf_users[f"`{user['user']}`@`{user['host']}`"] = user # 获取所有用户 query_engine = get_engine(instance=instance) @@ -39,21 +41,23 @@ def users(request): sql_get_user = "select concat('`', user, '`', '@', '`', host,'`') as query,user,host,account_locked from mysql.user;" else: sql_get_user = "select concat('`', user, '`', '@', '`', host,'`') as query,user,host from mysql.user;" - query_result = query_engine.query('mysql', sql_get_user) + query_result = query_engine.query("mysql", sql_get_user) if not query_result.error: db_users = query_result.rows # 获取用户权限信息 rows = [] for db_user in db_users: user_host = db_user[0] - user_priv = query_engine.query('mysql', 'show grants for {};'.format(user_host), close_conn=False).rows + user_priv = query_engine.query( + "mysql", "show grants for {};".format(user_host), close_conn=False + ).rows row = { - 'user_host': user_host, - 'user': db_user[1], - 'host': db_user[2], - 'privileges': user_priv, - 'saved': False, - 'is_locked': db_user[3] if server_version >= (5, 7, 6) else None + "user_host": user_host, + "user": db_user[1], + "host": db_user[2], + "privileges": user_priv, + "saved": False, + "is_locked": db_user[3] if server_version >= (5, 7, 6) else None, } # 合并数据 if user_host in cnf_users.keys(): @@ -61,123 +65,129 @@ def users(request): rows.append(row) # 过滤参数 if saved: - rows = [row for row in rows if row['saved']] + rows = [row for row in rows if row["saved"]] - result = {'status': 0, 'msg': 'ok', 'rows': rows} + result = {"status": 0, "msg": "ok", "rows": rows} else: - result = {'status': 1, 'msg': query_result.error} + result = {"status": 1, "msg": query_result.error} # 关闭连接 query_engine.close() - return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), - content_type='application/json') + return HttpResponse( + json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), + content_type="application/json", + ) -@permission_required('sql.instance_account_manage', raise_exception=True) +@permission_required("sql.instance_account_manage", raise_exception=True) def create(request): """创建数据库账号""" - instance_id = request.POST.get('instance_id', 0) - user = request.POST.get('user') - host = request.POST.get('host') - password1 = request.POST.get('password1') - password2 = request.POST.get('password2') - remark = request.POST.get('remark', '') + instance_id = request.POST.get("instance_id", 0) + user = request.POST.get("user") + host = request.POST.get("host") + password1 = request.POST.get("password1") + password2 = request.POST.get("password2") + remark = request.POST.get("remark", "") try: - instance = user_instances(request.user, db_type=['mysql']).get(id=instance_id) + instance = user_instances(request.user, db_type=["mysql"]).get(id=instance_id) except Instance.DoesNotExist: - return JsonResponse({'status': 1, 'msg': '你所在组未关联该实例', 'data': []}) + return JsonResponse({"status": 1, "msg": "你所在组未关联该实例", "data": []}) if not all([user, host, password1, password2]): - return JsonResponse({'status': 1, 'msg': '参数不完整,请确认后提交', 'data': []}) + return JsonResponse({"status": 1, "msg": "参数不完整,请确认后提交", "data": []}) if password1 != password2: - return JsonResponse({'status': 1, 'msg': '两次输入密码不一致', 'data': []}) + return JsonResponse({"status": 1, "msg": "两次输入密码不一致", "data": []}) # TODO 目前使用系统自带验证,后续实现验证器校验 try: validate_password(password1, user=None, password_validators=None) except ValidationError as msg: - return JsonResponse({'status': 1, 'msg': f'{msg}', 'data': []}) + return JsonResponse({"status": 1, "msg": f"{msg}", "data": []}) # escape - user = MySQLdb.escape_string(user).decode('utf-8') - host = MySQLdb.escape_string(host).decode('utf-8') - password1 = MySQLdb.escape_string(password1).decode('utf-8') + user = MySQLdb.escape_string(user).decode("utf-8") + host = MySQLdb.escape_string(host).decode("utf-8") + password1 = MySQLdb.escape_string(password1).decode("utf-8") engine = get_engine(instance=instance) # 在一个事务内执行 hosts = host.split("|") - create_user_cmd = '' + create_user_cmd = "" accounts = [] for host in hosts: create_user_cmd += f"create user '{user}'@'{host}' identified by '{password1}';" - accounts.append(InstanceAccount(instance=instance, user=user, host=host, password=password1, remark=remark)) - exec_result = engine.execute(db_name='mysql', sql=create_user_cmd) + accounts.append( + InstanceAccount( + instance=instance, + user=user, + host=host, + password=password1, + remark=remark, + ) + ) + exec_result = engine.execute(db_name="mysql", sql=create_user_cmd) if exec_result.error: - return JsonResponse({'status': 1, 'msg': exec_result.error}) + return JsonResponse({"status": 1, "msg": exec_result.error}) # 保存到数据库 else: InstanceAccount.objects.bulk_create(accounts) - return JsonResponse({'status': 0, 'msg': '', 'data': []}) + return JsonResponse({"status": 0, "msg": "", "data": []}) -@permission_required('sql.instance_account_manage', raise_exception=True) +@permission_required("sql.instance_account_manage", raise_exception=True) def edit(request): """修改、录入数据库账号""" - instance_id = request.POST.get('instance_id', 0) - user = request.POST.get('user') - host = request.POST.get('host') - password = request.POST.get('password') - remark = request.POST.get('remark', '') + instance_id = request.POST.get("instance_id", 0) + user = request.POST.get("user") + host = request.POST.get("host") + password = request.POST.get("password") + remark = request.POST.get("remark", "") try: - instance = user_instances(request.user, db_type=['mysql']).get(id=instance_id) + instance = user_instances(request.user, db_type=["mysql"]).get(id=instance_id) except Instance.DoesNotExist: - return JsonResponse({'status': 1, 'msg': '你所在组未关联该实例', 'data': []}) + return JsonResponse({"status": 1, "msg": "你所在组未关联该实例", "data": []}) if not all([user, host]): - return JsonResponse({'status': 1, 'msg': '参数不完整,请确认后提交', 'data': []}) + return JsonResponse({"status": 1, "msg": "参数不完整,请确认后提交", "data": []}) # 保存到数据库 if password: InstanceAccount.objects.update_or_create( - instance=instance, user=user, host=host, - defaults={ - "password": password, - "remark": remark - } + instance=instance, + user=user, + host=host, + defaults={"password": password, "remark": remark}, ) else: InstanceAccount.objects.update_or_create( - instance=instance, user=user, host=host, - defaults={ - "remark": remark - } + instance=instance, user=user, host=host, defaults={"remark": remark} ) - return JsonResponse({'status': 0, 'msg': '', 'data': []}) + return JsonResponse({"status": 0, "msg": "", "data": []}) -@permission_required('sql.instance_account_manage', raise_exception=True) +@permission_required("sql.instance_account_manage", raise_exception=True) def grant(request): """获取用户权限变更语句,并执行权限变更""" - instance_id = request.POST.get('instance_id', 0) - user_host = request.POST.get('user_host') - op_type = int(request.POST.get('op_type')) - priv_type = int(request.POST.get('priv_type')) - privs = json.loads(request.POST.get('privs')) - grant_sql = '' + instance_id = request.POST.get("instance_id", 0) + user_host = request.POST.get("user_host") + op_type = int(request.POST.get("op_type")) + priv_type = int(request.POST.get("priv_type")) + privs = json.loads(request.POST.get("privs")) + grant_sql = "" # escape - user_host = MySQLdb.escape_string(user_host).decode('utf-8') + user_host = MySQLdb.escape_string(user_host).decode("utf-8") # 全局权限 if priv_type == 0: - global_privs = privs['global_privs'] + global_privs = privs["global_privs"] if not all([global_privs]): - return JsonResponse({'status': 1, 'msg': '信息不完整,请确认后提交', 'data': []}) - global_privs = ['GRANT OPTION' if g == 'GRANT' else g for g in global_privs] + return JsonResponse({"status": 1, "msg": "信息不完整,请确认后提交", "data": []}) + global_privs = ["GRANT OPTION" if g == "GRANT" else g for g in global_privs] if op_type == 0: grant_sql = f"GRANT {','.join(global_privs)} ON *.* TO {user_host};" elif op_type == 1: @@ -185,37 +195,41 @@ def grant(request): # 库权限 elif priv_type == 1: - db_privs = privs['db_privs'] - db_name = request.POST.getlist('db_name[]') + db_privs = privs["db_privs"] + db_name = request.POST.getlist("db_name[]") if not all([db_privs, db_name]): - return JsonResponse({'status': 1, 'msg': '信息不完整,请确认后提交', 'data': []}) + return JsonResponse({"status": 1, "msg": "信息不完整,请确认后提交", "data": []}) for db in db_name: - db_privs = ['GRANT OPTION' if d == 'GRANT' else d for d in db_privs] + db_privs = ["GRANT OPTION" if d == "GRANT" else d for d in db_privs] if op_type == 0: grant_sql += f"GRANT {','.join(db_privs)} ON `{db}`.* TO {user_host};" elif op_type == 1: - grant_sql += f"REVOKE {','.join(db_privs)} ON `{db}`.* FROM {user_host};" + grant_sql += ( + f"REVOKE {','.join(db_privs)} ON `{db}`.* FROM {user_host};" + ) # 表权限 elif priv_type == 2: - tb_privs = privs['tb_privs'] - db_name = request.POST.get('db_name') - tb_name = request.POST.getlist('tb_name[]') + tb_privs = privs["tb_privs"] + db_name = request.POST.get("db_name") + tb_name = request.POST.getlist("tb_name[]") if not all([tb_privs, db_name, tb_name]): - return JsonResponse({'status': 1, 'msg': '信息不完整,请确认后提交', 'data': []}) + return JsonResponse({"status": 1, "msg": "信息不完整,请确认后提交", "data": []}) for tb in tb_name: - tb_privs = ['GRANT OPTION' if t == 'GRANT' else t for t in tb_privs] + tb_privs = ["GRANT OPTION" if t == "GRANT" else t for t in tb_privs] if op_type == 0: - grant_sql += f"GRANT {','.join(tb_privs)} ON `{db_name}`.`{tb}` TO {user_host};" + grant_sql += ( + f"GRANT {','.join(tb_privs)} ON `{db_name}`.`{tb}` TO {user_host};" + ) elif op_type == 1: grant_sql += f"REVOKE {','.join(tb_privs)} ON `{db_name}`.`{tb}` FROM {user_host};" # 列权限 elif priv_type == 3: - col_privs = privs['col_privs'] - db_name = request.POST.get('db_name') - tb_name = request.POST.get('tb_name') - col_name = request.POST.getlist('col_name[]') + col_privs = privs["col_privs"] + db_name = request.POST.get("db_name") + tb_name = request.POST.get("tb_name") + col_name = request.POST.getlist("col_name[]") if not all([col_privs, db_name, tb_name, col_name]): - return JsonResponse({'status': 1, 'msg': '信息不完整,请确认后提交', 'data': []}) + return JsonResponse({"status": 1, "msg": "信息不完整,请确认后提交", "data": []}) for priv in col_privs: if op_type == 0: grant_sql += f"GRANT {priv}(`{'`,`'.join(col_name)}`) ON `{db_name}`.`{tb_name}` TO {user_host};" @@ -224,116 +238,118 @@ def grant(request): # 执行变更语句 try: - instance = user_instances(request.user, db_type=['mysql']).get(id=instance_id) + instance = user_instances(request.user, db_type=["mysql"]).get(id=instance_id) except Instance.DoesNotExist: - return JsonResponse({'status': 1, 'msg': '你所在组未关联该实例', 'data': []}) + return JsonResponse({"status": 1, "msg": "你所在组未关联该实例", "data": []}) engine = get_engine(instance=instance) - exec_result = engine.execute(db_name='mysql', sql=grant_sql) + exec_result = engine.execute(db_name="mysql", sql=grant_sql) if exec_result.error: - return JsonResponse({'status': 1, 'msg': exec_result.error}) - return JsonResponse({'status': 0, 'msg': '', 'data': grant_sql}) + return JsonResponse({"status": 1, "msg": exec_result.error}) + return JsonResponse({"status": 0, "msg": "", "data": grant_sql}) -@permission_required('sql.instance_account_manage', raise_exception=True) +@permission_required("sql.instance_account_manage", raise_exception=True) def reset_pwd(request): """创建数据库账号""" - instance_id = request.POST.get('instance_id', 0) - user_host = request.POST.get('user_host') - user = request.POST.get('user') - host = request.POST.get('host') - reset_pwd1 = request.POST.get('reset_pwd1') - reset_pwd2 = request.POST.get('reset_pwd2') + instance_id = request.POST.get("instance_id", 0) + user_host = request.POST.get("user_host") + user = request.POST.get("user") + host = request.POST.get("host") + reset_pwd1 = request.POST.get("reset_pwd1") + reset_pwd2 = request.POST.get("reset_pwd2") if not all([user, host, reset_pwd1, reset_pwd2]): - return JsonResponse({'status': 1, 'msg': '参数不完整,请确认后提交', 'data': []}) + return JsonResponse({"status": 1, "msg": "参数不完整,请确认后提交", "data": []}) if reset_pwd1 != reset_pwd2: - return JsonResponse({'status': 1, 'msg': '两次输入密码不一致', 'data': []}) + return JsonResponse({"status": 1, "msg": "两次输入密码不一致", "data": []}) try: - instance = user_instances(request.user, db_type=['mysql']).get(id=instance_id) + instance = user_instances(request.user, db_type=["mysql"]).get(id=instance_id) except Instance.DoesNotExist: - return JsonResponse({'status': 1, 'msg': '你所在组未关联该实例', 'data': []}) + return JsonResponse({"status": 1, "msg": "你所在组未关联该实例", "data": []}) # escape - user_host = MySQLdb.escape_string(user_host).decode('utf-8') - reset_pwd1 = MySQLdb.escape_string(reset_pwd1).decode('utf-8') + user_host = MySQLdb.escape_string(user_host).decode("utf-8") + reset_pwd1 = MySQLdb.escape_string(reset_pwd1).decode("utf-8") # TODO 目前使用系统自带验证,后续实现验证器校验 try: validate_password(reset_pwd1, user=None, password_validators=None) except ValidationError as msg: - return JsonResponse({'status': 1, 'msg': f'{msg}', 'data': []}) + return JsonResponse({"status": 1, "msg": f"{msg}", "data": []}) engine = get_engine(instance=instance) - exec_result = engine.execute(db_name='mysql', - sql=f"ALTER USER {user_host} IDENTIFIED BY '{reset_pwd1}';") + exec_result = engine.execute( + db_name="mysql", sql=f"ALTER USER {user_host} IDENTIFIED BY '{reset_pwd1}';" + ) if exec_result.error: - result = {'status': 1, 'msg': exec_result.error} - return HttpResponse(json.dumps(result), content_type='application/json') + result = {"status": 1, "msg": exec_result.error} + return HttpResponse(json.dumps(result), content_type="application/json") # 保存到数据库 else: - InstanceAccount.objects.update_or_create(instance=instance, user=user, host=host, - defaults={'password': reset_pwd1}) - return JsonResponse({'status': 0, 'msg': '', 'data': []}) + InstanceAccount.objects.update_or_create( + instance=instance, user=user, host=host, defaults={"password": reset_pwd1} + ) + return JsonResponse({"status": 0, "msg": "", "data": []}) -@permission_required('sql.instance_account_manage', raise_exception=True) +@permission_required("sql.instance_account_manage", raise_exception=True) def lock(request): """锁定/解锁账号""" - instance_id = request.POST.get('instance_id', 0) - user_host = request.POST.get('user_host') - is_locked = request.POST.get('is_locked') - lock_sql = '' + instance_id = request.POST.get("instance_id", 0) + user_host = request.POST.get("user_host") + is_locked = request.POST.get("is_locked") + lock_sql = "" if not all([user_host]): - return JsonResponse({'status': 1, 'msg': '参数不完整,请确认后提交', 'data': []}) + return JsonResponse({"status": 1, "msg": "参数不完整,请确认后提交", "data": []}) try: - instance = user_instances(request.user, db_type=['mysql']).get(id=instance_id) + instance = user_instances(request.user, db_type=["mysql"]).get(id=instance_id) except Instance.DoesNotExist: - return JsonResponse({'status': 1, 'msg': '你所在组未关联该实例', 'data': []}) + return JsonResponse({"status": 1, "msg": "你所在组未关联该实例", "data": []}) # escape - user_host = MySQLdb.escape_string(user_host).decode('utf-8') + user_host = MySQLdb.escape_string(user_host).decode("utf-8") - if is_locked == 'N': + if is_locked == "N": lock_sql = f"ALTER USER {user_host} ACCOUNT LOCK;" - elif is_locked == 'Y': + elif is_locked == "Y": lock_sql = f"ALTER USER {user_host} ACCOUNT UNLOCK;" engine = get_engine(instance=instance) - exec_result = engine.execute(db_name='mysql', sql=lock_sql) + exec_result = engine.execute(db_name="mysql", sql=lock_sql) if exec_result.error: - return JsonResponse({'status': 1, 'msg': exec_result.error}) - return JsonResponse({'status': 0, 'msg': '', 'data': []}) + return JsonResponse({"status": 1, "msg": exec_result.error}) + return JsonResponse({"status": 0, "msg": "", "data": []}) -@permission_required('sql.instance_account_manage', raise_exception=True) +@permission_required("sql.instance_account_manage", raise_exception=True) def delete(request): """删除账号""" - instance_id = request.POST.get('instance_id', 0) - user_host = request.POST.get('user_host') - user = request.POST.get('user') - host = request.POST.get('host') + instance_id = request.POST.get("instance_id", 0) + user_host = request.POST.get("user_host") + user = request.POST.get("user") + host = request.POST.get("host") if not all([user_host]): - return JsonResponse({'status': 1, 'msg': '参数不完整,请确认后提交', 'data': []}) + return JsonResponse({"status": 1, "msg": "参数不完整,请确认后提交", "data": []}) try: - instance = user_instances(request.user, db_type=['mysql']).get(id=instance_id) + instance = user_instances(request.user, db_type=["mysql"]).get(id=instance_id) except Instance.DoesNotExist: - return JsonResponse({'status': 1, 'msg': '你所在组未关联该实例', 'data': []}) + return JsonResponse({"status": 1, "msg": "你所在组未关联该实例", "data": []}) # escape - user_host = MySQLdb.escape_string(user_host).decode('utf-8') + user_host = MySQLdb.escape_string(user_host).decode("utf-8") engine = get_engine(instance=instance) - exec_result = engine.execute(db_name='mysql', sql=f"DROP USER {user_host};") + exec_result = engine.execute(db_name="mysql", sql=f"DROP USER {user_host};") if exec_result.error: - return JsonResponse({'status': 1, 'msg': exec_result.error}) + return JsonResponse({"status": 1, "msg": exec_result.error}) # 删除数据库对应记录 else: InstanceAccount.objects.filter(instance=instance, user=user, host=host).delete() - return JsonResponse({'status': 0, 'msg': '', 'data': []}) + return JsonResponse({"status": 0, "msg": "", "data": []}) diff --git a/sql/instance_database.py b/sql/instance_database.py index 3deb528197..15d1572c35 100644 --- a/sql/instance_database.py +++ b/sql/instance_database.py @@ -17,28 +17,29 @@ from sql.models import Instance, InstanceDatabase, Users from sql.utils.resource_group import user_instances -__author__ = 'hhyo' +__author__ = "hhyo" -@permission_required('sql.menu_database', raise_exception=True) +@permission_required("sql.menu_database", raise_exception=True) def databases(request): """获取实例数据库列表""" - instance_id = request.POST.get('instance_id') - saved = True if request.POST.get('saved') == 'true' else False # 平台是否保存 + instance_id = request.POST.get("instance_id") + saved = True if request.POST.get("saved") == "true" else False # 平台是否保存 if not instance_id: - return JsonResponse({'status': 0, 'msg': '', 'data': []}) + return JsonResponse({"status": 0, "msg": "", "data": []}) try: - instance = user_instances(request.user, db_type=['mysql']).get(id=instance_id) + instance = user_instances(request.user, db_type=["mysql"]).get(id=instance_id) except Instance.DoesNotExist: - return JsonResponse({'status': 1, 'msg': '你所在组未关联该实例', 'data': []}) + return JsonResponse({"status": 1, "msg": "你所在组未关联该实例", "data": []}) # 获取已录入数据库 cnf_dbs = dict() - for db in InstanceDatabase.objects.filter( - instance=instance).values('id', 'db_name', 'owner', 'owner_display', 'remark'): - db['saved'] = True + for db in InstanceDatabase.objects.filter(instance=instance).values( + "id", "db_name", "owner", "owner_display", "remark" + ): + db["saved"] = True cnf_dbs[f"{db['db_name']}"] = db # 获取所有数据库 @@ -46,7 +47,9 @@ def databases(request): FROM information_schema.SCHEMATA WHERE SCHEMA_NAME NOT IN ('information_schema', 'performance_schema', 'mysql', 'test', 'sys');""" query_engine = get_engine(instance=instance) - query_result = query_engine.query('information_schema', sql_get_db, close_conn=False) + query_result = query_engine.query( + "information_schema", sql_get_db, close_conn=False + ) if not query_result.error: dbs = query_result.rows # 获取数据库关联用户信息 @@ -57,13 +60,15 @@ def databases(request): from information_schema.SCHEMA_PRIVILEGES where TABLE_SCHEMA='{db_name}' group by TABLE_SCHEMA;""" - bind_users = query_engine.query('information_schema', sql_get_bind_users, close_conn=False).rows + bind_users = query_engine.query( + "information_schema", sql_get_bind_users, close_conn=False + ).rows row = { - 'db_name': db_name, - 'charset': db[1], - 'collation': db[2], - 'grantees': bind_users[0][0].split(',') if bind_users else [], - 'saved': False + "db_name": db_name, + "charset": db[1], + "collation": db[2], + "grantees": bind_users[0][0].split(",") if bind_users else [], + "saved": False, } # 合并数据 if db_name in cnf_dbs.keys(): @@ -71,81 +76,91 @@ def databases(request): rows.append(row) # 过滤参数 if saved: - rows = [row for row in rows if row['saved']] + rows = [row for row in rows if row["saved"]] - result = {'status': 0, 'msg': 'ok', 'rows': rows} + result = {"status": 0, "msg": "ok", "rows": rows} else: - result = {'status': 1, 'msg': query_result.error} + result = {"status": 1, "msg": query_result.error} # 关闭连接 query_engine.close() - return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), - content_type='application/json') + return HttpResponse( + json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), + content_type="application/json", + ) -@permission_required('sql.menu_database', raise_exception=True) +@permission_required("sql.menu_database", raise_exception=True) def create(request): """创建数据库""" - instance_id = request.POST.get('instance_id', 0) - db_name = request.POST.get('db_name') - owner = request.POST.get('owner', '') - remark = request.POST.get('remark', '') + instance_id = request.POST.get("instance_id", 0) + db_name = request.POST.get("db_name") + owner = request.POST.get("owner", "") + remark = request.POST.get("remark", "") if not all([db_name]): - return JsonResponse({'status': 1, 'msg': '参数不完整,请确认后提交', 'data': []}) + return JsonResponse({"status": 1, "msg": "参数不完整,请确认后提交", "data": []}) try: - instance = user_instances(request.user, db_type=['mysql']).get(id=instance_id) + instance = user_instances(request.user, db_type=["mysql"]).get(id=instance_id) except Instance.DoesNotExist: - return JsonResponse({'status': 1, 'msg': '你所在组未关联该实例', 'data': []}) + return JsonResponse({"status": 1, "msg": "你所在组未关联该实例", "data": []}) try: owner_display = Users.objects.get(username=owner).display except Users.DoesNotExist: - return JsonResponse({'status': 1, 'msg': '负责人不存在', 'data': []}) + return JsonResponse({"status": 1, "msg": "负责人不存在", "data": []}) # escape - db_name = MySQLdb.escape_string(db_name).decode('utf-8') + db_name = MySQLdb.escape_string(db_name).decode("utf-8") engine = get_engine(instance=instance) - exec_result = engine.execute(db_name='information_schema', sql=f"create database {db_name};") + exec_result = engine.execute( + db_name="information_schema", sql=f"create database {db_name};" + ) if exec_result.error: - return JsonResponse({'status': 1, 'msg': exec_result.error}) + return JsonResponse({"status": 1, "msg": exec_result.error}) # 保存到数据库 else: InstanceDatabase.objects.create( - instance=instance, db_name=db_name, owner=owner, owner_display=owner_display, remark=remark) + instance=instance, + db_name=db_name, + owner=owner, + owner_display=owner_display, + remark=remark, + ) # 清空实例资源缓存 r = get_redis_connection("default") - for key in r.scan_iter(match='*insRes*', count=2000): + for key in r.scan_iter(match="*insRes*", count=2000): r.delete(key) - return JsonResponse({'status': 0, 'msg': '', 'data': []}) + return JsonResponse({"status": 0, "msg": "", "data": []}) -@permission_required('sql.menu_database', raise_exception=True) +@permission_required("sql.menu_database", raise_exception=True) def edit(request): """编辑/录入数据库""" - instance_id = request.POST.get('instance_id', 0) - db_name = request.POST.get('db_name') - owner = request.POST.get('owner', '') - remark = request.POST.get('remark', '') + instance_id = request.POST.get("instance_id", 0) + db_name = request.POST.get("db_name") + owner = request.POST.get("owner", "") + remark = request.POST.get("remark", "") if not all([db_name]): - return JsonResponse({'status': 1, 'msg': '参数不完整,请确认后提交', 'data': []}) + return JsonResponse({"status": 1, "msg": "参数不完整,请确认后提交", "data": []}) try: - instance = user_instances(request.user, db_type=['mysql']).get(id=instance_id) + instance = user_instances(request.user, db_type=["mysql"]).get(id=instance_id) except Instance.DoesNotExist: - return JsonResponse({'status': 1, 'msg': '你所在组未关联该实例', 'data': []}) + return JsonResponse({"status": 1, "msg": "你所在组未关联该实例", "data": []}) try: owner_display = Users.objects.get(username=owner).display except Users.DoesNotExist: - return JsonResponse({'status': 1, 'msg': '负责人不存在', 'data': []}) + return JsonResponse({"status": 1, "msg": "负责人不存在", "data": []}) # 更新或者录入信息 InstanceDatabase.objects.update_or_create( instance=instance, db_name=db_name, - defaults={"owner": owner, "owner_display": owner_display, "remark": remark}) - return JsonResponse({'status': 0, 'msg': '', 'data': []}) + defaults={"owner": owner, "owner_display": owner_display, "remark": remark}, + ) + return JsonResponse({"status": 0, "msg": "", "data": []}) diff --git a/sql/models.py b/sql/models.py index d998bacb09..84c0377369 100755 --- a/sql/models.py +++ b/sql/models.py @@ -10,15 +10,16 @@ class ResourceGroup(models.Model): """ 资源组 """ - group_id = models.AutoField('组ID', primary_key=True) - group_name = models.CharField('组名称', max_length=100, unique=True) - group_parent_id = models.BigIntegerField('父级id', default=0) - group_sort = models.IntegerField('排序', default=1) - group_level = models.IntegerField('层级', default=1) - ding_webhook = models.CharField('钉钉webhook地址', max_length=255, blank=True) - feishu_webhook = models.CharField('飞书webhook地址', max_length=255, blank=True) - qywx_webhook = models.CharField('企业微信webhook地址', max_length=255, blank=True) - is_deleted = models.IntegerField('是否删除', choices=((0, '否'), (1, '是')), default=0) + + group_id = models.AutoField("组ID", primary_key=True) + group_name = models.CharField("组名称", max_length=100, unique=True) + group_parent_id = models.BigIntegerField("父级id", default=0) + group_sort = models.IntegerField("排序", default=1) + group_level = models.IntegerField("层级", default=1) + ding_webhook = models.CharField("钉钉webhook地址", max_length=255, blank=True) + feishu_webhook = models.CharField("飞书webhook地址", max_length=255, blank=True) + qywx_webhook = models.CharField("企业微信webhook地址", max_length=255, blank=True) + is_deleted = models.IntegerField("是否删除", choices=((0, "否"), (1, "是")), default=0) create_time = models.DateTimeField(auto_now_add=True) sys_time = models.DateTimeField(auto_now=True) @@ -27,22 +28,25 @@ def __str__(self): class Meta: managed = True - db_table = 'resource_group' - verbose_name = u'资源组管理' - verbose_name_plural = u'资源组管理' + db_table = "resource_group" + verbose_name = "资源组管理" + verbose_name_plural = "资源组管理" class Users(AbstractUser): """ 用户信息扩展 """ - display = models.CharField('显示的中文名', max_length=50, default='') - ding_user_id = models.CharField('钉钉UserID', max_length=64, blank=True) - wx_user_id = models.CharField('企业微信UserID', max_length=64, blank=True) - feishu_open_id = models.CharField('飞书OpenID', max_length=64, blank=True) - failed_login_count = models.IntegerField('失败计数', default=0) - last_login_failed_at = models.DateTimeField('上次失败登录时间', blank=True, null=True) - resource_group = models.ManyToManyField(ResourceGroup, verbose_name='资源组', blank=True) + + display = models.CharField("显示的中文名", max_length=50, default="") + ding_user_id = models.CharField("钉钉UserID", max_length=64, blank=True) + wx_user_id = models.CharField("企业微信UserID", max_length=64, blank=True) + feishu_open_id = models.CharField("飞书OpenID", max_length=64, blank=True) + failed_login_count = models.IntegerField("失败计数", default=0) + last_login_failed_at = models.DateTimeField("上次失败登录时间", blank=True, null=True) + resource_group = models.ManyToManyField( + ResourceGroup, verbose_name="资源组", blank=True + ) def save(self, *args, **kwargs): self.failed_login_count = min(127, self.failed_login_count) @@ -56,24 +60,31 @@ def __str__(self): class Meta: managed = True - db_table = 'sql_users' - verbose_name = u'用户管理' - verbose_name_plural = u'用户管理' + db_table = "sql_users" + verbose_name = "用户管理" + verbose_name_plural = "用户管理" class TwoFactorAuthConfig(models.Model): """ 2fa配置信息 """ + auth_type_choice = ( - ('totp', 'Google身份验证器'), - ('sms', '短信验证码'), + ("totp", "Google身份验证器"), + ("sms", "短信验证码"), ) - username = fields.EncryptedCharField(verbose_name='用户名', max_length=200) - auth_type = fields.EncryptedCharField(verbose_name='认证类型', max_length=128, choices=auth_type_choice) - phone = fields.EncryptedCharField(verbose_name='手机号码', max_length=64, null=True, default='') - secret_key = fields.EncryptedCharField(verbose_name='用户密钥', max_length=256, null=True) + username = fields.EncryptedCharField(verbose_name="用户名", max_length=200) + auth_type = fields.EncryptedCharField( + verbose_name="认证类型", max_length=128, choices=auth_type_choice + ) + phone = fields.EncryptedCharField( + verbose_name="手机号码", max_length=64, null=True, default="" + ) + secret_key = fields.EncryptedCharField( + verbose_name="用户密钥", max_length=256, null=True + ) user = models.ForeignKey(Users, on_delete=models.CASCADE) def __int__(self): @@ -81,147 +92,193 @@ def __int__(self): class Meta: managed = True - db_table = '2fa_config' - verbose_name = u'2FA配置' - verbose_name_plural = u'2FA配置' - unique_together = ('user', 'auth_type') + db_table = "2fa_config" + verbose_name = "2FA配置" + verbose_name_plural = "2FA配置" + unique_together = ("user", "auth_type") class InstanceTag(models.Model): """实例标签配置""" - tag_code = models.CharField('标签代码', max_length=20, unique=True) - tag_name = models.CharField('标签名称', max_length=20, unique=True) - active = models.BooleanField('激活状态', default=True) - create_time = models.DateTimeField('创建时间', auto_now_add=True) + + tag_code = models.CharField("标签代码", max_length=20, unique=True) + tag_name = models.CharField("标签名称", max_length=20, unique=True) + active = models.BooleanField("激活状态", default=True) + create_time = models.DateTimeField("创建时间", auto_now_add=True) def __str__(self): return self.tag_name class Meta: managed = True - db_table = 'sql_instance_tag' - verbose_name = u'实例标签' - verbose_name_plural = u'实例标签' + db_table = "sql_instance_tag" + verbose_name = "实例标签" + verbose_name_plural = "实例标签" DB_TYPE_CHOICES = ( - ('mysql', 'MySQL'), - ('mssql', 'MsSQL'), - ('redis', 'Redis'), - ('pgsql', 'PgSQL'), - ('oracle', 'Oracle'), - ('mongo', 'Mongo'), - ('phoenix', 'Phoenix'), - ('odps', 'ODPS'), - ('clickhouse', 'ClickHouse'), - ('goinception', 'goInception')) + ("mysql", "MySQL"), + ("mssql", "MsSQL"), + ("redis", "Redis"), + ("pgsql", "PgSQL"), + ("oracle", "Oracle"), + ("mongo", "Mongo"), + ("phoenix", "Phoenix"), + ("odps", "ODPS"), + ("clickhouse", "ClickHouse"), + ("goinception", "goInception"), +) class Tunnel(models.Model): """ SSH隧道配置 """ - tunnel_name = models.CharField('隧道名称', max_length=50, unique=True) - host = models.CharField('隧道连接', max_length=200) - port = models.IntegerField('端口', default=0) - user = fields.EncryptedCharField(verbose_name='用户名', max_length=200, default='', blank=True, null=True) - password = fields.EncryptedCharField(verbose_name='密码', max_length=300, default='', blank=True, null=True) + + tunnel_name = models.CharField("隧道名称", max_length=50, unique=True) + host = models.CharField("隧道连接", max_length=200) + port = models.IntegerField("端口", default=0) + user = fields.EncryptedCharField( + verbose_name="用户名", max_length=200, default="", blank=True, null=True + ) + password = fields.EncryptedCharField( + verbose_name="密码", max_length=300, default="", blank=True, null=True + ) pkey = fields.EncryptedTextField(verbose_name="密钥", blank=True, null=True) - pkey_path = models.FileField(verbose_name="密钥地址", blank=True, null=True, upload_to='keys/') - pkey_password = fields.EncryptedCharField(verbose_name='密钥密码', max_length=300, default='', blank=True, null=True) - create_time = models.DateTimeField('创建时间', auto_now_add=True) - update_time = models.DateTimeField('更新时间', auto_now=True) + pkey_path = models.FileField( + verbose_name="密钥地址", blank=True, null=True, upload_to="keys/" + ) + pkey_password = fields.EncryptedCharField( + verbose_name="密钥密码", max_length=300, default="", blank=True, null=True + ) + create_time = models.DateTimeField("创建时间", auto_now_add=True) + update_time = models.DateTimeField("更新时间", auto_now=True) def __str__(self): return self.tunnel_name def short_pkey(self): if len(str(self.pkey)) > 20: - return '{}...'.format(str(self.pkey)[0:19]) + return "{}...".format(str(self.pkey)[0:19]) else: return str(self.pkey) class Meta: managed = True - db_table = 'ssh_tunnel' - verbose_name = u'隧道配置' - verbose_name_plural = u'隧道配置' + db_table = "ssh_tunnel" + verbose_name = "隧道配置" + verbose_name_plural = "隧道配置" class Instance(models.Model): """ 各个线上实例配置 """ - instance_name = models.CharField('实例名称', max_length=50, unique=True) - type = models.CharField('实例类型', max_length=6, choices=(('master', '主库'), ('slave', '从库'))) - db_type = models.CharField('数据库类型', max_length=20, choices=DB_TYPE_CHOICES) - mode = models.CharField('运行模式', max_length=10, default='', blank=True, choices=(('standalone', '单机'), ('cluster', '集群'))) - host = models.CharField('实例连接', max_length=200) - port = models.IntegerField('端口', default=0) - user = fields.EncryptedCharField(verbose_name='用户名', max_length=200, default='', blank=True) - password = fields.EncryptedCharField(verbose_name='密码', max_length=300, default='', blank=True) - db_name = models.CharField('数据库', max_length=64, default='', blank=True) - charset = models.CharField('字符集', max_length=20, default='', blank=True) - service_name = models.CharField('Oracle service name', max_length=50, null=True, blank=True) - sid = models.CharField('Oracle sid', max_length=50, null=True, blank=True) - resource_group = models.ManyToManyField(ResourceGroup, verbose_name='资源组', blank=True) - instance_tag = models.ManyToManyField(InstanceTag, verbose_name='实例标签', blank=True) - tunnel = models.ForeignKey(Tunnel, verbose_name='连接隧道', blank=True, null=True, on_delete=models.CASCADE, default=None) - create_time = models.DateTimeField('创建时间', auto_now_add=True) - update_time = models.DateTimeField('更新时间', auto_now=True) + + instance_name = models.CharField("实例名称", max_length=50, unique=True) + type = models.CharField( + "实例类型", max_length=6, choices=(("master", "主库"), ("slave", "从库")) + ) + db_type = models.CharField("数据库类型", max_length=20, choices=DB_TYPE_CHOICES) + mode = models.CharField( + "运行模式", + max_length=10, + default="", + blank=True, + choices=(("standalone", "单机"), ("cluster", "集群")), + ) + host = models.CharField("实例连接", max_length=200) + port = models.IntegerField("端口", default=0) + user = fields.EncryptedCharField( + verbose_name="用户名", max_length=200, default="", blank=True + ) + password = fields.EncryptedCharField( + verbose_name="密码", max_length=300, default="", blank=True + ) + db_name = models.CharField("数据库", max_length=64, default="", blank=True) + charset = models.CharField("字符集", max_length=20, default="", blank=True) + service_name = models.CharField( + "Oracle service name", max_length=50, null=True, blank=True + ) + sid = models.CharField("Oracle sid", max_length=50, null=True, blank=True) + resource_group = models.ManyToManyField( + ResourceGroup, verbose_name="资源组", blank=True + ) + instance_tag = models.ManyToManyField(InstanceTag, verbose_name="实例标签", blank=True) + tunnel = models.ForeignKey( + Tunnel, + verbose_name="连接隧道", + blank=True, + null=True, + on_delete=models.CASCADE, + default=None, + ) + create_time = models.DateTimeField("创建时间", auto_now_add=True) + update_time = models.DateTimeField("更新时间", auto_now=True) def __str__(self): return self.instance_name class Meta: managed = True - db_table = 'sql_instance' - verbose_name = u'实例配置' - verbose_name_plural = u'实例配置' + db_table = "sql_instance" + verbose_name = "实例配置" + verbose_name_plural = "实例配置" SQL_WORKFLOW_CHOICES = ( - ('workflow_finish', _('workflow_finish')), - ('workflow_abort', _('workflow_abort')), - ('workflow_manreviewing', _('workflow_manreviewing')), - ('workflow_review_pass', _('workflow_review_pass')), - ('workflow_timingtask', _('workflow_timingtask')), - ('workflow_queuing', _('workflow_queuing')), - ('workflow_executing', _('workflow_executing')), - ('workflow_autoreviewwrong', _('workflow_autoreviewwrong')), - ('workflow_exception', _('workflow_exception'))) + ("workflow_finish", _("workflow_finish")), + ("workflow_abort", _("workflow_abort")), + ("workflow_manreviewing", _("workflow_manreviewing")), + ("workflow_review_pass", _("workflow_review_pass")), + ("workflow_timingtask", _("workflow_timingtask")), + ("workflow_queuing", _("workflow_queuing")), + ("workflow_executing", _("workflow_executing")), + ("workflow_autoreviewwrong", _("workflow_autoreviewwrong")), + ("workflow_exception", _("workflow_exception")), +) class SqlWorkflow(models.Model): """ 存放各个SQL上线工单的基础内容 """ - workflow_name = models.CharField('工单内容', max_length=50) - demand_url = models.CharField('需求链接', max_length=500) - group_id = models.IntegerField('组ID') - group_name = models.CharField('组名称', max_length=100) + + workflow_name = models.CharField("工单内容", max_length=50) + demand_url = models.CharField("需求链接", max_length=500) + group_id = models.IntegerField("组ID") + group_name = models.CharField("组名称", max_length=100) instance = models.ForeignKey(Instance, on_delete=models.CASCADE) - db_name = models.CharField('数据库', max_length=64) - syntax_type = models.IntegerField('工单类型 0、未知,1、DDL,2、DML', choices=((0, '其他'), (1, 'DDL'), (2, 'DML')), default=0) - is_backup = models.BooleanField('是否备份', choices=((False, '否'), (True, '是'),), default=True) - engineer = models.CharField('发起人', max_length=30) - engineer_display = models.CharField('发起人中文名', max_length=50, default='') + db_name = models.CharField("数据库", max_length=64) + syntax_type = models.IntegerField( + "工单类型 0、未知,1、DDL,2、DML", choices=((0, "其他"), (1, "DDL"), (2, "DML")), default=0 + ) + is_backup = models.BooleanField( + "是否备份", + choices=( + (False, "否"), + (True, "是"), + ), + default=True, + ) + engineer = models.CharField("发起人", max_length=30) + engineer_display = models.CharField("发起人中文名", max_length=50, default="") status = models.CharField(max_length=50, choices=SQL_WORKFLOW_CHOICES) - audit_auth_groups = models.CharField('审批权限组列表', max_length=255) - run_date_start = models.DateTimeField('可执行起始时间', null=True, blank=True) - run_date_end = models.DateTimeField('可执行结束时间', null=True, blank=True) - create_time = models.DateTimeField('创建时间', auto_now_add=True) - finish_time = models.DateTimeField('结束时间', null=True, blank=True) - is_manual = models.IntegerField('是否原生执行', choices=((0, '否'), (1, '是')), default=0) + audit_auth_groups = models.CharField("审批权限组列表", max_length=255) + run_date_start = models.DateTimeField("可执行起始时间", null=True, blank=True) + run_date_end = models.DateTimeField("可执行结束时间", null=True, blank=True) + create_time = models.DateTimeField("创建时间", auto_now_add=True) + finish_time = models.DateTimeField("结束时间", null=True, blank=True) + is_manual = models.IntegerField("是否原生执行", choices=((0, "否"), (1, "是")), default=0) def __str__(self): return self.workflow_name class Meta: managed = True - db_table = 'sql_workflow' - verbose_name = u'SQL工单' - verbose_name_plural = u'SQL工单' + db_table = "sql_workflow" + verbose_name = "SQL工单" + verbose_name_plural = "SQL工单" class SqlWorkflowContent(models.Model): @@ -229,87 +286,91 @@ class SqlWorkflowContent(models.Model): 存放各个SQL上线工单的SQL|审核|执行内容 可定期归档或清理历史数据,也可通过``alter table sql_workflow_content row_format=compressed; ``来进行压缩 """ + workflow = models.OneToOneField(SqlWorkflow, on_delete=models.CASCADE) - sql_content = models.TextField('具体sql内容') - review_content = models.TextField('自动审核内容的JSON格式') - execute_result = models.TextField('执行结果的JSON格式', blank=True) + sql_content = models.TextField("具体sql内容") + review_content = models.TextField("自动审核内容的JSON格式") + execute_result = models.TextField("执行结果的JSON格式", blank=True) def __str__(self): return self.workflow.workflow_name class Meta: managed = True - db_table = 'sql_workflow_content' - verbose_name = u'SQL工单内容' - verbose_name_plural = u'SQL工单内容' + db_table = "sql_workflow_content" + verbose_name = "SQL工单内容" + verbose_name_plural = "SQL工单内容" -workflow_type_choices = ((1, _('sql_query')), (2, _('sql_review'))) -workflow_status_choices = ((0, '待审核'), (1, '审核通过'), (2, '审核不通过'), (3, '审核取消')) +workflow_type_choices = ((1, _("sql_query")), (2, _("sql_review"))) +workflow_status_choices = ((0, "待审核"), (1, "审核通过"), (2, "审核不通过"), (3, "审核取消")) class WorkflowAudit(models.Model): """ 工作流审核状态表 """ + audit_id = models.AutoField(primary_key=True) - group_id = models.IntegerField('组ID') - group_name = models.CharField('组名称', max_length=100) - workflow_id = models.BigIntegerField('关联业务id') - workflow_type = models.IntegerField('申请类型', choices=workflow_type_choices) - workflow_title = models.CharField('申请标题', max_length=50) - workflow_remark = models.CharField('申请备注', default='', max_length=140, blank=True) - audit_auth_groups = models.CharField('审批权限组列表', max_length=255) - current_audit = models.CharField('当前审批权限组', max_length=20) - next_audit = models.CharField('下级审批权限组', max_length=20) - current_status = models.IntegerField('审核状态', choices=workflow_status_choices) - create_user = models.CharField('申请人', max_length=30) - create_user_display = models.CharField('申请人中文名', max_length=50, default='') - create_time = models.DateTimeField('申请时间', auto_now_add=True) - sys_time = models.DateTimeField('系统时间', auto_now=True) + group_id = models.IntegerField("组ID") + group_name = models.CharField("组名称", max_length=100) + workflow_id = models.BigIntegerField("关联业务id") + workflow_type = models.IntegerField("申请类型", choices=workflow_type_choices) + workflow_title = models.CharField("申请标题", max_length=50) + workflow_remark = models.CharField("申请备注", default="", max_length=140, blank=True) + audit_auth_groups = models.CharField("审批权限组列表", max_length=255) + current_audit = models.CharField("当前审批权限组", max_length=20) + next_audit = models.CharField("下级审批权限组", max_length=20) + current_status = models.IntegerField("审核状态", choices=workflow_status_choices) + create_user = models.CharField("申请人", max_length=30) + create_user_display = models.CharField("申请人中文名", max_length=50, default="") + create_time = models.DateTimeField("申请时间", auto_now_add=True) + sys_time = models.DateTimeField("系统时间", auto_now=True) def __int__(self): return self.audit_id class Meta: managed = True - db_table = 'workflow_audit' - unique_together = ('workflow_id', 'workflow_type') - verbose_name = u'工作流审批列表' - verbose_name_plural = u'工作流审批列表' + db_table = "workflow_audit" + unique_together = ("workflow_id", "workflow_type") + verbose_name = "工作流审批列表" + verbose_name_plural = "工作流审批列表" class WorkflowAuditDetail(models.Model): """ 审批明细表 """ + audit_detail_id = models.AutoField(primary_key=True) - audit_id = models.IntegerField('审核主表id') - audit_user = models.CharField('审核人', max_length=30) - audit_time = models.DateTimeField('审核时间') - audit_status = models.IntegerField('审核状态', choices=workflow_status_choices) - remark = models.CharField('审核备注', default='', max_length=1000) - sys_time = models.DateTimeField('系统时间', auto_now=True) + audit_id = models.IntegerField("审核主表id") + audit_user = models.CharField("审核人", max_length=30) + audit_time = models.DateTimeField("审核时间") + audit_status = models.IntegerField("审核状态", choices=workflow_status_choices) + remark = models.CharField("审核备注", default="", max_length=1000) + sys_time = models.DateTimeField("系统时间", auto_now=True) def __int__(self): return self.audit_detail_id class Meta: managed = True - db_table = 'workflow_audit_detail' - verbose_name = u'工作流审批明细' - verbose_name_plural = u'工作流审批明细' + db_table = "workflow_audit_detail" + verbose_name = "工作流审批明细" + verbose_name_plural = "工作流审批明细" class WorkflowAuditSetting(models.Model): """ 审批配置表 """ + audit_setting_id = models.AutoField(primary_key=True) - group_id = models.IntegerField('组ID') - group_name = models.CharField('组名称', max_length=100) - workflow_type = models.IntegerField('审批类型', choices=workflow_type_choices) - audit_auth_groups = models.CharField('审批权限组列表', max_length=255) + group_id = models.IntegerField("组ID") + group_name = models.CharField("组名称", max_length=100) + workflow_type = models.IntegerField("审批类型", choices=workflow_type_choices) + audit_auth_groups = models.CharField("审批权限组列表", max_length=255) create_time = models.DateTimeField(auto_now_add=True) sys_time = models.DateTimeField(auto_now=True) @@ -318,33 +379,34 @@ def __int__(self): class Meta: managed = True - db_table = 'workflow_audit_setting' - unique_together = ('group_id', 'workflow_type') - verbose_name = u'审批流程配置' - verbose_name_plural = u'审批流程配置' + db_table = "workflow_audit_setting" + unique_together = ("group_id", "workflow_type") + verbose_name = "审批流程配置" + verbose_name_plural = "审批流程配置" class WorkflowLog(models.Model): """ 工作流日志表 """ + operation_type_choices = ( - (0, '提交/待审核'), - (1, '审核通过'), - (2, '审核不通过'), - (3, '审核取消'), - (4, '定时执行'), - (5, '执行工单'), - (6, '执行结束'), + (0, "提交/待审核"), + (1, "审核通过"), + (2, "审核不通过"), + (3, "审核取消"), + (4, "定时执行"), + (5, "执行工单"), + (6, "执行结束"), ) id = models.AutoField(primary_key=True) - audit_id = models.IntegerField('工单审批id', db_index=True) - operation_type = models.SmallIntegerField('操作类型', choices=operation_type_choices) - operation_type_desc = models.CharField('操作类型描述', max_length=10) - operation_info = models.CharField('操作信息', max_length=1000) - operator = models.CharField('操作人', max_length=30) - operator_display = models.CharField('操作人中文名', max_length=50, default='') + audit_id = models.IntegerField("工单审批id", db_index=True) + operation_type = models.SmallIntegerField("操作类型", choices=operation_type_choices) + operation_type_desc = models.CharField("操作类型描述", max_length=10) + operation_info = models.CharField("操作信息", max_length=1000) + operator = models.CharField("操作人", max_length=30) + operator_display = models.CharField("操作人中文名", max_length=50, default="") operation_time = models.DateTimeField(auto_now_add=True) def __int__(self): @@ -352,30 +414,38 @@ def __int__(self): class Meta: managed = True - db_table = 'workflow_log' - verbose_name = u'工作流日志' - verbose_name_plural = u'工作流日志' + db_table = "workflow_log" + verbose_name = "工作流日志" + verbose_name_plural = "工作流日志" class QueryPrivilegesApply(models.Model): """ 查询权限申请记录表 """ + apply_id = models.AutoField(primary_key=True) - group_id = models.IntegerField('组ID') - group_name = models.CharField('组名称', max_length=100) - title = models.CharField('申请标题', max_length=50) + group_id = models.IntegerField("组ID") + group_name = models.CharField("组名称", max_length=100) + title = models.CharField("申请标题", max_length=50) # TODO user_name display 改为外键 - user_name = models.CharField('申请人', max_length=30) - user_display = models.CharField('申请人中文名', max_length=50, default='') + user_name = models.CharField("申请人", max_length=30) + user_display = models.CharField("申请人中文名", max_length=50, default="") instance = models.ForeignKey(Instance, on_delete=models.CASCADE) - db_list = models.TextField('数据库', default='') # 逗号分隔的数据库列表 - table_list = models.TextField('表', default='') # 逗号分隔的表列表 - valid_date = models.DateField('有效时间') - limit_num = models.IntegerField('行数限制', default=100) - priv_type = models.IntegerField('权限类型', choices=((1, 'DATABASE'), (2, 'TABLE'),), default=0) - status = models.IntegerField('审核状态', choices=workflow_status_choices) - audit_auth_groups = models.CharField('审批权限组列表', max_length=255) + db_list = models.TextField("数据库", default="") # 逗号分隔的数据库列表 + table_list = models.TextField("表", default="") # 逗号分隔的表列表 + valid_date = models.DateField("有效时间") + limit_num = models.IntegerField("行数限制", default=100) + priv_type = models.IntegerField( + "权限类型", + choices=( + (1, "DATABASE"), + (2, "TABLE"), + ), + default=0, + ) + status = models.IntegerField("审核状态", choices=workflow_status_choices) + audit_auth_groups = models.CharField("审批权限组列表", max_length=255) create_time = models.DateTimeField(auto_now_add=True) sys_time = models.DateTimeField(auto_now=True) @@ -384,25 +454,33 @@ def __int__(self): class Meta: managed = True - db_table = 'query_privileges_apply' - verbose_name = u'查询权限申请记录表' - verbose_name_plural = u'查询权限申请记录表' + db_table = "query_privileges_apply" + verbose_name = "查询权限申请记录表" + verbose_name_plural = "查询权限申请记录表" class QueryPrivileges(models.Model): """ 用户权限关系表 """ + privilege_id = models.AutoField(primary_key=True) - user_name = models.CharField('用户名', max_length=30) - user_display = models.CharField('申请人中文名', max_length=50, default='') + user_name = models.CharField("用户名", max_length=30) + user_display = models.CharField("申请人中文名", max_length=50, default="") instance = models.ForeignKey(Instance, on_delete=models.CASCADE) - db_name = models.CharField('数据库', max_length=64, default='') - table_name = models.CharField('表', max_length=64, default='') - valid_date = models.DateField('有效时间') - limit_num = models.IntegerField('行数限制', default=100) - priv_type = models.IntegerField('权限类型', choices=((1, 'DATABASE'), (2, 'TABLE'),), default=0) - is_deleted = models.IntegerField('是否删除', default=0) + db_name = models.CharField("数据库", max_length=64, default="") + table_name = models.CharField("表", max_length=64, default="") + valid_date = models.DateField("有效时间") + limit_num = models.IntegerField("行数限制", default=100) + priv_type = models.IntegerField( + "权限类型", + choices=( + (1, "DATABASE"), + (2, "TABLE"), + ), + default=0, + ) + is_deleted = models.IntegerField("是否删除", default=0) create_time = models.DateTimeField(auto_now_add=True) sys_time = models.DateTimeField(auto_now=True) @@ -411,245 +489,306 @@ def __int__(self): class Meta: managed = True - db_table = 'query_privileges' + db_table = "query_privileges" index_together = ["user_name", "instance", "db_name", "valid_date"] - verbose_name = u'查询权限记录' - verbose_name_plural = u'查询权限记录' + verbose_name = "查询权限记录" + verbose_name_plural = "查询权限记录" class QueryLog(models.Model): """ 记录在线查询sql的日志 """ + # TODO 改为实例外键 - instance_name = models.CharField('实例名称', max_length=50) - db_name = models.CharField('数据库名称', max_length=64) - sqllog = models.TextField('执行的查询语句') - effect_row = models.BigIntegerField('返回行数') - cost_time = models.CharField('执行耗时', max_length=10, default='') + instance_name = models.CharField("实例名称", max_length=50) + db_name = models.CharField("数据库名称", max_length=64) + sqllog = models.TextField("执行的查询语句") + effect_row = models.BigIntegerField("返回行数") + cost_time = models.CharField("执行耗时", max_length=10, default="") # TODO 改为user 外键 - username = models.CharField('操作人', max_length=30) - user_display = models.CharField('操作人中文名', max_length=50, default='') - priv_check = models.BooleanField('查询权限是否正常校验', choices=((False, '跳过'), (True, '正常'),), default=False) - hit_rule = models.BooleanField('查询是否命中脱敏规则', choices=((False, '未命中/未知'), (True, '命中')), default=False) - masking = models.BooleanField('查询结果是否正常脱敏', choices=((False, '否'), (True, '是'),), default=False) - favorite = models.BooleanField('是否收藏', choices=((False, '否'), (True, '是'),), default=False) - alias = models.CharField('语句标识', max_length=64, default='', blank=True) - create_time = models.DateTimeField('操作时间', auto_now_add=True) + username = models.CharField("操作人", max_length=30) + user_display = models.CharField("操作人中文名", max_length=50, default="") + priv_check = models.BooleanField( + "查询权限是否正常校验", + choices=( + (False, "跳过"), + (True, "正常"), + ), + default=False, + ) + hit_rule = models.BooleanField( + "查询是否命中脱敏规则", choices=((False, "未命中/未知"), (True, "命中")), default=False + ) + masking = models.BooleanField( + "查询结果是否正常脱敏", + choices=( + (False, "否"), + (True, "是"), + ), + default=False, + ) + favorite = models.BooleanField( + "是否收藏", + choices=( + (False, "否"), + (True, "是"), + ), + default=False, + ) + alias = models.CharField("语句标识", max_length=64, default="", blank=True) + create_time = models.DateTimeField("操作时间", auto_now_add=True) sys_time = models.DateTimeField(auto_now=True) class Meta: managed = True - db_table = 'query_log' - verbose_name = u'查询日志' - verbose_name_plural = u'查询日志' + db_table = "query_log" + verbose_name = "查询日志" + verbose_name_plural = "查询日志" -rule_type_choices = ((1, '手机号'), (2, '证件号码'), (3, '银行卡'), (4, '邮箱'), (5, '金额'), (6, '其他')) +rule_type_choices = ( + (1, "手机号"), + (2, "证件号码"), + (3, "银行卡"), + (4, "邮箱"), + (5, "金额"), + (6, "其他"), +) class DataMaskingColumns(models.Model): """ 脱敏字段配置 """ - column_id = models.AutoField('字段id', primary_key=True) - rule_type = models.IntegerField('规则类型', choices=rule_type_choices) - active = models.BooleanField('激活状态', choices=((False, '未激活'), (True, '激活'))) + + column_id = models.AutoField("字段id", primary_key=True) + rule_type = models.IntegerField("规则类型", choices=rule_type_choices) + active = models.BooleanField("激活状态", choices=((False, "未激活"), (True, "激活"))) instance = models.ForeignKey(Instance, on_delete=models.CASCADE) - table_schema = models.CharField('字段所在库名', max_length=64) - table_name = models.CharField('字段所在表名', max_length=64) - column_name = models.CharField('字段名', max_length=64) - column_comment = models.CharField('字段描述', max_length=1024, default='', blank=True) + table_schema = models.CharField("字段所在库名", max_length=64) + table_name = models.CharField("字段所在表名", max_length=64) + column_name = models.CharField("字段名", max_length=64) + column_comment = models.CharField("字段描述", max_length=1024, default="", blank=True) create_time = models.DateTimeField(auto_now_add=True) sys_time = models.DateTimeField(auto_now=True) class Meta: managed = True - db_table = 'data_masking_columns' - verbose_name = u'脱敏字段配置' - verbose_name_plural = u'脱敏字段配置' + db_table = "data_masking_columns" + verbose_name = "脱敏字段配置" + verbose_name_plural = "脱敏字段配置" class DataMaskingRules(models.Model): """ 脱敏规则配置 """ - rule_type = models.IntegerField('规则类型', choices=rule_type_choices, unique=True) - rule_regex = models.CharField('规则脱敏所用的正则表达式,表达式必须分组,隐藏的组会使用****代替', max_length=255) - hide_group = models.IntegerField('需要隐藏的组') - rule_desc = models.CharField('规则描述', max_length=100, default='', blank=True) + + rule_type = models.IntegerField("规则类型", choices=rule_type_choices, unique=True) + rule_regex = models.CharField("规则脱敏所用的正则表达式,表达式必须分组,隐藏的组会使用****代替", max_length=255) + hide_group = models.IntegerField("需要隐藏的组") + rule_desc = models.CharField("规则描述", max_length=100, default="", blank=True) sys_time = models.DateTimeField(auto_now=True) class Meta: managed = True - db_table = 'data_masking_rules' - verbose_name = u'脱敏规则配置' - verbose_name_plural = u'脱敏规则配置' + db_table = "data_masking_rules" + verbose_name = "脱敏规则配置" + verbose_name_plural = "脱敏规则配置" class InstanceAccount(models.Model): """ 实例账号列表 """ + instance = models.ForeignKey(Instance, on_delete=models.CASCADE) - user = fields.EncryptedCharField(verbose_name='账号', max_length=128) - host = models.CharField(verbose_name='主机', max_length=64) - password = fields.EncryptedCharField(verbose_name='密码', max_length=128, default='', blank=True) - remark = models.CharField('备注', max_length=255) - sys_time = models.DateTimeField('系统修改时间', auto_now=True) + user = fields.EncryptedCharField(verbose_name="账号", max_length=128) + host = models.CharField(verbose_name="主机", max_length=64) + password = fields.EncryptedCharField( + verbose_name="密码", max_length=128, default="", blank=True + ) + remark = models.CharField("备注", max_length=255) + sys_time = models.DateTimeField("系统修改时间", auto_now=True) class Meta: managed = True - db_table = 'instance_account' - unique_together = ('instance', 'user', 'host') - verbose_name = '实例账号列表' - verbose_name_plural = '实例账号列表' + db_table = "instance_account" + unique_together = ("instance", "user", "host") + verbose_name = "实例账号列表" + verbose_name_plural = "实例账号列表" class InstanceDatabase(models.Model): """ 实例数据库列表 """ + instance = models.ForeignKey(Instance, on_delete=models.CASCADE) - db_name = models.CharField('数据库名', max_length=128) - owner = models.CharField('负责人', max_length=50, default='', blank=True) - owner_display = models.CharField('负责人中文名', max_length=50, default='', blank=True) - remark = models.CharField('备注', max_length=255, default='', blank=True) - sys_time = models.DateTimeField('系统修改时间', auto_now=True) + db_name = models.CharField("数据库名", max_length=128) + owner = models.CharField("负责人", max_length=50, default="", blank=True) + owner_display = models.CharField("负责人中文名", max_length=50, default="", blank=True) + remark = models.CharField("备注", max_length=255, default="", blank=True) + sys_time = models.DateTimeField("系统修改时间", auto_now=True) class Meta: managed = True - db_table = 'instance_database' - unique_together = ('instance', 'db_name') - verbose_name = '实例数据库' - verbose_name_plural = '实例数据库列表' + db_table = "instance_database" + unique_together = ("instance", "db_name") + verbose_name = "实例数据库" + verbose_name_plural = "实例数据库列表" class ParamTemplate(models.Model): """ 实例参数模板配置 """ - db_type = models.CharField('数据库类型', max_length=20, choices=DB_TYPE_CHOICES) - variable_name = models.CharField('参数名', max_length=64) - default_value = models.CharField('默认参数值', max_length=1024) - editable = models.BooleanField('是否支持修改', default=False) - valid_values = models.CharField('有效参数值,范围参数[1-65535],值参数[ON|OFF]', max_length=1024, blank=True) - description = models.CharField('参数描述', max_length=1024, blank=True) - create_time = models.DateTimeField('创建时间', auto_now_add=True) - sys_time = models.DateTimeField('系统时间修改', auto_now=True) + + db_type = models.CharField("数据库类型", max_length=20, choices=DB_TYPE_CHOICES) + variable_name = models.CharField("参数名", max_length=64) + default_value = models.CharField("默认参数值", max_length=1024) + editable = models.BooleanField("是否支持修改", default=False) + valid_values = models.CharField( + "有效参数值,范围参数[1-65535],值参数[ON|OFF]", max_length=1024, blank=True + ) + description = models.CharField("参数描述", max_length=1024, blank=True) + create_time = models.DateTimeField("创建时间", auto_now_add=True) + sys_time = models.DateTimeField("系统时间修改", auto_now=True) class Meta: managed = True - db_table = 'param_template' - unique_together = ('db_type', 'variable_name') - verbose_name = u'实例参数模板配置' - verbose_name_plural = u'实例参数模板配置' + db_table = "param_template" + unique_together = ("db_type", "variable_name") + verbose_name = "实例参数模板配置" + verbose_name_plural = "实例参数模板配置" class ParamHistory(models.Model): """ 可在线修改的动态参数配置 """ + instance = models.ForeignKey(Instance, on_delete=models.CASCADE) - variable_name = models.CharField('参数名', max_length=64) - old_var = models.CharField('修改前参数值', max_length=1024) - new_var = models.CharField('修改后参数值', max_length=1024) - set_sql = models.CharField('在线变更配置执行的SQL语句', max_length=1024) - user_name = models.CharField('修改人', max_length=30) - user_display = models.CharField('修改人中文名', max_length=50) - create_time = models.DateTimeField('参数被修改时间点', auto_now_add=True) + variable_name = models.CharField("参数名", max_length=64) + old_var = models.CharField("修改前参数值", max_length=1024) + new_var = models.CharField("修改后参数值", max_length=1024) + set_sql = models.CharField("在线变更配置执行的SQL语句", max_length=1024) + user_name = models.CharField("修改人", max_length=30) + user_display = models.CharField("修改人中文名", max_length=50) + create_time = models.DateTimeField("参数被修改时间点", auto_now_add=True) class Meta: managed = True - ordering = ['-create_time'] - db_table = 'param_history' - verbose_name = u'实例参数修改历史' - verbose_name_plural = u'实例参数修改历史' + ordering = ["-create_time"] + db_table = "param_history" + verbose_name = "实例参数修改历史" + verbose_name_plural = "实例参数修改历史" class ArchiveConfig(models.Model): """ 归档配置表 """ - title = models.CharField('归档配置说明', max_length=50) + + title = models.CharField("归档配置说明", max_length=50) resource_group = models.ForeignKey(ResourceGroup, on_delete=models.CASCADE) - audit_auth_groups = models.CharField('审批权限组列表', max_length=255, blank=True) - src_instance = models.ForeignKey(Instance, related_name='src_instance', on_delete=models.CASCADE) - src_db_name = models.CharField('源数据库', max_length=64) - src_table_name = models.CharField('源表', max_length=64) - dest_instance = models.ForeignKey(Instance, related_name='dest_instance', on_delete=models.CASCADE, - blank=True, null=True) - dest_db_name = models.CharField('目标数据库', max_length=64, blank=True, null=True) - dest_table_name = models.CharField('目标表', max_length=64, blank=True, null=True) - condition = models.CharField('归档条件,where条件', max_length=1000) - mode = models.CharField('归档模式', max_length=10, choices=(('file', '文件'), ('dest', '其他实例'), ('purge', '直接删除'))) - no_delete = models.BooleanField('是否保留源数据') - sleep = models.IntegerField('归档limit行后的休眠秒数', default=1) - status = models.IntegerField('审核状态', choices=workflow_status_choices, blank=True, default=1) - state = models.BooleanField('是否启用归档', default=True) - user_name = models.CharField('申请人', max_length=30, blank=True, default='') - user_display = models.CharField('申请人中文名', max_length=50, blank=True, default='') - create_time = models.DateTimeField('创建时间', auto_now_add=True) - last_archive_time = models.DateTimeField('最近归档时间', blank=True, null=True) - sys_time = models.DateTimeField('系统时间修改', auto_now=True) + audit_auth_groups = models.CharField("审批权限组列表", max_length=255, blank=True) + src_instance = models.ForeignKey( + Instance, related_name="src_instance", on_delete=models.CASCADE + ) + src_db_name = models.CharField("源数据库", max_length=64) + src_table_name = models.CharField("源表", max_length=64) + dest_instance = models.ForeignKey( + Instance, + related_name="dest_instance", + on_delete=models.CASCADE, + blank=True, + null=True, + ) + dest_db_name = models.CharField("目标数据库", max_length=64, blank=True, null=True) + dest_table_name = models.CharField("目标表", max_length=64, blank=True, null=True) + condition = models.CharField("归档条件,where条件", max_length=1000) + mode = models.CharField( + "归档模式", + max_length=10, + choices=(("file", "文件"), ("dest", "其他实例"), ("purge", "直接删除")), + ) + no_delete = models.BooleanField("是否保留源数据") + sleep = models.IntegerField("归档limit行后的休眠秒数", default=1) + status = models.IntegerField( + "审核状态", choices=workflow_status_choices, blank=True, default=1 + ) + state = models.BooleanField("是否启用归档", default=True) + user_name = models.CharField("申请人", max_length=30, blank=True, default="") + user_display = models.CharField("申请人中文名", max_length=50, blank=True, default="") + create_time = models.DateTimeField("创建时间", auto_now_add=True) + last_archive_time = models.DateTimeField("最近归档时间", blank=True, null=True) + sys_time = models.DateTimeField("系统时间修改", auto_now=True) class Meta: managed = True - db_table = 'archive_config' - verbose_name = u'归档配置表' - verbose_name_plural = u'归档配置表' + db_table = "archive_config" + verbose_name = "归档配置表" + verbose_name_plural = "归档配置表" class ArchiveLog(models.Model): """ 归档日志表 """ + archive = models.ForeignKey(ArchiveConfig, on_delete=models.CASCADE) - cmd = models.CharField('归档命令', max_length=2000) - condition = models.CharField('归档条件,where条件', max_length=1000) - mode = models.CharField('归档模式', max_length=10, choices=(('file', '文件'), ('dest', '其他实例'), ('purge', '直接删除'))) - no_delete = models.BooleanField('是否保留源数据') - sleep = models.IntegerField('归档limit行记录后的休眠秒数', default=0) - select_cnt = models.IntegerField('查询数量') - insert_cnt = models.IntegerField('插入数量') - delete_cnt = models.IntegerField('删除数量') - statistics = models.TextField('归档统计日志') - success = models.BooleanField('是否归档成功') - error_info = models.TextField('错误信息') - start_time = models.DateTimeField('开始时间') - end_time = models.DateTimeField('结束时间') - sys_time = models.DateTimeField('系统时间修改', auto_now=True) + cmd = models.CharField("归档命令", max_length=2000) + condition = models.CharField("归档条件,where条件", max_length=1000) + mode = models.CharField( + "归档模式", + max_length=10, + choices=(("file", "文件"), ("dest", "其他实例"), ("purge", "直接删除")), + ) + no_delete = models.BooleanField("是否保留源数据") + sleep = models.IntegerField("归档limit行记录后的休眠秒数", default=0) + select_cnt = models.IntegerField("查询数量") + insert_cnt = models.IntegerField("插入数量") + delete_cnt = models.IntegerField("删除数量") + statistics = models.TextField("归档统计日志") + success = models.BooleanField("是否归档成功") + error_info = models.TextField("错误信息") + start_time = models.DateTimeField("开始时间") + end_time = models.DateTimeField("结束时间") + sys_time = models.DateTimeField("系统时间修改", auto_now=True) class Meta: managed = True - db_table = 'archive_log' - verbose_name = u'归档日志表' - verbose_name_plural = u'归档日志表' + db_table = "archive_log" + verbose_name = "归档日志表" + verbose_name_plural = "归档日志表" class Config(models.Model): """ 配置信息表 """ - item = models.CharField('配置项', max_length=100, unique=True) - value = fields.EncryptedCharField(verbose_name='配置项值', max_length=500) - description = models.CharField('描述', max_length=200, default='', blank=True) + + item = models.CharField("配置项", max_length=100, unique=True) + value = fields.EncryptedCharField(verbose_name="配置项值", max_length=500) + description = models.CharField("描述", max_length=200, default="", blank=True) class Meta: managed = True - db_table = 'sql_config' - verbose_name = u'系统配置' - verbose_name_plural = u'系统配置' + db_table = "sql_config" + verbose_name = "系统配置" + verbose_name_plural = "系统配置" # 云服务认证信息配置 class CloudAccessKey(models.Model): - cloud_type_choices = (('aliyun', 'aliyun'),) + cloud_type_choices = (("aliyun", "aliyun"),) - type = models.CharField(max_length=20, default='', choices=cloud_type_choices) + type = models.CharField(max_length=20, default="", choices=cloud_type_choices) key_id = models.CharField(max_length=200) key_secret = models.CharField(max_length=200) - remark = models.CharField(max_length=50, default='', blank=True) + remark = models.CharField(max_length=50, default="", blank=True) def __init__(self, *args, **kwargs): self.c = Crypto() @@ -657,12 +796,12 @@ def __init__(self, *args, **kwargs): @property def raw_key_id(self): - """ 返回明文信息""" + """返回明文信息""" return self.c.decrypt(self.key_id) @property def raw_key_secret(self): - """ 返回明文信息""" + """返回明文信息""" return self.c.decrypt(self.key_secret) def save(self, *args, **kwargs): @@ -671,32 +810,35 @@ def save(self, *args, **kwargs): super(CloudAccessKey, self).save(*args, **kwargs) def __str__(self): - return f'{self.type}({self.remark})' + return f"{self.type}({self.remark})" class Meta: managed = True - db_table = 'cloud_access_key' - verbose_name = u'云服务认证信息配置' - verbose_name_plural = u'云服务认证信息配置' + db_table = "cloud_access_key" + verbose_name = "云服务认证信息配置" + verbose_name_plural = "云服务认证信息配置" class AliyunRdsConfig(models.Model): """ 阿里云rds配置信息 """ + instance = models.OneToOneField(Instance, on_delete=models.CASCADE) - rds_dbinstanceid = models.CharField('对应阿里云RDS实例ID', max_length=100) - ak = models.ForeignKey(CloudAccessKey, verbose_name='RDS实例对应的AK配置', on_delete=models.CASCADE) - is_enable = models.BooleanField('是否启用', default=False) + rds_dbinstanceid = models.CharField("对应阿里云RDS实例ID", max_length=100) + ak = models.ForeignKey( + CloudAccessKey, verbose_name="RDS实例对应的AK配置", on_delete=models.CASCADE + ) + is_enable = models.BooleanField("是否启用", default=False) def __int__(self): return self.rds_dbinstanceid class Meta: managed = True - db_table = 'aliyun_rds_config' - verbose_name = u'阿里云rds配置' - verbose_name_plural = u'阿里云rds配置' + db_table = "aliyun_rds_config" + verbose_name = "阿里云rds配置" + verbose_name_plural = "阿里云rds配置" class Permission(models.Model): @@ -707,58 +849,58 @@ class Permission(models.Model): class Meta: managed = True permissions = ( - ('menu_dashboard', '菜单 Dashboard'), - ('menu_sqlcheck', '菜单 SQL审核'), - ('menu_sqlworkflow', '菜单 SQL上线'), - ('menu_sqlanalyze', '菜单 SQL分析'), - ('menu_query', '菜单 SQL查询'), - ('menu_sqlquery', '菜单 在线查询'), - ('menu_queryapplylist', '菜单 权限管理'), - ('menu_sqloptimize', '菜单 SQL优化'), - ('menu_sqladvisor', '菜单 优化工具'), - ('menu_slowquery', '菜单 慢查日志'), - ('menu_instance', '菜单 实例管理'), - ('menu_instance_list', '菜单 实例列表'), - ('menu_dbdiagnostic', '菜单 会话管理'), - ('menu_database', '菜单 数据库管理'), - ('menu_instance_account', '菜单 实例账号管理'), - ('menu_param', '菜单 参数配置'), - ('menu_data_dictionary', '菜单 数据字典'), - ('menu_tools', '菜单 工具插件'), - ('menu_archive', '菜单 数据归档'), - ('menu_my2sql', '菜单 My2SQL'), - ('menu_schemasync', '菜单 SchemaSync'), - ('menu_system', '菜单 系统管理'), - ('menu_document', '菜单 相关文档'), - ('menu_openapi', '菜单 OpenAPI'), - ('sql_submit', '提交SQL上线工单'), - ('sql_review', '审核SQL上线工单'), - ('sql_execute_for_resource_group', '执行SQL上线工单(资源组粒度)'), - ('sql_execute', '执行SQL上线工单(仅自己提交的)'), - ('sql_analyze', '执行SQL分析'), - ('optimize_sqladvisor', '执行SQLAdvisor'), - ('optimize_sqltuning', '执行SQLTuning'), - ('optimize_soar', '执行SOAR'), - ('query_applypriv', '申请查询权限'), - ('query_mgtpriv', '管理查询权限'), - ('query_review', '审核查询权限'), - ('query_submit', '提交SQL查询'), - ('query_all_instances', '可查询所有实例'), - ('query_resource_group_instance', '可查询所在资源组内的所有实例'), - ('process_view', '查看会话'), - ('process_kill', '终止会话'), - ('tablespace_view', '查看表空间'), - ('trx_view', '查看事务信息'), - ('trxandlocks_view', '查看锁信息'), - ('instance_account_manage', '管理实例账号'), - ('param_view', '查看实例参数列表'), - ('param_edit', '修改实例参数'), - ('data_dictionary_export', '导出数据字典'), - ('archive_apply', '提交归档申请'), - ('archive_review', '审核归档申请'), - ('archive_mgt', '管理归档申请'), - ('audit_user','审计权限'), - ('query_download', '在线查询下载权限'), + ("menu_dashboard", "菜单 Dashboard"), + ("menu_sqlcheck", "菜单 SQL审核"), + ("menu_sqlworkflow", "菜单 SQL上线"), + ("menu_sqlanalyze", "菜单 SQL分析"), + ("menu_query", "菜单 SQL查询"), + ("menu_sqlquery", "菜单 在线查询"), + ("menu_queryapplylist", "菜单 权限管理"), + ("menu_sqloptimize", "菜单 SQL优化"), + ("menu_sqladvisor", "菜单 优化工具"), + ("menu_slowquery", "菜单 慢查日志"), + ("menu_instance", "菜单 实例管理"), + ("menu_instance_list", "菜单 实例列表"), + ("menu_dbdiagnostic", "菜单 会话管理"), + ("menu_database", "菜单 数据库管理"), + ("menu_instance_account", "菜单 实例账号管理"), + ("menu_param", "菜单 参数配置"), + ("menu_data_dictionary", "菜单 数据字典"), + ("menu_tools", "菜单 工具插件"), + ("menu_archive", "菜单 数据归档"), + ("menu_my2sql", "菜单 My2SQL"), + ("menu_schemasync", "菜单 SchemaSync"), + ("menu_system", "菜单 系统管理"), + ("menu_document", "菜单 相关文档"), + ("menu_openapi", "菜单 OpenAPI"), + ("sql_submit", "提交SQL上线工单"), + ("sql_review", "审核SQL上线工单"), + ("sql_execute_for_resource_group", "执行SQL上线工单(资源组粒度)"), + ("sql_execute", "执行SQL上线工单(仅自己提交的)"), + ("sql_analyze", "执行SQL分析"), + ("optimize_sqladvisor", "执行SQLAdvisor"), + ("optimize_sqltuning", "执行SQLTuning"), + ("optimize_soar", "执行SOAR"), + ("query_applypriv", "申请查询权限"), + ("query_mgtpriv", "管理查询权限"), + ("query_review", "审核查询权限"), + ("query_submit", "提交SQL查询"), + ("query_all_instances", "可查询所有实例"), + ("query_resource_group_instance", "可查询所在资源组内的所有实例"), + ("process_view", "查看会话"), + ("process_kill", "终止会话"), + ("tablespace_view", "查看表空间"), + ("trx_view", "查看事务信息"), + ("trxandlocks_view", "查看锁信息"), + ("instance_account_manage", "管理实例账号"), + ("param_view", "查看实例参数列表"), + ("param_edit", "修改实例参数"), + ("data_dictionary_export", "导出数据字典"), + ("archive_apply", "提交归档申请"), + ("archive_review", "审核归档申请"), + ("archive_mgt", "管理归档申请"), + ("audit_user", "审计权限"), + ("query_download", "在线查询下载权限"), ) @@ -766,6 +908,7 @@ class SlowQuery(models.Model): """ SlowQuery """ + checksum = models.CharField(max_length=32, primary_key=True) fingerprint = models.TextField() sample = models.TextField() @@ -777,143 +920,286 @@ class SlowQuery(models.Model): class Meta: managed = False - db_table = 'mysql_slow_query_review' - verbose_name = u'慢日志统计' - verbose_name_plural = u'慢日志统计' + db_table = "mysql_slow_query_review" + verbose_name = "慢日志统计" + verbose_name_plural = "慢日志统计" class SlowQueryHistory(models.Model): """ SlowQueryHistory """ + hostname_max = models.CharField(max_length=64, null=False) client_max = models.CharField(max_length=64, null=True) user_max = models.CharField(max_length=64, null=False) db_max = models.CharField(max_length=64, null=True, default=None) bytes_max = models.CharField(max_length=64, null=True) - checksum = models.ForeignKey(SlowQuery, db_constraint=False, to_field='checksum', db_column='checksum', - on_delete=models.CASCADE) + checksum = models.ForeignKey( + SlowQuery, + db_constraint=False, + to_field="checksum", + db_column="checksum", + on_delete=models.CASCADE, + ) sample = models.TextField() ts_min = models.DateTimeField(db_index=True) ts_max = models.DateTimeField() ts_cnt = models.FloatField(blank=True, null=True) - query_time_sum = models.FloatField(db_column='Query_time_sum', blank=True, null=True) - query_time_min = models.FloatField(db_column='Query_time_min', blank=True, null=True) - query_time_max = models.FloatField(db_column='Query_time_max', blank=True, null=True) - query_time_pct_95 = models.FloatField(db_column='Query_time_pct_95', blank=True, null=True) - query_time_stddev = models.FloatField(db_column='Query_time_stddev', blank=True, null=True) - query_time_median = models.FloatField(db_column='Query_time_median', blank=True, null=True) - lock_time_sum = models.FloatField(db_column='Lock_time_sum', blank=True, null=True) - lock_time_min = models.FloatField(db_column='Lock_time_min', blank=True, null=True) - lock_time_max = models.FloatField(db_column='Lock_time_max', blank=True, null=True) - lock_time_pct_95 = models.FloatField(db_column='Lock_time_pct_95', blank=True, null=True) - lock_time_stddev = models.FloatField(db_column='Lock_time_stddev', blank=True, null=True) - lock_time_median = models.FloatField(db_column='Lock_time_median', blank=True, null=True) - rows_sent_sum = models.FloatField(db_column='Rows_sent_sum', blank=True, null=True) - rows_sent_min = models.FloatField(db_column='Rows_sent_min', blank=True, null=True) - rows_sent_max = models.FloatField(db_column='Rows_sent_max', blank=True, null=True) - rows_sent_pct_95 = models.FloatField(db_column='Rows_sent_pct_95', blank=True, null=True) - rows_sent_stddev = models.FloatField(db_column='Rows_sent_stddev', blank=True, null=True) - rows_sent_median = models.FloatField(db_column='Rows_sent_median', blank=True, null=True) - rows_examined_sum = models.FloatField(db_column='Rows_examined_sum', blank=True, null=True) - rows_examined_min = models.FloatField(db_column='Rows_examined_min', blank=True, null=True) - rows_examined_max = models.FloatField(db_column='Rows_examined_max', blank=True, null=True) - rows_examined_pct_95 = models.FloatField(db_column='Rows_examined_pct_95', blank=True, null=True) - rows_examined_stddev = models.FloatField(db_column='Rows_examined_stddev', blank=True, null=True) - rows_examined_median = models.FloatField(db_column='Rows_examined_median', blank=True, null=True) - rows_affected_sum = models.FloatField(db_column='Rows_affected_sum', blank=True, null=True) - rows_affected_min = models.FloatField(db_column='Rows_affected_min', blank=True, null=True) - rows_affected_max = models.FloatField(db_column='Rows_affected_max', blank=True, null=True) - rows_affected_pct_95 = models.FloatField(db_column='Rows_affected_pct_95', blank=True, null=True) - rows_affected_stddev = models.FloatField(db_column='Rows_affected_stddev', blank=True, null=True) - rows_affected_median = models.FloatField(db_column='Rows_affected_median', blank=True, null=True) - rows_read_sum = models.FloatField(db_column='Rows_read_sum', blank=True, null=True) - rows_read_min = models.FloatField(db_column='Rows_read_min', blank=True, null=True) - rows_read_max = models.FloatField(db_column='Rows_read_max', blank=True, null=True) - rows_read_pct_95 = models.FloatField(db_column='Rows_read_pct_95', blank=True, null=True) - rows_read_stddev = models.FloatField(db_column='Rows_read_stddev', blank=True, null=True) - rows_read_median = models.FloatField(db_column='Rows_read_median', blank=True, null=True) - merge_passes_sum = models.FloatField(db_column='Merge_passes_sum', blank=True, null=True) - merge_passes_min = models.FloatField(db_column='Merge_passes_min', blank=True, null=True) - merge_passes_max = models.FloatField(db_column='Merge_passes_max', blank=True, null=True) - merge_passes_pct_95 = models.FloatField(db_column='Merge_passes_pct_95', blank=True, null=True) - merge_passes_stddev = models.FloatField(db_column='Merge_passes_stddev', blank=True, null=True) - merge_passes_median = models.FloatField(db_column='Merge_passes_median', blank=True, null=True) - innodb_io_r_ops_min = models.FloatField(db_column='InnoDB_IO_r_ops_min', blank=True, null=True) - innodb_io_r_ops_max = models.FloatField(db_column='InnoDB_IO_r_ops_max', blank=True, null=True) - innodb_io_r_ops_pct_95 = models.FloatField(db_column='InnoDB_IO_r_ops_pct_95', blank=True, null=True) - innodb_io_r_ops_stddev = models.FloatField(db_column='InnoDB_IO_r_ops_stddev', blank=True, null=True) - innodb_io_r_ops_median = models.FloatField(db_column='InnoDB_IO_r_ops_median', blank=True, null=True) - innodb_io_r_bytes_min = models.FloatField(db_column='InnoDB_IO_r_bytes_min', blank=True, null=True) - innodb_io_r_bytes_max = models.FloatField(db_column='InnoDB_IO_r_bytes_max', blank=True, null=True) - innodb_io_r_bytes_pct_95 = models.FloatField(db_column='InnoDB_IO_r_bytes_pct_95', blank=True, null=True) - innodb_io_r_bytes_stddev = models.FloatField(db_column='InnoDB_IO_r_bytes_stddev', blank=True, null=True) - innodb_io_r_bytes_median = models.FloatField(db_column='InnoDB_IO_r_bytes_median', blank=True, null=True) - innodb_io_r_wait_min = models.FloatField(db_column='InnoDB_IO_r_wait_min', blank=True, null=True) - innodb_io_r_wait_max = models.FloatField(db_column='InnoDB_IO_r_wait_max', blank=True, null=True) - innodb_io_r_wait_pct_95 = models.FloatField(db_column='InnoDB_IO_r_wait_pct_95', blank=True, null=True) - innodb_io_r_wait_stddev = models.FloatField(db_column='InnoDB_IO_r_wait_stddev', blank=True, null=True) - innodb_io_r_wait_median = models.FloatField(db_column='InnoDB_IO_r_wait_median', blank=True, null=True) - innodb_rec_lock_wait_min = models.FloatField(db_column='InnoDB_rec_lock_wait_min', blank=True, null=True) - innodb_rec_lock_wait_max = models.FloatField(db_column='InnoDB_rec_lock_wait_max', blank=True, null=True) - innodb_rec_lock_wait_pct_95 = models.FloatField(db_column='InnoDB_rec_lock_wait_pct_95', blank=True, null=True) - innodb_rec_lock_wait_stddev = models.FloatField(db_column='InnoDB_rec_lock_wait_stddev', blank=True, null=True) - innodb_rec_lock_wait_median = models.FloatField(db_column='InnoDB_rec_lock_wait_median', blank=True, null=True) - innodb_queue_wait_min = models.FloatField(db_column='InnoDB_queue_wait_min', blank=True, null=True) - innodb_queue_wait_max = models.FloatField(db_column='InnoDB_queue_wait_max', blank=True, null=True) - innodb_queue_wait_pct_95 = models.FloatField(db_column='InnoDB_queue_wait_pct_95', blank=True, null=True) - innodb_queue_wait_stddev = models.FloatField(db_column='InnoDB_queue_wait_stddev', blank=True, null=True) - innodb_queue_wait_median = models.FloatField(db_column='InnoDB_queue_wait_median', blank=True, null=True) - innodb_pages_distinct_min = models.FloatField(db_column='InnoDB_pages_distinct_min', blank=True, null=True) - innodb_pages_distinct_max = models.FloatField(db_column='InnoDB_pages_distinct_max', blank=True, null=True) - innodb_pages_distinct_pct_95 = models.FloatField(db_column='InnoDB_pages_distinct_pct_95', blank=True, null=True) - innodb_pages_distinct_stddev = models.FloatField(db_column='InnoDB_pages_distinct_stddev', blank=True, null=True) - innodb_pages_distinct_median = models.FloatField(db_column='InnoDB_pages_distinct_median', blank=True, null=True) - qc_hit_cnt = models.FloatField(db_column='QC_Hit_cnt', blank=True, null=True) - qc_hit_sum = models.FloatField(db_column='QC_Hit_sum', blank=True, null=True) - full_scan_cnt = models.FloatField(db_column='Full_scan_cnt', blank=True, null=True) - full_scan_sum = models.FloatField(db_column='Full_scan_sum', blank=True, null=True) - full_join_cnt = models.FloatField(db_column='Full_join_cnt', blank=True, null=True) - full_join_sum = models.FloatField(db_column='Full_join_sum', blank=True, null=True) - tmp_table_cnt = models.FloatField(db_column='Tmp_table_cnt', blank=True, null=True) - tmp_table_sum = models.FloatField(db_column='Tmp_table_sum', blank=True, null=True) - tmp_table_on_disk_cnt = models.FloatField(db_column='Tmp_table_on_disk_cnt', blank=True, null=True) - tmp_table_on_disk_sum = models.FloatField(db_column='Tmp_table_on_disk_sum', blank=True, null=True) - filesort_cnt = models.FloatField(db_column='Filesort_cnt', blank=True, null=True) - filesort_sum = models.FloatField(db_column='Filesort_sum', blank=True, null=True) - filesort_on_disk_cnt = models.FloatField(db_column='Filesort_on_disk_cnt', blank=True, null=True) - filesort_on_disk_sum = models.FloatField(db_column='Filesort_on_disk_sum', blank=True, null=True) + query_time_sum = models.FloatField( + db_column="Query_time_sum", blank=True, null=True + ) + query_time_min = models.FloatField( + db_column="Query_time_min", blank=True, null=True + ) + query_time_max = models.FloatField( + db_column="Query_time_max", blank=True, null=True + ) + query_time_pct_95 = models.FloatField( + db_column="Query_time_pct_95", blank=True, null=True + ) + query_time_stddev = models.FloatField( + db_column="Query_time_stddev", blank=True, null=True + ) + query_time_median = models.FloatField( + db_column="Query_time_median", blank=True, null=True + ) + lock_time_sum = models.FloatField(db_column="Lock_time_sum", blank=True, null=True) + lock_time_min = models.FloatField(db_column="Lock_time_min", blank=True, null=True) + lock_time_max = models.FloatField(db_column="Lock_time_max", blank=True, null=True) + lock_time_pct_95 = models.FloatField( + db_column="Lock_time_pct_95", blank=True, null=True + ) + lock_time_stddev = models.FloatField( + db_column="Lock_time_stddev", blank=True, null=True + ) + lock_time_median = models.FloatField( + db_column="Lock_time_median", blank=True, null=True + ) + rows_sent_sum = models.FloatField(db_column="Rows_sent_sum", blank=True, null=True) + rows_sent_min = models.FloatField(db_column="Rows_sent_min", blank=True, null=True) + rows_sent_max = models.FloatField(db_column="Rows_sent_max", blank=True, null=True) + rows_sent_pct_95 = models.FloatField( + db_column="Rows_sent_pct_95", blank=True, null=True + ) + rows_sent_stddev = models.FloatField( + db_column="Rows_sent_stddev", blank=True, null=True + ) + rows_sent_median = models.FloatField( + db_column="Rows_sent_median", blank=True, null=True + ) + rows_examined_sum = models.FloatField( + db_column="Rows_examined_sum", blank=True, null=True + ) + rows_examined_min = models.FloatField( + db_column="Rows_examined_min", blank=True, null=True + ) + rows_examined_max = models.FloatField( + db_column="Rows_examined_max", blank=True, null=True + ) + rows_examined_pct_95 = models.FloatField( + db_column="Rows_examined_pct_95", blank=True, null=True + ) + rows_examined_stddev = models.FloatField( + db_column="Rows_examined_stddev", blank=True, null=True + ) + rows_examined_median = models.FloatField( + db_column="Rows_examined_median", blank=True, null=True + ) + rows_affected_sum = models.FloatField( + db_column="Rows_affected_sum", blank=True, null=True + ) + rows_affected_min = models.FloatField( + db_column="Rows_affected_min", blank=True, null=True + ) + rows_affected_max = models.FloatField( + db_column="Rows_affected_max", blank=True, null=True + ) + rows_affected_pct_95 = models.FloatField( + db_column="Rows_affected_pct_95", blank=True, null=True + ) + rows_affected_stddev = models.FloatField( + db_column="Rows_affected_stddev", blank=True, null=True + ) + rows_affected_median = models.FloatField( + db_column="Rows_affected_median", blank=True, null=True + ) + rows_read_sum = models.FloatField(db_column="Rows_read_sum", blank=True, null=True) + rows_read_min = models.FloatField(db_column="Rows_read_min", blank=True, null=True) + rows_read_max = models.FloatField(db_column="Rows_read_max", blank=True, null=True) + rows_read_pct_95 = models.FloatField( + db_column="Rows_read_pct_95", blank=True, null=True + ) + rows_read_stddev = models.FloatField( + db_column="Rows_read_stddev", blank=True, null=True + ) + rows_read_median = models.FloatField( + db_column="Rows_read_median", blank=True, null=True + ) + merge_passes_sum = models.FloatField( + db_column="Merge_passes_sum", blank=True, null=True + ) + merge_passes_min = models.FloatField( + db_column="Merge_passes_min", blank=True, null=True + ) + merge_passes_max = models.FloatField( + db_column="Merge_passes_max", blank=True, null=True + ) + merge_passes_pct_95 = models.FloatField( + db_column="Merge_passes_pct_95", blank=True, null=True + ) + merge_passes_stddev = models.FloatField( + db_column="Merge_passes_stddev", blank=True, null=True + ) + merge_passes_median = models.FloatField( + db_column="Merge_passes_median", blank=True, null=True + ) + innodb_io_r_ops_min = models.FloatField( + db_column="InnoDB_IO_r_ops_min", blank=True, null=True + ) + innodb_io_r_ops_max = models.FloatField( + db_column="InnoDB_IO_r_ops_max", blank=True, null=True + ) + innodb_io_r_ops_pct_95 = models.FloatField( + db_column="InnoDB_IO_r_ops_pct_95", blank=True, null=True + ) + innodb_io_r_ops_stddev = models.FloatField( + db_column="InnoDB_IO_r_ops_stddev", blank=True, null=True + ) + innodb_io_r_ops_median = models.FloatField( + db_column="InnoDB_IO_r_ops_median", blank=True, null=True + ) + innodb_io_r_bytes_min = models.FloatField( + db_column="InnoDB_IO_r_bytes_min", blank=True, null=True + ) + innodb_io_r_bytes_max = models.FloatField( + db_column="InnoDB_IO_r_bytes_max", blank=True, null=True + ) + innodb_io_r_bytes_pct_95 = models.FloatField( + db_column="InnoDB_IO_r_bytes_pct_95", blank=True, null=True + ) + innodb_io_r_bytes_stddev = models.FloatField( + db_column="InnoDB_IO_r_bytes_stddev", blank=True, null=True + ) + innodb_io_r_bytes_median = models.FloatField( + db_column="InnoDB_IO_r_bytes_median", blank=True, null=True + ) + innodb_io_r_wait_min = models.FloatField( + db_column="InnoDB_IO_r_wait_min", blank=True, null=True + ) + innodb_io_r_wait_max = models.FloatField( + db_column="InnoDB_IO_r_wait_max", blank=True, null=True + ) + innodb_io_r_wait_pct_95 = models.FloatField( + db_column="InnoDB_IO_r_wait_pct_95", blank=True, null=True + ) + innodb_io_r_wait_stddev = models.FloatField( + db_column="InnoDB_IO_r_wait_stddev", blank=True, null=True + ) + innodb_io_r_wait_median = models.FloatField( + db_column="InnoDB_IO_r_wait_median", blank=True, null=True + ) + innodb_rec_lock_wait_min = models.FloatField( + db_column="InnoDB_rec_lock_wait_min", blank=True, null=True + ) + innodb_rec_lock_wait_max = models.FloatField( + db_column="InnoDB_rec_lock_wait_max", blank=True, null=True + ) + innodb_rec_lock_wait_pct_95 = models.FloatField( + db_column="InnoDB_rec_lock_wait_pct_95", blank=True, null=True + ) + innodb_rec_lock_wait_stddev = models.FloatField( + db_column="InnoDB_rec_lock_wait_stddev", blank=True, null=True + ) + innodb_rec_lock_wait_median = models.FloatField( + db_column="InnoDB_rec_lock_wait_median", blank=True, null=True + ) + innodb_queue_wait_min = models.FloatField( + db_column="InnoDB_queue_wait_min", blank=True, null=True + ) + innodb_queue_wait_max = models.FloatField( + db_column="InnoDB_queue_wait_max", blank=True, null=True + ) + innodb_queue_wait_pct_95 = models.FloatField( + db_column="InnoDB_queue_wait_pct_95", blank=True, null=True + ) + innodb_queue_wait_stddev = models.FloatField( + db_column="InnoDB_queue_wait_stddev", blank=True, null=True + ) + innodb_queue_wait_median = models.FloatField( + db_column="InnoDB_queue_wait_median", blank=True, null=True + ) + innodb_pages_distinct_min = models.FloatField( + db_column="InnoDB_pages_distinct_min", blank=True, null=True + ) + innodb_pages_distinct_max = models.FloatField( + db_column="InnoDB_pages_distinct_max", blank=True, null=True + ) + innodb_pages_distinct_pct_95 = models.FloatField( + db_column="InnoDB_pages_distinct_pct_95", blank=True, null=True + ) + innodb_pages_distinct_stddev = models.FloatField( + db_column="InnoDB_pages_distinct_stddev", blank=True, null=True + ) + innodb_pages_distinct_median = models.FloatField( + db_column="InnoDB_pages_distinct_median", blank=True, null=True + ) + qc_hit_cnt = models.FloatField(db_column="QC_Hit_cnt", blank=True, null=True) + qc_hit_sum = models.FloatField(db_column="QC_Hit_sum", blank=True, null=True) + full_scan_cnt = models.FloatField(db_column="Full_scan_cnt", blank=True, null=True) + full_scan_sum = models.FloatField(db_column="Full_scan_sum", blank=True, null=True) + full_join_cnt = models.FloatField(db_column="Full_join_cnt", blank=True, null=True) + full_join_sum = models.FloatField(db_column="Full_join_sum", blank=True, null=True) + tmp_table_cnt = models.FloatField(db_column="Tmp_table_cnt", blank=True, null=True) + tmp_table_sum = models.FloatField(db_column="Tmp_table_sum", blank=True, null=True) + tmp_table_on_disk_cnt = models.FloatField( + db_column="Tmp_table_on_disk_cnt", blank=True, null=True + ) + tmp_table_on_disk_sum = models.FloatField( + db_column="Tmp_table_on_disk_sum", blank=True, null=True + ) + filesort_cnt = models.FloatField(db_column="Filesort_cnt", blank=True, null=True) + filesort_sum = models.FloatField(db_column="Filesort_sum", blank=True, null=True) + filesort_on_disk_cnt = models.FloatField( + db_column="Filesort_on_disk_cnt", blank=True, null=True + ) + filesort_on_disk_sum = models.FloatField( + db_column="Filesort_on_disk_sum", blank=True, null=True + ) class Meta: managed = False - db_table = 'mysql_slow_query_review_history' - unique_together = ('checksum', 'ts_min', 'ts_max') - index_together = ('hostname_max', 'ts_min') - verbose_name = u'慢日志明细' - verbose_name_plural = u'慢日志明细' + db_table = "mysql_slow_query_review_history" + unique_together = ("checksum", "ts_min", "ts_max") + index_together = ("hostname_max", "ts_min") + verbose_name = "慢日志明细" + verbose_name_plural = "慢日志明细" class AuditEntry(models.Model): """ 登录审计日志 """ - user_id = models.IntegerField('用户ID') - user_name = models.CharField('用户名称', max_length=30, null=True) - user_display = models.CharField('用户中文名',max_length=50, null=True) - action = models.CharField('动作', max_length=255) - extra_info = models.TextField('额外的信息', null=True) - action_time = models.DateTimeField('操作时间', auto_now_add=True) + + user_id = models.IntegerField("用户ID") + user_name = models.CharField("用户名称", max_length=30, null=True) + user_display = models.CharField("用户中文名", max_length=50, null=True) + action = models.CharField("动作", max_length=255) + extra_info = models.TextField("额外的信息", null=True) + action_time = models.DateTimeField("操作时间", auto_now_add=True) class Meta: managed = True - db_table = 'audit_log' - verbose_name = u'审计日志' - verbose_name_plural = u'审计日志' + db_table = "audit_log" + verbose_name = "审计日志" + verbose_name_plural = "审计日志" def __unicode__(self): - return '{0} - {1} - {2} - {3} - {4}'.format(self.user_id, self.user_name, self.extra_info - , self.action, self.action_time) + return "{0} - {1} - {2} - {3} - {4}".format( + self.user_id, self.user_name, self.extra_info, self.action, self.action_time + ) def __str__(self): - return '{0} - {1} - {2} - {3} - {4}'.format(self.user_id, self.user_name, self.extra_info - , self.action, self.action_time) + return "{0} - {1} - {2} - {3} - {4}".format( + self.user_id, self.user_name, self.extra_info, self.action, self.action_time + ) diff --git a/sql/notify.py b/sql/notify.py index ca70c5fd0e..5728b4a431 100755 --- a/sql/notify.py +++ b/sql/notify.py @@ -5,7 +5,13 @@ from django.contrib.auth.models import Group from common.config import SysConfig -from sql.models import QueryPrivilegesApply, Users, SqlWorkflow, ResourceGroup, ArchiveConfig +from sql.models import ( + QueryPrivilegesApply, + Users, + SqlWorkflow, + ResourceGroup, + ArchiveConfig, +) from sql.utils.resource_group import auth_group_users from common.utils.sendmsg import MsgSender from common.utils.const import WorkflowDict @@ -13,22 +19,31 @@ import logging -logger = logging.getLogger('default') +logger = logging.getLogger("default") def __notify_cnf_status(): """返回消息通知开关""" sys_config = SysConfig() - mail_status = sys_config.get('mail') - ding_status = sys_config.get('ding_to_person') - ding_webhook_status = sys_config.get('ding') - wx_status = sys_config.get('wx') + mail_status = sys_config.get("mail") + ding_status = sys_config.get("ding_to_person") + ding_webhook_status = sys_config.get("ding") + wx_status = sys_config.get("wx") qywx_webhook_status = sys_config.get("qywx_webhook") feishu_webhook_status = sys_config.get("feishu_webhook") feishu_status = sys_config.get("feishu") - if not any([mail_status, ding_status, ding_webhook_status, wx_status, feishu_status, feishu_webhook_status, - qywx_webhook_status]): - logger.info('未开启任何消息通知,可在系统设置中开启') + if not any( + [ + mail_status, + ding_status, + ding_webhook_status, + wx_status, + feishu_status, + feishu_webhook_status, + qywx_webhook_status, + ] + ): + logger.info("未开启任何消息通知,可在系统设置中开启") return False else: return True @@ -45,30 +60,41 @@ def __send(msg_title, msg_content, msg_to, msg_cc=None, **kwargs): sys_config = SysConfig() msg_sender = MsgSender() msg_cc = msg_cc if msg_cc else [] - dingding_webhook = kwargs.get('dingding_webhook') - feishu_webhook = kwargs.get('feishu_webhook') - qywx_webhook = kwargs.get('qywx_webhook') + dingding_webhook = kwargs.get("dingding_webhook") + feishu_webhook = kwargs.get("feishu_webhook") + qywx_webhook = kwargs.get("qywx_webhook") msg_to_email = [user.email for user in msg_to if user.email] msg_cc_email = [user.email for user in msg_cc if user.email] - msg_to_ding_user = [user.ding_user_id for user in chain(msg_to, msg_cc) if user.ding_user_id] - msg_to_wx_user = [user.wx_user_id if user.wx_user_id else user.username for user in chain(msg_to, msg_cc)] - logger.info(f'{msg_to_email}{msg_cc_email}{msg_to_wx_user}{chain(msg_to, msg_cc)}') - if sys_config.get('mail'): - msg_sender.send_email(msg_title, msg_content, msg_to_email, list_cc_addr=msg_cc_email) - if sys_config.get('ding') and dingding_webhook: - msg_sender.send_ding(dingding_webhook, msg_title + '\n' + msg_content) - if sys_config.get('ding_to_person'): - msg_sender.send_ding2user(msg_to_ding_user, msg_title + '\n' + msg_content) - if sys_config.get('wx'): - msg_sender.send_wx2user(msg_title + '\n' + msg_content, msg_to_wx_user) + msg_to_ding_user = [ + user.ding_user_id for user in chain(msg_to, msg_cc) if user.ding_user_id + ] + msg_to_wx_user = [ + user.wx_user_id if user.wx_user_id else user.username + for user in chain(msg_to, msg_cc) + ] + logger.info(f"{msg_to_email}{msg_cc_email}{msg_to_wx_user}{chain(msg_to, msg_cc)}") + if sys_config.get("mail"): + msg_sender.send_email( + msg_title, msg_content, msg_to_email, list_cc_addr=msg_cc_email + ) + if sys_config.get("ding") and dingding_webhook: + msg_sender.send_ding(dingding_webhook, msg_title + "\n" + msg_content) + if sys_config.get("ding_to_person"): + msg_sender.send_ding2user(msg_to_ding_user, msg_title + "\n" + msg_content) + if sys_config.get("wx"): + msg_sender.send_wx2user(msg_title + "\n" + msg_content, msg_to_wx_user) if sys_config.get("feishu_webhook") and feishu_webhook: msg_sender.send_feishu_webhook(feishu_webhook, msg_title, msg_content) if sys_config.get("feishu"): - open_id = [user.feishu_open_id for user in chain(msg_to, msg_cc) if user.feishu_open_id] - user_mail = [user.email for user in chain(msg_to, msg_cc) if not user.feishu_open_id] + open_id = [ + user.feishu_open_id for user in chain(msg_to, msg_cc) if user.feishu_open_id + ] + user_mail = [ + user.email for user in chain(msg_to, msg_cc) if not user.feishu_open_id + ] msg_sender.send_feishu_user(msg_title, msg_content, open_id, user_mail) - if sys_config.get('qywx_webhook') and qywx_webhook: - msg_sender.send_qywx_webhook(qywx_webhook, msg_title + '\n' + msg_content) + if sys_config.get("qywx_webhook") and qywx_webhook: + msg_sender.send_qywx_webhook(qywx_webhook, msg_title + "\n" + msg_content) def notify_for_audit(audit_id, **kwargs): @@ -86,71 +112,90 @@ def notify_for_audit(audit_id, **kwargs): # 获取审核信息 audit_detail = Audit.detail(audit_id=audit_id) audit_id = audit_detail.audit_id - workflow_audit_remark = kwargs.get('audit_remark', '') - base_url = sys_config.get('archery_base_url', 'http://127.0.0.1:8000').rstrip('/') - workflow_url = "{base_url}/workflow/{audit_id}".format(base_url=base_url, audit_id=audit_detail.audit_id) + workflow_audit_remark = kwargs.get("audit_remark", "") + base_url = sys_config.get("archery_base_url", "http://127.0.0.1:8000").rstrip("/") + workflow_url = "{base_url}/workflow/{audit_id}".format( + base_url=base_url, audit_id=audit_detail.audit_id + ) workflow_id = audit_detail.workflow_id workflow_type = audit_detail.workflow_type status = audit_detail.current_status workflow_title = audit_detail.workflow_title workflow_from = audit_detail.create_user_display group_name = audit_detail.group_name - dingding_webhook = ResourceGroup.objects.get(group_id=audit_detail.group_id).ding_webhook - feishu_webhook = ResourceGroup.objects.get(group_id=audit_detail.group_id).feishu_webhook - qywx_webhook = ResourceGroup.objects.get(group_id=audit_detail.group_id).qywx_webhook + dingding_webhook = ResourceGroup.objects.get( + group_id=audit_detail.group_id + ).ding_webhook + feishu_webhook = ResourceGroup.objects.get( + group_id=audit_detail.group_id + ).feishu_webhook + qywx_webhook = ResourceGroup.objects.get( + group_id=audit_detail.group_id + ).qywx_webhook # 获取当前审批和审批流程 - workflow_auditors, current_workflow_auditors = Audit.review_info(audit_detail.workflow_id, - audit_detail.workflow_type) + workflow_auditors, current_workflow_auditors = Audit.review_info( + audit_detail.workflow_id, audit_detail.workflow_type + ) # 准备消息内容 - if workflow_type == WorkflowDict.workflow_type['query']: - workflow_type_display = WorkflowDict.workflow_type['query_display'] + if workflow_type == WorkflowDict.workflow_type["query"]: + workflow_type_display = WorkflowDict.workflow_type["query_display"] workflow_detail = QueryPrivilegesApply.objects.get(apply_id=workflow_id) instance = workflow_detail.instance.instance_name - db_name = ' ' + db_name = " " if workflow_detail.priv_type == 1: - workflow_content = '''数据库清单:{}\n授权截止时间:{}\n结果集:{}\n'''.format( + workflow_content = """数据库清单:{}\n授权截止时间:{}\n结果集:{}\n""".format( workflow_detail.db_list, - datetime.datetime.strftime(workflow_detail.valid_date, '%Y-%m-%d %H:%M:%S'), - workflow_detail.limit_num) + datetime.datetime.strftime( + workflow_detail.valid_date, "%Y-%m-%d %H:%M:%S" + ), + workflow_detail.limit_num, + ) elif workflow_detail.priv_type == 2: db_name = workflow_detail.db_list - workflow_content = '''数据库:{}\n表清单:{}\n授权截止时间:{}\n结果集:{}\n'''.format( + workflow_content = """数据库:{}\n表清单:{}\n授权截止时间:{}\n结果集:{}\n""".format( workflow_detail.db_list, workflow_detail.table_list, - datetime.datetime.strftime(workflow_detail.valid_date, '%Y-%m-%d %H:%M:%S'), - workflow_detail.limit_num) + datetime.datetime.strftime( + workflow_detail.valid_date, "%Y-%m-%d %H:%M:%S" + ), + workflow_detail.limit_num, + ) else: - workflow_content = '' - elif workflow_type == WorkflowDict.workflow_type['sqlreview']: - workflow_type_display = WorkflowDict.workflow_type['sqlreview_display'] + workflow_content = "" + elif workflow_type == WorkflowDict.workflow_type["sqlreview"]: + workflow_type_display = WorkflowDict.workflow_type["sqlreview_display"] workflow_detail = SqlWorkflow.objects.get(pk=workflow_id) instance = workflow_detail.instance.instance_name db_name = workflow_detail.db_name - workflow_content = re.sub('[\r\n\f]{2,}', '\n', - workflow_detail.sqlworkflowcontent.sql_content[0:500].replace('\r', '')) - elif workflow_type == WorkflowDict.workflow_type['archive']: - workflow_type_display = WorkflowDict.workflow_type['archive_display'] + workflow_content = re.sub( + "[\r\n\f]{2,}", + "\n", + workflow_detail.sqlworkflowcontent.sql_content[0:500].replace("\r", ""), + ) + elif workflow_type == WorkflowDict.workflow_type["archive"]: + workflow_type_display = WorkflowDict.workflow_type["archive_display"] workflow_detail = ArchiveConfig.objects.get(pk=workflow_id) instance = workflow_detail.src_instance.instance_name db_name = workflow_detail.src_db_name - workflow_content = '''归档表:{}\n归档模式:{}\n归档条件:{}\n'''.format( + workflow_content = """归档表:{}\n归档模式:{}\n归档条件:{}\n""".format( workflow_detail.src_table_name, workflow_detail.mode, - workflow_detail.condition) + workflow_detail.condition, + ) else: - raise Exception('工单类型不正确') + raise Exception("工单类型不正确") # 准备消息格式 - if status == WorkflowDict.workflow_status['audit_wait']: # 申请阶段 + if status == WorkflowDict.workflow_status["audit_wait"]: # 申请阶段 msg_title = "[{}]新的工单申请#{}".format(workflow_type_display, audit_id) # 接收人,发送给该资源组内对应权限组所有的用户 auth_group_names = Group.objects.get(id=audit_detail.current_audit).name msg_to = auth_group_users([auth_group_names], audit_detail.group_id) - msg_cc = Users.objects.filter(username__in=kwargs.get('cc_users', [])) + msg_cc = Users.objects.filter(username__in=kwargs.get("cc_users", [])) # 消息内容 - msg_content = '''发起时间:{}\n发起人:{}\n组:{}\n目标实例:{}\n数据库:{}\n审批流程:{}\n当前审批:{}\n工单名称:{}\n工单地址:{}\n工单详情预览:{}\n'''.format( - workflow_detail.create_time.strftime('%Y-%m-%d %H:%M:%S'), + msg_content = """发起时间:{}\n发起人:{}\n组:{}\n目标实例:{}\n数据库:{}\n审批流程:{}\n当前审批:{}\n工单名称:{}\n工单地址:{}\n工单详情预览:{}\n""".format( + workflow_detail.create_time.strftime("%Y-%m-%d %H:%M:%S"), workflow_from, group_name, instance, @@ -159,15 +204,16 @@ def notify_for_audit(audit_id, **kwargs): current_workflow_auditors, workflow_title, workflow_url, - workflow_content) - elif status == WorkflowDict.workflow_status['audit_success']: # 审核通过 + workflow_content, + ) + elif status == WorkflowDict.workflow_status["audit_success"]: # 审核通过 msg_title = "[{}]工单审核通过#{}".format(workflow_type_display, audit_id) # 接收人,仅发送给申请人 msg_to = [Users.objects.get(username=audit_detail.create_user)] - msg_cc = Users.objects.filter(username__in=kwargs.get('cc_users', [])) + msg_cc = Users.objects.filter(username__in=kwargs.get("cc_users", [])) # 消息内容 - msg_content = '''发起时间:{}\n发起人:{}\n组:{}\n目标实例:{}\n数据库:{}\n审批流程:{}\n工单名称:{}\n工单地址:{}\n工单详情预览:{}\n'''.format( - workflow_detail.create_time.strftime('%Y-%m-%d %H:%M:%S'), + msg_content = """发起时间:{}\n发起人:{}\n组:{}\n目标实例:{}\n数据库:{}\n审批流程:{}\n工单名称:{}\n工单地址:{}\n工单详情预览:{}\n""".format( + workflow_detail.create_time.strftime("%Y-%m-%d %H:%M:%S"), workflow_from, group_name, instance, @@ -175,43 +221,55 @@ def notify_for_audit(audit_id, **kwargs): workflow_auditors, workflow_title, workflow_url, - workflow_content) - elif status == WorkflowDict.workflow_status['audit_reject']: # 审核驳回 + workflow_content, + ) + elif status == WorkflowDict.workflow_status["audit_reject"]: # 审核驳回 msg_title = "[{}]工单被驳回#{}".format(workflow_type_display, audit_id) # 接收人,仅发送给申请人 msg_to = [Users.objects.get(username=audit_detail.create_user)] - msg_cc = Users.objects.filter(username__in=kwargs.get('cc_users', [])) + msg_cc = Users.objects.filter(username__in=kwargs.get("cc_users", [])) # 消息内容 - msg_content = '''发起时间:{}\n目标实例:{}\n数据库:{}\n工单名称:{}\n工单地址:{}\n驳回原因:{}\n提醒:此工单被审核不通过,请按照驳回原因进行修改!'''.format( - workflow_detail.create_time.strftime('%Y-%m-%d %H:%M:%S'), + msg_content = """发起时间:{}\n目标实例:{}\n数据库:{}\n工单名称:{}\n工单地址:{}\n驳回原因:{}\n提醒:此工单被审核不通过,请按照驳回原因进行修改!""".format( + workflow_detail.create_time.strftime("%Y-%m-%d %H:%M:%S"), instance, db_name, workflow_title, workflow_url, - re.sub('[\r\n\f]{2,}', '\n', workflow_audit_remark)) - elif status == WorkflowDict.workflow_status['audit_abort']: # 审核取消,通知所有审核人 + re.sub("[\r\n\f]{2,}", "\n", workflow_audit_remark), + ) + elif status == WorkflowDict.workflow_status["audit_abort"]: # 审核取消,通知所有审核人 msg_title = "[{}]提交人主动终止工单#{}".format(workflow_type_display, audit_id) # 接收人,发送给该资源组内对应权限组所有的用户 - auth_group_names = [Group.objects.get(id=auth_group_id).name for auth_group_id in - audit_detail.audit_auth_groups.split(',')] + auth_group_names = [ + Group.objects.get(id=auth_group_id).name + for auth_group_id in audit_detail.audit_auth_groups.split(",") + ] msg_to = auth_group_users(auth_group_names, audit_detail.group_id) - msg_cc = Users.objects.filter(username__in=kwargs.get('cc_users', [])) + msg_cc = Users.objects.filter(username__in=kwargs.get("cc_users", [])) # 消息内容 - msg_content = '''发起时间:{}\n发起人:{}\n组:{}\n目标实例:{}\n数据库:{}\n工单名称:{}\n工单地址:{}\n终止原因:{}'''.format( - workflow_detail.create_time.strftime('%Y-%m-%d %H:%M:%S'), + msg_content = """发起时间:{}\n发起人:{}\n组:{}\n目标实例:{}\n数据库:{}\n工单名称:{}\n工单地址:{}\n终止原因:{}""".format( + workflow_detail.create_time.strftime("%Y-%m-%d %H:%M:%S"), workflow_from, group_name, instance, db_name, workflow_title, workflow_url, - re.sub('[\r\n\f]{2,}', '\n', workflow_audit_remark)) + re.sub("[\r\n\f]{2,}", "\n", workflow_audit_remark), + ) else: - raise Exception('工单状态不正确') + raise Exception("工单状态不正确") logger.info(f"通知Debug{msg_to}{msg_cc}") # 发送通知 - __send(msg_title, msg_content, msg_to, msg_cc, feishu_webhook=feishu_webhook, dingding_webhook=dingding_webhook, - qywx_webhook=qywx_webhook) + __send( + msg_title, + msg_content, + msg_to, + msg_cc, + feishu_webhook=feishu_webhook, + dingding_webhook=dingding_webhook, + qywx_webhook=qywx_webhook, + ) def notify_for_execute(workflow): @@ -226,14 +284,17 @@ def notify_for_execute(workflow): sys_config = SysConfig() # 获取当前审批和审批流程 - base_url = sys_config.get('archery_base_url', 'http://127.0.0.1:8000').rstrip('/') + base_url = sys_config.get("archery_base_url", "http://127.0.0.1:8000").rstrip("/") audit_auth_group, current_audit_auth_group = Audit.review_info(workflow.id, 2) audit_id = Audit.detail_by_workflow_id(workflow.id, 2).audit_id url = "{base_url}/workflow/{audit_id}".format(base_url=base_url, audit_id=audit_id) - msg_title = "[{}]工单{}#{}".format(WorkflowDict.workflow_type['sqlreview_display'], - workflow.get_status_display(), audit_id) - msg_content = '''发起时间:{}\n发起人:{}\n组:{}\n目标实例:{}\n数据库:{}\n审批流程:{}\n工单名称:{}\n工单地址:{}\n工单详情预览:{}\n'''.format( - workflow.create_time.strftime('%Y-%m-%d %H:%M:%S'), + msg_title = "[{}]工单{}#{}".format( + WorkflowDict.workflow_type["sqlreview_display"], + workflow.get_status_display(), + audit_id, + ) + msg_content = """发起时间:{}\n发起人:{}\n组:{}\n目标实例:{}\n数据库:{}\n审批流程:{}\n工单名称:{}\n工单地址:{}\n工单详情预览:{}\n""".format( + workflow.create_time.strftime("%Y-%m-%d %H:%M:%S"), workflow.engineer_display, workflow.group_name, workflow.instance.instance_name, @@ -241,35 +302,54 @@ def notify_for_execute(workflow): audit_auth_group, workflow.workflow_name, url, - re.sub('[\r\n\f]{2,}', '\n', workflow.sqlworkflowcontent.sql_content[0:500].replace('\r', ''))) + re.sub( + "[\r\n\f]{2,}", + "\n", + workflow.sqlworkflowcontent.sql_content[0:500].replace("\r", ""), + ), + ) # 邮件通知申请人,抄送DBA msg_to = Users.objects.filter(username=workflow.engineer) - msg_cc = auth_group_users(auth_group_names=['DBA'], group_id=workflow.group_id) + msg_cc = auth_group_users(auth_group_names=["DBA"], group_id=workflow.group_id) # 处理接收人 - dingding_webhook = ResourceGroup.objects.get(group_id=workflow.group_id).ding_webhook - feishu_webhook = ResourceGroup.objects.get(group_id=workflow.group_id).feishu_webhook + dingding_webhook = ResourceGroup.objects.get( + group_id=workflow.group_id + ).ding_webhook + feishu_webhook = ResourceGroup.objects.get( + group_id=workflow.group_id + ).feishu_webhook qywx_webhook = ResourceGroup.objects.get(group_id=workflow.group_id).qywx_webhook # 发送通知 - __send(msg_title, msg_content, msg_to, msg_cc, dingding_webhook=dingding_webhook, feishu_webhook=feishu_webhook, - qywx_webhook=qywx_webhook) + __send( + msg_title, + msg_content, + msg_to, + msg_cc, + dingding_webhook=dingding_webhook, + feishu_webhook=feishu_webhook, + qywx_webhook=qywx_webhook, + ) # DDL通知 - if sys_config.get('ddl_notify_auth_group') and workflow.status == 'workflow_finish': + if sys_config.get("ddl_notify_auth_group") and workflow.status == "workflow_finish": # 判断上线语句是否存在DDL,存在则通知相关人员 if workflow.syntax_type == 1: # 消息内容通知 - msg_title = '[Archery]有新的DDL语句执行完成#{}'.format(audit_id) - msg_content = '''发起人:{}\n变更组:{}\n变更实例:{}\n变更数据库:{}\n工单名称:{}\n工单地址:{}\n工单预览:{}\n'''.format( + msg_title = "[Archery]有新的DDL语句执行完成#{}".format(audit_id) + msg_content = """发起人:{}\n变更组:{}\n变更实例:{}\n变更数据库:{}\n工单名称:{}\n工单地址:{}\n工单预览:{}\n""".format( Users.objects.get(username=workflow.engineer).display, workflow.group_name, workflow.instance.instance_name, workflow.db_name, workflow.workflow_name, url, - workflow.sqlworkflowcontent.sql_content[0:500]) + workflow.sqlworkflowcontent.sql_content[0:500], + ) # 获取通知成员ddl_notify_auth_group - ddl_notify_auth_group = sys_config.get('ddl_notify_auth_group', '').split(',') + ddl_notify_auth_group = sys_config.get("ddl_notify_auth_group", "").split( + "," + ) msg_to = Users.objects.filter(groups__name__in=ddl_notify_auth_group) # 发送通知 __send(msg_title, msg_content, msg_to, msg_cc) @@ -285,11 +365,11 @@ def notify_for_my2sql(task): if not __notify_cnf_status(): return None if task.success: - msg_title = '[Archery 通知]My2SQL执行结束' - msg_content = f'解析的SQL文件在{task.result[1]}目录下,请前往查看' + msg_title = "[Archery 通知]My2SQL执行结束" + msg_content = f"解析的SQL文件在{task.result[1]}目录下,请前往查看" else: - msg_title = '[Archery 通知]My2SQL执行失败' - msg_content = f'{task.result}' + msg_title = "[Archery 通知]My2SQL执行失败" + msg_content = f"{task.result}" # 发送 - msg_to = [task.kwargs['user']] + msg_to = [task.kwargs["user"]] __send(msg_title, msg_content, msg_to) diff --git a/sql/plugins/my2sql.py b/sql/plugins/my2sql.py index 7b3457ef23..5546cc5f99 100644 --- a/sql/plugins/my2sql.py +++ b/sql/plugins/my2sql.py @@ -5,9 +5,8 @@ class My2SQL(Plugin): - def __init__(self): - self.path = SysConfig().get('my2sql') + self.path = SysConfig().get("my2sql") self.required_args = [] self.disable_args = [] super(Plugin, self).__init__() @@ -19,34 +18,50 @@ def generate_args2cmd(self, args, shell): :param shell: :return: """ - conn_options = ['conn_options'] - args_options = ['work-type', 'threads', 'start-file', 'stop-file', 'start-pos', - 'stop-pos', 'databases', 'tables', 'sql', 'output-dir'] - no_args_options = ['output-toScreen', 'add-extraInfo', 'ignore-primaryKey-forInsert', - 'full-columns', 'do-not-add-prifixDb', 'file-per-table'] - datetime_options = ['start-datetime', 'stop-datetime'] + conn_options = ["conn_options"] + args_options = [ + "work-type", + "threads", + "start-file", + "stop-file", + "start-pos", + "stop-pos", + "databases", + "tables", + "sql", + "output-dir", + ] + no_args_options = [ + "output-toScreen", + "add-extraInfo", + "ignore-primaryKey-forInsert", + "full-columns", + "do-not-add-prifixDb", + "file-per-table", + ] + datetime_options = ["start-datetime", "stop-datetime"] if shell: - cmd_args = f'{shlex.quote(str(self.path))}' if self.path else '' + cmd_args = f"{shlex.quote(str(self.path))}" if self.path else "" for name, value in args.items(): if name in conn_options: - cmd_args += f' {value}' + cmd_args += f" {value}" elif name in args_options and value: - cmd_args += f' -{name} {shlex.quote(str(value))}' + cmd_args += f" -{name} {shlex.quote(str(value))}" elif name in datetime_options and value: cmd_args += f" -{name} '{shlex.quote(str(value))}'" elif name in no_args_options and value: - cmd_args += f' -{name}' + cmd_args += f" -{name}" else: cmd_args = [self.path] for name, value in args.items(): if name in conn_options: - cmd_args.append(f'{value}') + cmd_args.append(f"{value}") elif name in args_options: - cmd_args.append(f'-{name}') - cmd_args.append(f'{value}') + cmd_args.append(f"-{name}") + cmd_args.append(f"{value}") elif name in datetime_options: - cmd_args.append(f'-{name}') + cmd_args.append(f"-{name}") cmd_args.append(f"'{value}'") elif name in no_args_options: - cmd_args.append(f'-{name}') + cmd_args.append(f"-{name}") return cmd_args diff --git a/sql/plugins/plugin.py b/sql/plugins/plugin.py index 3cea862226..102b4c46d8 100644 --- a/sql/plugins/plugin.py +++ b/sql/plugins/plugin.py @@ -5,13 +5,13 @@ @file: plugin.py @time: 2019/03/04 """ -__author__ = 'hhyo' +__author__ = "hhyo" import logging import subprocess import traceback -logger = logging.getLogger('default') +logger = logging.getLogger("default") class Plugin: @@ -25,20 +25,28 @@ def check_args(self, args): 检查请求参数列表 :return: {'status': 0, 'msg': 'ok', 'data': {}} """ - args_check_result = {'status': 0, 'msg': 'ok', 'data': {}} + args_check_result = {"status": 0, "msg": "ok", "data": {}} # 检查路径 if self.path is None: - return {'status': 1, 'msg': '可执行文件路径不能为空!', 'data': {}} + return {"status": 1, "msg": "可执行文件路径不能为空!", "data": {}} # 检查禁用参数 for arg in args.keys(): if arg in self.disable_args: - return {'status': 1, 'msg': '{arg}参数已被禁用'.format(arg=arg), 'data': {}} + return {"status": 1, "msg": "{arg}参数已被禁用".format(arg=arg), "data": {}} # 检查必须参数 for req_arg in self.required_args: if req_arg not in args.keys(): - return {'status': 1, 'msg': '必须指定{arg}参数'.format(arg=req_arg), 'data': {}} - elif args[req_arg] is None or args[req_arg] == '': - return {'status': 1, 'msg': '{arg}参数值不能为空'.format(arg=req_arg), 'data': {}} + return { + "status": 1, + "msg": "必须指定{arg}参数".format(arg=req_arg), + "data": {}, + } + elif args[req_arg] is None or args[req_arg] == "": + return { + "status": 1, + "msg": "{arg}参数值不能为空".format(arg=req_arg), + "data": {}, + } return args_check_result def generate_args2cmd(self, args, shell): @@ -54,12 +62,14 @@ def execute_cmd(cmd_args, shell): :return: """ try: - p = subprocess.Popen(cmd_args, - shell=shell, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - universal_newlines=True) + p = subprocess.Popen( + cmd_args, + shell=shell, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + ) return p except Exception as e: logger.error("命令执行失败\n{}".format(traceback.format_exc())) - raise RuntimeError('命令执行失败,失败原因:%s' % str(e)) + raise RuntimeError("命令执行失败,失败原因:%s" % str(e)) diff --git a/sql/plugins/pt_archiver.py b/sql/plugins/pt_archiver.py index c5785c69b5..d84ed2254f 100644 --- a/sql/plugins/pt_archiver.py +++ b/sql/plugins/pt_archiver.py @@ -8,7 +8,7 @@ from common.config import SysConfig from sql.plugins.plugin import Plugin -__author__ = 'hhyo' +__author__ = "hhyo" class PtArchiver(Plugin): @@ -17,9 +17,9 @@ class PtArchiver(Plugin): """ def __init__(self): - self.path = 'pt-archiver' + self.path = "pt-archiver" self.required_args = [] - self.disable_args = ['analyze'] + self.disable_args = ["analyze"] super(Plugin, self).__init__() def generate_args2cmd(self, args, shell): @@ -29,24 +29,41 @@ def generate_args2cmd(self, args, shell): :param shell: :return: """ - k_options = ['no-version-check', 'statistics', 'bulk-insert', 'bulk-delete', 'purge', 'no-delete'] - kv_options = ['source', 'dest', 'file', 'where', 'progress', 'charset', 'limit', 'txn-size', 'sleep'] + k_options = [ + "no-version-check", + "statistics", + "bulk-insert", + "bulk-delete", + "purge", + "no-delete", + ] + kv_options = [ + "source", + "dest", + "file", + "where", + "progress", + "charset", + "limit", + "txn-size", + "sleep", + ] if shell: - cmd_args = self.path if self.path else '' + cmd_args = self.path if self.path else "" for name, value in args.items(): if name in k_options and value: - cmd_args += f' --{name}' + cmd_args += f" --{name}" elif name in kv_options: - if name == 'where': + if name == "where": cmd_args += f' --{name} "{value}"' else: - cmd_args += f' --{name} {value}' + cmd_args += f" --{name} {value}" else: cmd_args = [self.path] for name, value in args.items(): if name in k_options and value: - cmd_args.append(f'--{name}') + cmd_args.append(f"--{name}") elif name in kv_options: - cmd_args.append(f'--{name}') - cmd_args.append(f'{value}') + cmd_args.append(f"--{name}") + cmd_args.append(f"{value}") return cmd_args diff --git a/sql/plugins/schemasync.py b/sql/plugins/schemasync.py index c4dbf50a55..826299b9de 100644 --- a/sql/plugins/schemasync.py +++ b/sql/plugins/schemasync.py @@ -5,16 +5,15 @@ @file: schemasync.py @time: 2019/03/05 """ -__author__ = 'hhyo' +__author__ = "hhyo" import shlex from sql.plugins.plugin import Plugin class SchemaSync(Plugin): - def __init__(self): - self.path = 'schemasync' + self.path = "schemasync" self.required_args = [] self.disable_args = [] super(Plugin, self).__init__() @@ -26,26 +25,26 @@ def generate_args2cmd(self, args, shell): :param shell: :return: """ - k_options = ['sync-auto-inc', 'sync-comments'] - kv_options = ['tag', 'output-directory', 'log-directory'] - v_options = ['source', 'target'] + k_options = ["sync-auto-inc", "sync-comments"] + kv_options = ["tag", "output-directory", "log-directory"] + v_options = ["source", "target"] if shell: - cmd_args = self.path if self.path else '' + cmd_args = self.path if self.path else "" for name, value in args.items(): if name in k_options and value: - cmd_args += f' --{name}' + cmd_args += f" --{name}" elif name in kv_options: - cmd_args += f' --{name}={shlex.quote(str(value))}' + cmd_args += f" --{name}={shlex.quote(str(value))}" elif name in v_options: - cmd_args += f' {value}' + cmd_args += f" {value}" else: cmd_args = [self.path] for name, value in args.items(): if name in k_options and value: - cmd_args.append(f'--{name}') + cmd_args.append(f"--{name}") elif name in kv_options: - cmd_args.append(f'--{name}') - cmd_args.append(f'{value}') - elif name in ['source', 'target']: - cmd_args.append(f'{value}') + cmd_args.append(f"--{name}") + cmd_args.append(f"{value}") + elif name in ["source", "target"]: + cmd_args.append(f"{value}") return cmd_args diff --git a/sql/plugins/soar.py b/sql/plugins/soar.py index 6220c59b33..18844a8188 100644 --- a/sql/plugins/soar.py +++ b/sql/plugins/soar.py @@ -5,7 +5,7 @@ @file: soar.py @time: 2019/03/04 """ -__author__ = 'hhyo' +__author__ = "hhyo" import shlex from common.config import SysConfig @@ -13,10 +13,9 @@ class Soar(Plugin): - def __init__(self): - self.path = SysConfig().get('soar') - self.required_args = ['query'] + self.path = SysConfig().get("soar") + self.required_args = ["query"] self.disable_args = [] super(Plugin, self).__init__() @@ -28,14 +27,14 @@ def generate_args2cmd(self, args, shell): :return: """ if shell: - cmd_args = shlex.quote(str(self.path)) if self.path else '' + cmd_args = shlex.quote(str(self.path)) if self.path else "" for name, value in args.items(): cmd_args += f" -{name}={shlex.quote(str(value))}" else: cmd_args = [self.path] for name, value in args.items(): - cmd_args.append(f'-{name}') - cmd_args.append(f'{value}') + cmd_args.append(f"-{name}") + cmd_args.append(f"{value}") return cmd_args def fingerprint(self, sql): @@ -44,10 +43,7 @@ def fingerprint(self, sql): :param sql: :return: """ - args = { - "query": sql, - "report-type": "fingerprint" - } + args = {"query": sql, "report-type": "fingerprint"} cmd_args = self.generate_args2cmd(args, shell=True) return self.execute_cmd(cmd_args=cmd_args, shell=True) @@ -57,10 +53,7 @@ def compress(self, sql): :param sql: :return: """ - args = { - "query": sql, - "report-type": "compress" - } + args = {"query": sql, "report-type": "compress"} cmd_args = self.generate_args2cmd(args, shell=True) return self.execute_cmd(cmd_args=cmd_args, shell=True) @@ -73,7 +66,7 @@ def pretty(self, sql): args = { "query": sql, "max-pretty-sql-length": 100000, # 超出该长度的SQL会转换成指纹输出 (default 1024) - "report-type": "pretty" + "report-type": "pretty", } cmd_args = self.generate_args2cmd(args, shell=True) return self.execute_cmd(cmd_args=cmd_args, shell=True) @@ -84,10 +77,7 @@ def remove_comment(self, sql): :param sql: :return: """ - args = { - "query": sql, - "report-type": "remove-comment" - } + args = {"query": sql, "report-type": "remove-comment"} cmd_args = self.generate_args2cmd(args, shell=True) return self.execute_cmd(cmd_args=cmd_args, shell=True) @@ -98,17 +88,34 @@ def rewrite(self, sql, rewrite_rules=None): :param rewrite_rules: :return: """ - rewrite_type_list = ['dml2select', 'star2columns', 'insertcolumns', 'having', 'orderbynull', 'unionall', - 'or2in', 'dmlorderby', 'distinctstar', 'standard', 'mergealter', 'alwaystrue', - 'countstar', 'innodb', 'autoincrement', 'intwidth', 'truncate', 'rmparenthesis', - 'delimiter'] - rewrite_rules = rewrite_rules if rewrite_rules else ['dml2select'] + rewrite_type_list = [ + "dml2select", + "star2columns", + "insertcolumns", + "having", + "orderbynull", + "unionall", + "or2in", + "dmlorderby", + "distinctstar", + "standard", + "mergealter", + "alwaystrue", + "countstar", + "innodb", + "autoincrement", + "intwidth", + "truncate", + "rmparenthesis", + "delimiter", + ] + rewrite_rules = rewrite_rules if rewrite_rules else ["dml2select"] if set(rewrite_rules).issubset(set(rewrite_type_list)) is False: - raise RuntimeError(f'不支持的改写规则,仅支持{rewrite_type_list}') + raise RuntimeError(f"不支持的改写规则,仅支持{rewrite_type_list}") args = { "query": sql, "report-type": "rewrite", - "rewrite-rules": ','.join(rewrite_type_list) + "rewrite-rules": ",".join(rewrite_type_list), } cmd_args = self.generate_args2cmd(args, shell=True) return self.execute_cmd(cmd_args=cmd_args, shell=True) @@ -119,9 +126,6 @@ def query_tree(self, sql): :param sql: :return: """ - args = { - "query": sql, - "report-type": "ast-json" - } + args = {"query": sql, "report-type": "ast-json"} cmd_args = self.generate_args2cmd(args, shell=True) return self.execute_cmd(cmd_args=cmd_args, shell=True) diff --git a/sql/plugins/sqladvisor.py b/sql/plugins/sqladvisor.py index 744a11f2b4..6130661eb7 100644 --- a/sql/plugins/sqladvisor.py +++ b/sql/plugins/sqladvisor.py @@ -5,7 +5,7 @@ @file: sqladvisor.py @time: 2019/03/04 """ -__author__ = 'hhyo' +__author__ = "hhyo" import shlex from common.config import SysConfig @@ -14,8 +14,8 @@ class SQLAdvisor(Plugin): def __init__(self): - self.path = SysConfig().get('sqladvisor') - self.required_args = ['q'] + self.path = SysConfig().get("sqladvisor") + self.required_args = ["q"] self.disable_args = [] super(Plugin, self).__init__() @@ -27,12 +27,12 @@ def generate_args2cmd(self, args, shell): :return: """ if shell: - cmd_args = shlex.quote(str(self.path)) if self.path else '' + cmd_args = shlex.quote(str(self.path)) if self.path else "" for name, value in args.items(): cmd_args += f" -{name} {shlex.quote(str(value))}" else: cmd_args = [self.path] for name, value in args.items(): - cmd_args.append(f'-{name}') - cmd_args.append(f'{value}') + cmd_args.append(f"-{name}") + cmd_args.append(f"{value}") return cmd_args diff --git a/sql/plugins/tests.py b/sql/plugins/tests.py index 59d18db220..241a92403d 100644 --- a/sql/plugins/tests.py +++ b/sql/plugins/tests.py @@ -20,7 +20,7 @@ User = get_user_model() -__author__ = 'hhyo' +__author__ = "hhyo" class TestPlugin(TestCase): @@ -30,7 +30,7 @@ class TestPlugin(TestCase): @classmethod def setUpClass(cls): - cls.superuser = User(username='super', is_superuser=True) + cls.superuser = User(username="super", is_superuser=True) cls.superuser.save() cls.sys_config = SysConfig() cls.client = Client() @@ -46,74 +46,87 @@ def test_check_args_path(self): 测试路径 :return: """ - args = {"online-dsn": '', - "test-dsn": '', - "allow-online-as-test": "false", - "report-type": "markdown", - "query": "select 1;" - } - self.sys_config.set('soar', '') + args = { + "online-dsn": "", + "test-dsn": "", + "allow-online-as-test": "false", + "report-type": "markdown", + "query": "select 1;", + } + self.sys_config.set("soar", "") self.sys_config.get_all_config() soar = Soar() args_check_result = soar.check_args(args) - self.assertDictEqual(args_check_result, {'status': 1, 'msg': '可执行文件路径不能为空!', 'data': {}}) + self.assertDictEqual( + args_check_result, {"status": 1, "msg": "可执行文件路径不能为空!", "data": {}} + ) # 路径不为空 - self.sys_config.set('soar', '/opt/archery/src/plugins/soar') + self.sys_config.set("soar", "/opt/archery/src/plugins/soar") self.sys_config.get_all_config() soar = Soar() args_check_result = soar.check_args(args) - self.assertDictEqual(args_check_result, {'status': 0, 'msg': 'ok', 'data': {}}) + self.assertDictEqual(args_check_result, {"status": 0, "msg": "ok", "data": {}}) def test_check_args_disable(self): """ 测试禁用参数 :return: """ - args = {"online-dsn": '', - "test-dsn": '', - "allow-online-as-test": "false", - "report-type": "markdown", - "query": "select 1;" - } - self.sys_config.set('soar', '/opt/archery/src/plugins/soar') + args = { + "online-dsn": "", + "test-dsn": "", + "allow-online-as-test": "false", + "report-type": "markdown", + "query": "select 1;", + } + self.sys_config.set("soar", "/opt/archery/src/plugins/soar") self.sys_config.get_all_config() soar = Soar() - soar.disable_args = ['allow-online-as-test'] + soar.disable_args = ["allow-online-as-test"] args_check_result = soar.check_args(args) - self.assertDictEqual(args_check_result, {'status': 1, 'msg': 'allow-online-as-test参数已被禁用', 'data': {}}) + self.assertDictEqual( + args_check_result, + {"status": 1, "msg": "allow-online-as-test参数已被禁用", "data": {}}, + ) def test_check_args_required(self): """ 测试必选参数 :return: """ - args = {"online-dsn": '', - "test-dsn": '', - "allow-online-as-test": "false", - "report-type": "markdown", - } - self.sys_config.set('soar', '/opt/archery/src/plugins/soar') + args = { + "online-dsn": "", + "test-dsn": "", + "allow-online-as-test": "false", + "report-type": "markdown", + } + self.sys_config.set("soar", "/opt/archery/src/plugins/soar") self.sys_config.get_all_config() soar = Soar() - soar.required_args = ['query'] + soar.required_args = ["query"] args_check_result = soar.check_args(args) - self.assertDictEqual(args_check_result, {'status': 1, 'msg': '必须指定query参数', 'data': {}}) - args['query'] = "" + self.assertDictEqual( + args_check_result, {"status": 1, "msg": "必须指定query参数", "data": {}} + ) + args["query"] = "" args_check_result = soar.check_args(args) - self.assertDictEqual(args_check_result, {'status': 1, 'msg': 'query参数值不能为空', 'data': {}}) + self.assertDictEqual( + args_check_result, {"status": 1, "msg": "query参数值不能为空", "data": {}} + ) def test_soar_generate_args2cmd(self): """ 测试SOAR参数转换 :return: """ - args = {"online-dsn": '', - "test-dsn": '', - "allow-online-as-test": "false", - "report-type": "markdown", - "query": "select 1;" - } - self.sys_config.set('soar', '/opt/archery/src/plugins/soar') + args = { + "online-dsn": "", + "test-dsn": "", + "allow-online-as-test": "false", + "report-type": "markdown", + "query": "select 1;", + } + self.sys_config.set("soar", "/opt/archery/src/plugins/soar") self.sys_config.get_all_config() soar = Soar() cmd_args = soar.generate_args2cmd(args, False) @@ -126,15 +139,16 @@ def test_sql_advisor_generate_args2cmd(self): 测试sql_advisor参数转换 :return: """ - args = {"h": 'mysql', - "P": 3306, - "u": 'root', - "p": '', - "d": 'archery', - "v": 1, - "q": 'select 1;' - } - self.sys_config.set('sqladvisor', '/opt/archery/src/plugins/SQLAdvisor') + args = { + "h": "mysql", + "P": 3306, + "u": "root", + "p": "", + "d": "archery", + "v": 1, + "q": "select 1;", + } + self.sys_config.set("sqladvisor", "/opt/archery/src/plugins/SQLAdvisor") self.sys_config.get_all_config() sql_advisor = SQLAdvisor() cmd_args = sql_advisor.generate_args2cmd(args, False) @@ -150,20 +164,16 @@ def test_schema_sync_generate_args2cmd(self): args = { "sync-auto-inc": True, "sync-comments": True, - "tag": 'tag_v', - "output-directory": '', - "source": r"mysql://{user}:{pwd}@{host}:{port}/{database}".format(user='root', - pwd='123456', - host='127.0.0.1', - port=3306, - database='*'), - "target": r"mysql://{user}:{pwd}@{host}:{port}/{database}".format(user='root', - pwd='123456', - host='127.0.0.1', - port=3306, - database='*') + "tag": "tag_v", + "output-directory": "", + "source": r"mysql://{user}:{pwd}@{host}:{port}/{database}".format( + user="root", pwd="123456", host="127.0.0.1", port=3306, database="*" + ), + "target": r"mysql://{user}:{pwd}@{host}:{port}/{database}".format( + user="root", pwd="123456", host="127.0.0.1", port=3306, database="*" + ), } - self.sys_config.set('schemasync', '/opt/venv4schemasync/bin/schemasync') + self.sys_config.set("schemasync", "/opt/venv4schemasync/bin/schemasync") self.sys_config.get_all_config() schema_sync = SchemaSync() cmd_args = schema_sync.generate_args2cmd(args, False) @@ -176,24 +186,26 @@ def test_my2sql_generate_args2cmd(self): 测试my2sql参数转换 :return: """ - args = {'conn_options': "-host mysql -user root -password '123456' -port 3306 ", - 'work-type': '2sql', - 'start-file': 'mysql-bin.000043', - 'start-pos': 111, - 'stop-file': '', - 'stop-pos': '', - 'start-datetime': '', - 'stop-datetime': '', - 'databases': 'account_center', - 'tables': 'ac_apps', - 'sql': 'update', - "threads": 1, - "add-extraInfo": "false", - "ignore-primaryKey-forInsert": "false", - "full-columns": "false", - "do-not-add-prifixDb": "false", - "file-per-table": "false"} - self.sys_config.set('my2sql', '/opt/archery/src/plugins/my2sql') + args = { + "conn_options": "-host mysql -user root -password '123456' -port 3306 ", + "work-type": "2sql", + "start-file": "mysql-bin.000043", + "start-pos": 111, + "stop-file": "", + "stop-pos": "", + "start-datetime": "", + "stop-datetime": "", + "databases": "account_center", + "tables": "ac_apps", + "sql": "update", + "threads": 1, + "add-extraInfo": "false", + "ignore-primaryKey-forInsert": "false", + "full-columns": "false", + "do-not-add-prifixDb": "false", + "file-per-table": "false", + } + self.sys_config.set("my2sql", "/opt/archery/src/plugins/my2sql") self.sys_config.get_all_config() my2sql = My2SQL() cmd_args = my2sql.generate_args2cmd(args, False) @@ -208,14 +220,14 @@ def test_pt_archiver_generate_args2cmd(self): """ args = { "no-version-check": True, - "source": '', - "where": '', + "source": "", + "where": "", "progress": 5000, "statistics": True, - "charset": 'UTF8', + "charset": "UTF8", "limit": 10000, "txn-size": 1000, - "sleep": 1 + "sleep": 1, } pt_archiver = PtArchiver() cmd_args = pt_archiver.generate_args2cmd(args, False) @@ -223,32 +235,32 @@ def test_pt_archiver_generate_args2cmd(self): cmd_args = pt_archiver.generate_args2cmd(args, True) self.assertIsInstance(cmd_args, str) - @patch('sql.plugins.plugin.subprocess') + @patch("sql.plugins.plugin.subprocess") def test_execute_cmd(self, mock_subprocess): - args = {"online-dsn": '', - "test-dsn": '', - "allow-online-as-test": "false", - "report-type": "markdown", - "query": "select 1;" - } - self.sys_config.set('soar', '/opt/archery/src/plugins/soar') + args = { + "online-dsn": "", + "test-dsn": "", + "allow-online-as-test": "false", + "report-type": "markdown", + "query": "select 1;", + } + self.sys_config.set("soar", "/opt/archery/src/plugins/soar") self.sys_config.get_all_config() soar = Soar() cmd_args = soar.generate_args2cmd(args, True) - mock_subprocess.Popen.return_value.communicate.return_value = ('some_stdout', 'some_stderr') + mock_subprocess.Popen.return_value.communicate.return_value = ( + "some_stdout", + "some_stderr", + ) stdout, stderr = soar.execute_cmd(cmd_args, True).communicate() mock_subprocess.Popen.assert_called_once_with( - cmd_args, - shell=True, - stdout=ANY, - stderr=ANY, - universal_newlines=ANY + cmd_args, shell=True, stdout=ANY, stderr=ANY, universal_newlines=ANY ) - self.assertIn('some_stdout', stdout) + self.assertIn("some_stdout", stdout) # 异常 - mock_subprocess.Popen.side_effect = Exception('Boom! some exception!') + mock_subprocess.Popen.side_effect = Exception("Boom! some exception!") with self.assertRaises(RuntimeError): soar.execute_cmd(cmd_args, False) @@ -260,13 +272,13 @@ class TestSoar(TestCase): @classmethod def setUpClass(cls): - soar_path = '/opt/archery/src/plugins/soar' # 修改为本机的soar路径 - cls.superuser = User(username='super', is_superuser=True) + soar_path = "/opt/archery/src/plugins/soar" # 修改为本机的soar路径 + cls.superuser = User(username="super", is_superuser=True) cls.superuser.save() cls.client = Client() cls.client.force_login(cls.superuser) cls.sys_config = SysConfig() - cls.sys_config.set('soar', soar_path) + cls.sys_config.set("soar", soar_path) cls.sys_config.get_all_config() cls.soar = Soar() @@ -327,4 +339,4 @@ def test_rewrite(self): self.soar.rewrite(sql) # 异常测试 with self.assertRaises(RuntimeError): - self.soar.rewrite(sql, 'unknown') + self.soar.rewrite(sql, "unknown") diff --git a/sql/query.py b/sql/query.py index 8cd67af541..83583d8c0e 100644 --- a/sql/query.py +++ b/sql/query.py @@ -19,64 +19,66 @@ from .models import QueryLog, Instance from sql.engines import get_engine -logger = logging.getLogger('default') +logger = logging.getLogger("default") -@permission_required('sql.query_submit', raise_exception=True) +@permission_required("sql.query_submit", raise_exception=True) def query(request): """ 获取SQL查询结果 :param request: :return: """ - instance_name = request.POST.get('instance_name') - sql_content = request.POST.get('sql_content') - db_name = request.POST.get('db_name') - tb_name = request.POST.get('tb_name') - limit_num = int(request.POST.get('limit_num', 0)) - schema_name = request.POST.get('schema_name', None) + instance_name = request.POST.get("instance_name") + sql_content = request.POST.get("sql_content") + db_name = request.POST.get("db_name") + tb_name = request.POST.get("tb_name") + limit_num = int(request.POST.get("limit_num", 0)) + schema_name = request.POST.get("schema_name", None) user = request.user - result = {'status': 0, 'msg': 'ok', 'data': {}} + result = {"status": 0, "msg": "ok", "data": {}} try: instance = user_instances(request.user).get(instance_name=instance_name) except Instance.DoesNotExist: - result['status'] = 1 - result['msg'] = '你所在组未关联该实例' - return HttpResponse(json.dumps(result), content_type='application/json') + result["status"] = 1 + result["msg"] = "你所在组未关联该实例" + return HttpResponse(json.dumps(result), content_type="application/json") # 服务器端参数验证 if None in [sql_content, db_name, instance_name, limit_num]: - result['status'] = 1 - result['msg'] = '页面提交参数可能为空' - return HttpResponse(json.dumps(result), content_type='application/json') + result["status"] = 1 + result["msg"] = "页面提交参数可能为空" + return HttpResponse(json.dumps(result), content_type="application/json") try: config = SysConfig() # 查询前的检查,禁用语句检查,语句切分 query_engine = get_engine(instance=instance) query_check_info = query_engine.query_check(db_name=db_name, sql=sql_content) - if query_check_info.get('bad_query'): + if query_check_info.get("bad_query"): # 引擎内部判断为 bad_query - result['status'] = 1 - result['msg'] = query_check_info.get('msg') - return HttpResponse(json.dumps(result), content_type='application/json') - if query_check_info.get('has_star') and config.get('disable_star') is True: + result["status"] = 1 + result["msg"] = query_check_info.get("msg") + return HttpResponse(json.dumps(result), content_type="application/json") + if query_check_info.get("has_star") and config.get("disable_star") is True: # 引擎内部判断为有 * 且禁止 * 选项打开 - result['status'] = 1 - result['msg'] = query_check_info.get('msg') - return HttpResponse(json.dumps(result), content_type='application/json') - sql_content = query_check_info['filtered_sql'] + result["status"] = 1 + result["msg"] = query_check_info.get("msg") + return HttpResponse(json.dumps(result), content_type="application/json") + sql_content = query_check_info["filtered_sql"] # 查询权限校验,并且获取limit_num - priv_check_info = query_priv_check(user, instance, db_name, sql_content, limit_num) - if priv_check_info['status'] == 0: - limit_num = priv_check_info['data']['limit_num'] - priv_check = priv_check_info['data']['priv_check'] + priv_check_info = query_priv_check( + user, instance, db_name, sql_content, limit_num + ) + if priv_check_info["status"] == 0: + limit_num = priv_check_info["data"]["limit_num"] + priv_check = priv_check_info["data"]["priv_check"] else: - result['status'] = priv_check_info['status'] - result['msg'] = priv_check_info['msg'] - return HttpResponse(json.dumps(result), content_type='application/json') + result["status"] = priv_check_info["status"] + result["msg"] = priv_check_info["msg"] + return HttpResponse(json.dumps(result), content_type="application/json") # explain的limit_num设置为0 limit_num = 0 if re.match(r"^explain", sql_content.lower()) else limit_num @@ -86,19 +88,25 @@ def query(request): # 先获取查询连接,用于后面查询复用连接以及终止会话 query_engine.get_connection(db_name=db_name) thread_id = query_engine.thread_id - max_execution_time = int(config.get('max_execution_time', 60)) + max_execution_time = int(config.get("max_execution_time", 60)) # 执行查询语句,并增加一个定时终止语句的schedule,timeout=max_execution_time if thread_id: - schedule_name = f'query-{time.time()}' - run_date = (datetime.datetime.now() + datetime.timedelta(seconds=max_execution_time)) + schedule_name = f"query-{time.time()}" + run_date = datetime.datetime.now() + datetime.timedelta( + seconds=max_execution_time + ) add_kill_conn_schedule(schedule_name, run_date, instance.id, thread_id) with FuncTimer() as t: # 获取主从延迟信息 seconds_behind_master = query_engine.seconds_behind_master - query_result = query_engine.query(db_name, sql_content, limit_num, - schema_name=schema_name, - tb_name=tb_name, - max_execution_time=max_execution_time * 1000) + query_result = query_engine.query( + db_name, + sql_content, + limit_num, + schema_name=schema_name, + tb_name=tb_name, + max_execution_time=max_execution_time * 1000, + ) query_result.query_time = t.cost # 返回查询结果后删除schedule if thread_id: @@ -106,46 +114,50 @@ def query(request): # 查询异常 if query_result.error: - result['status'] = 1 - result['msg'] = query_result.error + result["status"] = 1 + result["msg"] = query_result.error # 数据脱敏,仅对查询无错误的结果集进行脱敏,并且按照query_check配置是否返回 - elif config.get('data_masking'): + elif config.get("data_masking"): try: with FuncTimer() as t: - masking_result = query_engine.query_masking(db_name, sql_content, query_result) + masking_result = query_engine.query_masking( + db_name, sql_content, query_result + ) masking_result.mask_time = t.cost # 脱敏出错 if masking_result.error: # 开启query_check,直接返回异常,禁止执行 - if config.get('query_check'): - result['status'] = 1 - result['msg'] = f'数据脱敏异常:{masking_result.error}' + if config.get("query_check"): + result["status"] = 1 + result["msg"] = f"数据脱敏异常:{masking_result.error}" # 关闭query_check,忽略错误信息,返回未脱敏数据,权限校验标记为跳过 else: - logger.warning(f'数据脱敏异常,按照配置放行,查询语句:{sql_content},错误信息:{masking_result.error}') + logger.warning( + f"数据脱敏异常,按照配置放行,查询语句:{sql_content},错误信息:{masking_result.error}" + ) query_result.error = None - result['data'] = query_result.__dict__ + result["data"] = query_result.__dict__ # 正常脱敏 else: - result['data'] = masking_result.__dict__ + result["data"] = masking_result.__dict__ except Exception as msg: logger.error(traceback.format_exc()) # 抛出未定义异常,并且开启query_check,直接返回异常,禁止执行 - if config.get('query_check'): - result['status'] = 1 - result['msg'] = f'数据脱敏异常,请联系管理员,错误信息:{msg}' + if config.get("query_check"): + result["status"] = 1 + result["msg"] = f"数据脱敏异常,请联系管理员,错误信息:{msg}" # 关闭query_check,忽略错误信息,返回未脱敏数据,权限校验标记为跳过 else: - logger.warning(f'数据脱敏异常,按照配置放行,查询语句:{sql_content},错误信息:{msg}') + logger.warning(f"数据脱敏异常,按照配置放行,查询语句:{sql_content},错误信息:{msg}") query_result.error = None - result['data'] = query_result.__dict__ + result["data"] = query_result.__dict__ # 无需脱敏的语句 else: - result['data'] = query_result.__dict__ + result["data"] = query_result.__dict__ # 仅将成功的查询语句记录存入数据库 if not query_result.error: - result['data']['seconds_behind_master'] = seconds_behind_master + result["data"]["seconds_behind_master"] = seconds_behind_master if int(limit_num) == 0: limit_num = int(query_result.affected_rows) else: @@ -160,35 +172,46 @@ def query(request): cost_time=query_result.query_time, priv_check=priv_check, hit_rule=query_result.mask_rule_hit, - masking=query_result.is_masked + masking=query_result.is_masked, ) # 防止查询超时 if connection.connection and not connection.is_usable(): close_old_connections() query_log.save() except Exception as e: - logger.error(f'查询异常报错,查询语句:{sql_content}\n,错误信息:{traceback.format_exc()}') - result['status'] = 1 - result['msg'] = f'查询异常报错,错误信息:{e}' - return HttpResponse(json.dumps(result), content_type='application/json') + logger.error(f"查询异常报错,查询语句:{sql_content}\n,错误信息:{traceback.format_exc()}") + result["status"] = 1 + result["msg"] = f"查询异常报错,错误信息:{e}" + return HttpResponse(json.dumps(result), content_type="application/json") # 返回查询结果 try: - return HttpResponse(json.dumps(result, use_decimal=False, cls=ExtendJSONEncoderFTime, bigint_as_string=True), - content_type='application/json') + return HttpResponse( + json.dumps( + result, + use_decimal=False, + cls=ExtendJSONEncoderFTime, + bigint_as_string=True, + ), + content_type="application/json", + ) # 虽然能正常返回,但是依然会乱码 except UnicodeDecodeError: - return HttpResponse(json.dumps(result, default=str, bigint_as_string=True, encoding='latin1'), - content_type='application/json') + return HttpResponse( + json.dumps(result, default=str, bigint_as_string=True, encoding="latin1"), + content_type="application/json", + ) -@permission_required('sql.menu_sqlquery', raise_exception=True) +@permission_required("sql.menu_sqlquery", raise_exception=True) def querylog(request): return _querylog(request) -@permission_required('sql.audit_user', raise_exception=True) + +@permission_required("sql.audit_user", raise_exception=True) def querylog_audit(request): return _querylog(request) + def _querylog(request): """ 获取sql查询记录 @@ -198,67 +221,85 @@ def _querylog(request): # 获取用户信息 user = request.user - limit = int(request.GET.get('limit',0)) - offset = int(request.GET.get('offset',0)) + limit = int(request.GET.get("limit", 0)) + offset = int(request.GET.get("offset", 0)) limit = offset + limit limit = limit if limit else None - star = True if request.GET.get('star') == 'true' else False - query_log_id = request.GET.get('query_log_id') - search = request.GET.get('search', '') - start_date = request.GET.get('start_date','') - end_date = request.GET.get('end_date','') + star = True if request.GET.get("star") == "true" else False + query_log_id = request.GET.get("query_log_id") + search = request.GET.get("search", "") + start_date = request.GET.get("start_date", "") + end_date = request.GET.get("end_date", "") # 组合筛选项 filter_dict = dict() # 是否收藏 if star: - filter_dict['favorite'] = star + filter_dict["favorite"] = star # 语句别名 if query_log_id: - filter_dict['id'] = query_log_id + filter_dict["id"] = query_log_id # 管理员、审计员查看全部数据,普通用户查看自己的数据 - if not (user.is_superuser or user.has_perm('sql.audit_user')): - filter_dict['username'] = user.username - + if not (user.is_superuser or user.has_perm("sql.audit_user")): + filter_dict["username"] = user.username + if start_date and end_date: - end_date = datetime.datetime.strptime(end_date, '%Y-%m-%d') + datetime.timedelta(days=1) - filter_dict['create_time__range'] = (start_date, end_date) + end_date = datetime.datetime.strptime( + end_date, "%Y-%m-%d" + ) + datetime.timedelta(days=1) + filter_dict["create_time__range"] = (start_date, end_date) # 过滤组合筛选项 sql_log = QueryLog.objects.filter(**filter_dict) # 过滤搜索信息 - sql_log = sql_log.filter(Q(sqllog__icontains=search) | - Q(user_display__icontains=search) | - Q(alias__icontains=search)) + sql_log = sql_log.filter( + Q(sqllog__icontains=search) + | Q(user_display__icontains=search) + | Q(alias__icontains=search) + ) sql_log_count = sql_log.count() - sql_log_list = sql_log.order_by('-id')[offset:limit].values( - "id", "instance_name", "db_name", "sqllog", - "effect_row", "cost_time", "user_display", "favorite", "alias", - "create_time") + sql_log_list = sql_log.order_by("-id")[offset:limit].values( + "id", + "instance_name", + "db_name", + "sqllog", + "effect_row", + "cost_time", + "user_display", + "favorite", + "alias", + "create_time", + ) # QuerySet 序列化 rows = [row for row in sql_log_list] result = {"total": sql_log_count, "rows": rows} # 返回查询结果 - return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), - content_type='application/json') + return HttpResponse( + json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), + content_type="application/json", + ) -@permission_required('sql.menu_sqlquery', raise_exception=True) +@permission_required("sql.menu_sqlquery", raise_exception=True) def favorite(request): """ 收藏查询记录,并且设置别名 :param request: :return: """ - query_log_id = request.POST.get('query_log_id') - star = True if request.POST.get('star') == 'true' else False - alias = request.POST.get('alias') - QueryLog(id=query_log_id, favorite=star, alias=alias).save(update_fields=['favorite', 'alias']) + query_log_id = request.POST.get("query_log_id") + star = True if request.POST.get("star") == "true" else False + alias = request.POST.get("alias") + QueryLog(id=query_log_id, favorite=star, alias=alias).save( + update_fields=["favorite", "alias"] + ) # 返回查询结果 - return HttpResponse(json.dumps({'status': 0, 'msg': 'ok'}), content_type='application/json') + return HttpResponse( + json.dumps({"status": 0, "msg": "ok"}), content_type="application/json" + ) def kill_query_conn(instance_id, thread_id): @@ -266,4 +307,3 @@ def kill_query_conn(instance_id, thread_id): instance = Instance.objects.get(pk=instance_id) query_engine = get_engine(instance) query_engine.kill_connection(thread_id) - diff --git a/sql/query_privileges.py b/sql/query_privileges.py index 9257e23071..4c88980574 100644 --- a/sql/query_privileges.py +++ b/sql/query_privileges.py @@ -29,9 +29,9 @@ from sql.utils.workflow_audit import Audit from sql.utils.sql_utils import extract_tables -logger = logging.getLogger('default') +logger = logging.getLogger("default") -__author__ = 'hhyo' +__author__ = "hhyo" # TODO 权限校验内的语法解析和判断独立到每个engine内 @@ -45,22 +45,26 @@ def query_priv_check(user, instance, db_name, sql_content, limit_num): :param limit_num: :return: """ - result = {'status': 0, 'msg': 'ok', 'data': {'priv_check': True, 'limit_num': 0}} + result = {"status": 0, "msg": "ok", "data": {"priv_check": True, "limit_num": 0}} # 如果有can_query_all_instance, 视为管理员, 仅获取limit值信息 # superuser 拥有全部权限, 不需做特别修改 - if user.has_perm('sql.query_all_instances'): - priv_limit = int(SysConfig().get('admin_query_limit', 5000)) - result['data']['limit_num'] = min(priv_limit, limit_num) if limit_num else priv_limit + if user.has_perm("sql.query_all_instances"): + priv_limit = int(SysConfig().get("admin_query_limit", 5000)) + result["data"]["limit_num"] = ( + min(priv_limit, limit_num) if limit_num else priv_limit + ) return result # 如果有can_query_resource_group_instance, 视为资源组管理员, 可查询资源组内所有实例数据 - if user.has_perm('sql.query_resource_group_instance'): - if user_instances(user, tag_codes=['can_read']).filter(pk=instance.pk).exists(): - priv_limit = int(SysConfig().get('admin_query_limit', 5000)) - result['data']['limit_num'] = min(priv_limit, limit_num) if limit_num else priv_limit + if user.has_perm("sql.query_resource_group_instance"): + if user_instances(user, tag_codes=["can_read"]).filter(pk=instance.pk).exists(): + priv_limit = int(SysConfig().get("admin_query_limit", 5000)) + result["data"]["limit_num"] = ( + min(priv_limit, limit_num) if limit_num else priv_limit + ) return result # 仅MySQL做表权限校验 - if instance.db_type == 'mysql': + if instance.db_type == "mysql": try: # explain和show create跳过权限校验 if re.match(r"^explain|^show\s+create", sql_content, re.I): @@ -70,28 +74,39 @@ def query_priv_check(user, instance, db_name, sql_content, limit_num): # 循环验证权限,可能存在性能问题,但一次查询涉及的库表数量有限 for table in table_ref: # 既无库权限也无表权限则鉴权失败 - if not _db_priv(user, instance, table['schema']) and \ - not _tb_priv(user, instance, table['schema'], table['name']): + if not _db_priv(user, instance, table["schema"]) and not _tb_priv( + user, instance, table["schema"], table["name"] + ): # 没有库表查询权限时的staus为2 - result['status'] = 2 - result['msg'] = f"你无{table['schema']}.{table['name']}表的查询权限!请先到查询权限管理进行申请" + result["status"] = 2 + result[ + "msg" + ] = f"你无{table['schema']}.{table['name']}表的查询权限!请先到查询权限管理进行申请" return result # 获取查询涉及库/表权限的最小limit限制,和前端传参作对比,取最小值 for table in table_ref: - priv_limit = _priv_limit(user, instance, db_name=table['schema'], tb_name=table['name']) + priv_limit = _priv_limit( + user, instance, db_name=table["schema"], tb_name=table["name"] + ) limit_num = min(priv_limit, limit_num) if limit_num else priv_limit - result['data']['limit_num'] = limit_num + result["data"]["limit_num"] = limit_num except Exception as msg: - logger.error(f"无法校验查询语句权限,{instance.instance_name},{sql_content},{traceback.format_exc()}") - result['status'] = 1 - result['msg'] = f"无法校验查询语句权限,请联系管理员,错误信息:{msg}" + logger.error( + f"无法校验查询语句权限,{instance.instance_name},{sql_content},{traceback.format_exc()}" + ) + result["status"] = 1 + result["msg"] = f"无法校验查询语句权限,请联系管理员,错误信息:{msg}" # 其他类型实例仅校验库权限 else: # 先获取查询语句涉及的库,redis、mssql特殊处理,仅校验当前选择的库 - if instance.db_type in ['redis', 'mssql']: + if instance.db_type in ["redis", "mssql"]: dbs = [db_name] else: - dbs = [i['schema'].strip('`') for i in extract_tables(sql_content) if i['schema'] is not None] + dbs = [ + i["schema"].strip("`") + for i in extract_tables(sql_content) + if i["schema"] is not None + ] dbs.append(db_name) # 库去重 dbs = list(set(dbs)) @@ -101,18 +116,18 @@ def query_priv_check(user, instance, db_name, sql_content, limit_num): for db_name in dbs: if not _db_priv(user, instance, db_name): # 没有库表查询权限时的staus为2 - result['status'] = 2 - result['msg'] = f"你无{db_name}数据库的查询权限!请先到查询权限管理进行申请" + result["status"] = 2 + result["msg"] = f"你无{db_name}数据库的查询权限!请先到查询权限管理进行申请" return result # 有所有库权限则获取最小limit值 for db_name in dbs: priv_limit = _priv_limit(user, instance, db_name=db_name) limit_num = min(priv_limit, limit_num) if limit_num else priv_limit - result['data']['limit_num'] = limit_num + result["data"]["limit_num"] = limit_num return result -@permission_required('sql.menu_queryapplylist', raise_exception=True) +@permission_required("sql.menu_queryapplylist", raise_exception=True) def query_priv_apply_list(request): """ 获取查询权限申请列表 @@ -120,20 +135,22 @@ def query_priv_apply_list(request): :return: """ user = request.user - limit = int(request.POST.get('limit', 0)) - offset = int(request.POST.get('offset', 0)) + limit = int(request.POST.get("limit", 0)) + offset = int(request.POST.get("offset", 0)) limit = offset + limit - search = request.POST.get('search', '') + search = request.POST.get("search", "") query_privs = QueryPrivilegesApply.objects.all() # 过滤搜索项,支持模糊搜索标题、用户 if search: - query_privs = query_privs.filter(Q(title__icontains=search) | Q(user_display__icontains=search)) + query_privs = query_privs.filter( + Q(title__icontains=search) | Q(user_display__icontains=search) + ) # 管理员可以看到全部数据 if user.is_superuser: query_privs = query_privs # 拥有审核权限、可以查看组内所有工单 - elif user.has_perm('sql.query_review'): + elif user.has_perm("sql.query_review"): # 先获取用户所在资源组列表 group_list = user_groups(user) group_ids = [group.group_id for group in group_list] @@ -143,9 +160,19 @@ def query_priv_apply_list(request): query_privs = query_privs.filter(user_name=user.username) count = query_privs.count() - lists = query_privs.order_by('-apply_id')[offset:limit].values( - 'apply_id', 'title', 'instance__instance_name', 'db_list', 'priv_type', 'table_list', 'limit_num', 'valid_date', - 'user_display', 'status', 'create_time', 'group_name' + lists = query_privs.order_by("-apply_id")[offset:limit].values( + "apply_id", + "title", + "instance__instance_name", + "db_list", + "priv_type", + "table_list", + "limit_num", + "valid_date", + "user_display", + "status", + "create_time", + "group_name", ) # QuerySet 序列化 @@ -153,49 +180,60 @@ def query_priv_apply_list(request): result = {"total": count, "rows": rows} # 返回查询结果 - return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), - content_type='application/json') + return HttpResponse( + json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), + content_type="application/json", + ) -@permission_required('sql.query_applypriv', raise_exception=True) +@permission_required("sql.query_applypriv", raise_exception=True) def query_priv_apply(request): """ 申请查询权限 :param request: :return: """ - title = request.POST['title'] - instance_name = request.POST.get('instance_name') - group_name = request.POST.get('group_name') + title = request.POST["title"] + instance_name = request.POST.get("instance_name") + group_name = request.POST.get("group_name") group_id = ResourceGroup.objects.get(group_name=group_name).group_id - priv_type = request.POST.get('priv_type') - db_name = request.POST.get('db_name') - db_list = request.POST.getlist('db_list[]') - table_list = request.POST.getlist('table_list[]') - valid_date = request.POST.get('valid_date') - limit_num = request.POST.get('limit_num') + priv_type = request.POST.get("priv_type") + db_name = request.POST.get("db_name") + db_list = request.POST.getlist("db_list[]") + table_list = request.POST.getlist("table_list[]") + valid_date = request.POST.get("valid_date") + limit_num = request.POST.get("limit_num") # 获取用户信息 user = request.user # 服务端参数校验 - result = {'status': 0, 'msg': 'ok', 'data': []} + result = {"status": 0, "msg": "ok", "data": []} if int(priv_type) == 1: if not (title and instance_name and db_list and valid_date and limit_num): - result['status'] = 1 - result['msg'] = '请填写完整' - return HttpResponse(json.dumps(result), content_type='application/json') + result["status"] = 1 + result["msg"] = "请填写完整" + return HttpResponse(json.dumps(result), content_type="application/json") elif int(priv_type) == 2: - if not (title and instance_name and db_name and valid_date and table_list and limit_num): - result['status'] = 1 - result['msg'] = '请填写完整' - return HttpResponse(json.dumps(result), content_type='application/json') + if not ( + title + and instance_name + and db_name + and valid_date + and table_list + and limit_num + ): + result["status"] = 1 + result["msg"] = "请填写完整" + return HttpResponse(json.dumps(result), content_type="application/json") try: - user_instances(request.user, tag_codes=['can_read']).get(instance_name=instance_name) + user_instances(request.user, tag_codes=["can_read"]).get( + instance_name=instance_name + ) except Instance.DoesNotExist: - result['status'] = 1 - result['msg'] = '你所在组未关联该实例!' - return HttpResponse(json.dumps(result), content_type='application/json') + result["status"] = 1 + result["msg"] = "你所在组未关联该实例!" + return HttpResponse(json.dumps(result), content_type="application/json") # 库权限 ins = Instance.objects.get(instance_name=instance_name) @@ -203,23 +241,23 @@ def query_priv_apply(request): # 检查申请账号是否已拥库查询权限 for db_name in db_list: if _db_priv(user, ins, db_name): - result['status'] = 1 - result['msg'] = f'你已拥有{instance_name}实例{db_name}库权限,不能重复申请' - return HttpResponse(json.dumps(result), content_type='application/json') + result["status"] = 1 + result["msg"] = f"你已拥有{instance_name}实例{db_name}库权限,不能重复申请" + return HttpResponse(json.dumps(result), content_type="application/json") # 表权限 elif int(priv_type) == 2: # 先检查是否拥有库权限 if _db_priv(user, ins, db_name): - result['status'] = 1 - result['msg'] = f'你已拥有{instance_name}实例{db_name}库的全部权限,不能重复申请' - return HttpResponse(json.dumps(result), content_type='application/json') + result["status"] = 1 + result["msg"] = f"你已拥有{instance_name}实例{db_name}库的全部权限,不能重复申请" + return HttpResponse(json.dumps(result), content_type="application/json") # 检查申请账号是否已拥有该表的查询权限 for tb_name in table_list: if _tb_priv(user, ins, db_name, tb_name): - result['status'] = 1 - result['msg'] = f'你已拥有{instance_name}实例{db_name}.{tb_name}表的查询权限,不能重复申请' - return HttpResponse(json.dumps(result), content_type='application/json') + result["status"] = 1 + result["msg"] = f"你已拥有{instance_name}实例{db_name}.{tb_name}表的查询权限,不能重复申请" + return HttpResponse(json.dumps(result), content_type="application/json") # 使用事务保持数据一致性 try: @@ -229,43 +267,53 @@ def query_priv_apply(request): title=title, group_id=group_id, group_name=group_name, - audit_auth_groups=Audit.settings(group_id, WorkflowDict.workflow_type['query']), + audit_auth_groups=Audit.settings( + group_id, WorkflowDict.workflow_type["query"] + ), user_name=user.username, user_display=user.display, instance=ins, priv_type=int(priv_type), valid_date=valid_date, - status=WorkflowDict.workflow_status['audit_wait'], - limit_num=limit_num + status=WorkflowDict.workflow_status["audit_wait"], + limit_num=limit_num, ) if int(priv_type) == 1: - applyinfo.db_list = ','.join(db_list) - applyinfo.table_list = '' + applyinfo.db_list = ",".join(db_list) + applyinfo.table_list = "" elif int(priv_type) == 2: applyinfo.db_list = db_name - applyinfo.table_list = ','.join(table_list) + applyinfo.table_list = ",".join(table_list) applyinfo.save() apply_id = applyinfo.apply_id # 调用工作流插入审核信息,查询权限申请workflow_type=1 - audit_result = Audit.add(WorkflowDict.workflow_type['query'], apply_id) - if audit_result['status'] == 0: + audit_result = Audit.add(WorkflowDict.workflow_type["query"], apply_id) + if audit_result["status"] == 0: # 更新业务表审核状态,判断是否插入权限信息 - _query_apply_audit_call_back(apply_id, audit_result['data']['workflow_status']) + _query_apply_audit_call_back( + apply_id, audit_result["data"]["workflow_status"] + ) except Exception as msg: logger.error(traceback.format_exc()) - result['status'] = 1 - result['msg'] = str(msg) + result["status"] = 1 + result["msg"] = str(msg) else: result = audit_result # 消息通知 - audit_id = Audit.detail_by_workflow_id(workflow_id=apply_id, - workflow_type=WorkflowDict.workflow_type['query']).audit_id - async_task(notify_for_audit, audit_id=audit_id, timeout=60, task_name=f'query-priv-apply-{apply_id}') - return HttpResponse(json.dumps(result), content_type='application/json') - - -@permission_required('sql.menu_queryapplylist', raise_exception=True) + audit_id = Audit.detail_by_workflow_id( + workflow_id=apply_id, workflow_type=WorkflowDict.workflow_type["query"] + ).audit_id + async_task( + notify_for_audit, + audit_id=audit_id, + timeout=60, + task_name=f"query-priv-apply-{apply_id}", + ) + return HttpResponse(json.dumps(result), content_type="application/json") + + +@permission_required("sql.menu_queryapplylist", raise_exception=True) def user_query_priv(request): """ 用户的查询权限管理 @@ -273,38 +321,54 @@ def user_query_priv(request): :return: """ user = request.user - user_display = request.POST.get('user_display', 'all') - limit = int(request.POST.get('limit')) - offset = int(request.POST.get('offset')) + user_display = request.POST.get("user_display", "all") + limit = int(request.POST.get("limit")) + offset = int(request.POST.get("offset")) limit = offset + limit - search = request.POST.get('search', '') + search = request.POST.get("search", "") - user_query_privs = QueryPrivileges.objects.filter(is_deleted=0, valid_date__gte=datetime.datetime.now()) + user_query_privs = QueryPrivileges.objects.filter( + is_deleted=0, valid_date__gte=datetime.datetime.now() + ) # 过滤搜索项,支持模糊搜索用户、数据库、表 if search: - user_query_privs = user_query_privs.filter(Q(user_display__icontains=search) | - Q(db_name__icontains=search) | - Q(table_name__icontains=search)) + user_query_privs = user_query_privs.filter( + Q(user_display__icontains=search) + | Q(db_name__icontains=search) + | Q(table_name__icontains=search) + ) # 过滤用户 - if user_display != 'all': + if user_display != "all": user_query_privs = user_query_privs.filter(user_display=user_display) # 管理员可以看到全部数据 if user.is_superuser: user_query_privs = user_query_privs # 拥有管理权限、可以查看组内所有工单 - elif user.has_perm('sql.query_mgtpriv'): + elif user.has_perm("sql.query_mgtpriv"): # 先获取用户所在资源组列表 group_list = user_groups(user) group_ids = [group.group_id for group in group_list] - user_query_privs = user_query_privs.filter(instance__queryprivilegesapply__group_id__in=group_ids) + user_query_privs = user_query_privs.filter( + instance__queryprivilegesapply__group_id__in=group_ids + ) # 其他人只能看到自己提交的工单 else: user_query_privs = user_query_privs.filter(user_name=user.username) privileges_count = user_query_privs.distinct().count() - privileges_list = user_query_privs.distinct().order_by('-privilege_id')[offset:limit].values( - 'privilege_id', 'user_display', 'instance__instance_name', 'db_name', 'priv_type', - 'table_name', 'limit_num', 'valid_date' + privileges_list = ( + user_query_privs.distinct() + .order_by("-privilege_id")[offset:limit] + .values( + "privilege_id", + "user_display", + "instance__instance_name", + "db_name", + "priv_type", + "table_name", + "limit_num", + "valid_date", + ) ) # QuerySet 序列化 @@ -312,45 +376,47 @@ def user_query_priv(request): result = {"total": privileges_count, "rows": rows} # 返回查询结果 - return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), - content_type='application/json') + return HttpResponse( + json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), + content_type="application/json", + ) -@permission_required('sql.query_mgtpriv', raise_exception=True) +@permission_required("sql.query_mgtpriv", raise_exception=True) def query_priv_modify(request): """ 变更权限信息 :param request: :return: """ - privilege_id = request.POST.get('privilege_id') - type = request.POST.get('type') - result = {'status': 0, 'msg': 'ok', 'data': []} + privilege_id = request.POST.get("privilege_id") + type = request.POST.get("type") + result = {"status": 0, "msg": "ok", "data": []} # type=1删除权限,type=2变更权限 try: privilege = QueryPrivileges.objects.get(privilege_id=int(privilege_id)) except QueryPrivileges.DoesNotExist: - result['msg'] = '待操作权限不存在' - result['status'] = 1 - return HttpResponse(json.dumps(result), content_type='application/json') + result["msg"] = "待操作权限不存在" + result["status"] = 1 + return HttpResponse(json.dumps(result), content_type="application/json") if int(type) == 1: # 删除权限 privilege.is_deleted = 1 - privilege.save(update_fields=['is_deleted']) - return HttpResponse(json.dumps(result), content_type='application/json') + privilege.save(update_fields=["is_deleted"]) + return HttpResponse(json.dumps(result), content_type="application/json") elif int(type) == 2: # 变更权限 - valid_date = request.POST.get('valid_date') - limit_num = request.POST.get('limit_num') + valid_date = request.POST.get("valid_date") + limit_num = request.POST.get("limit_num") privilege.valid_date = valid_date privilege.limit_num = limit_num - privilege.save(update_fields=['valid_date', 'limit_num']) - return HttpResponse(json.dumps(result), content_type='application/json') + privilege.save(update_fields=["valid_date", "limit_num"]) + return HttpResponse(json.dumps(result), content_type="application/json") -@permission_required('sql.query_review', raise_exception=True) +@permission_required("sql.query_review", raise_exception=True) def query_priv_audit(request): """ 查询权限审核 @@ -359,42 +425,52 @@ def query_priv_audit(request): """ # 获取用户信息 user = request.user - apply_id = int(request.POST['apply_id']) - audit_status = int(request.POST['audit_status']) - audit_remark = request.POST.get('audit_remark') + apply_id = int(request.POST["apply_id"]) + audit_status = int(request.POST["audit_status"]) + audit_remark = request.POST.get("audit_remark") if audit_remark is None: - audit_remark = '' + audit_remark = "" if Audit.can_review(request.user, apply_id, 1) is False: - context = {'errMsg': '你无权操作当前工单!'} - return render(request, 'error.html', context) + context = {"errMsg": "你无权操作当前工单!"} + return render(request, "error.html", context) # 使用事务保持数据一致性 try: with transaction.atomic(): - audit_id = Audit.detail_by_workflow_id(workflow_id=apply_id, - workflow_type=WorkflowDict.workflow_type['query']).audit_id + audit_id = Audit.detail_by_workflow_id( + workflow_id=apply_id, workflow_type=WorkflowDict.workflow_type["query"] + ).audit_id # 调用工作流接口审核 - audit_result = Audit.audit(audit_id, audit_status, user.username, audit_remark) + audit_result = Audit.audit( + audit_id, audit_status, user.username, audit_remark + ) # 按照审核结果更新业务表审核状态 audit_detail = Audit.detail(audit_id) - if audit_detail.workflow_type == WorkflowDict.workflow_type['query']: + if audit_detail.workflow_type == WorkflowDict.workflow_type["query"]: # 更新业务表审核状态,插入权限信息 - _query_apply_audit_call_back(audit_detail.workflow_id, audit_result['data']['workflow_status']) + _query_apply_audit_call_back( + audit_detail.workflow_id, audit_result["data"]["workflow_status"] + ) except Exception as msg: logger.error(traceback.format_exc()) - context = {'errMsg': msg} - return render(request, 'error.html', context) + context = {"errMsg": msg} + return render(request, "error.html", context) else: # 消息通知 - async_task(notify_for_audit, audit_id=audit_id, audit_remark=audit_remark, timeout=60, - task_name=f'query-priv-audit-{apply_id}') + async_task( + notify_for_audit, + audit_id=audit_id, + audit_remark=audit_remark, + timeout=60, + task_name=f"query-priv-audit-{apply_id}", + ) - return HttpResponseRedirect(reverse('sql:queryapplydetail', args=(apply_id,))) + return HttpResponseRedirect(reverse("sql:queryapplydetail", args=(apply_id,))) def _table_ref(sql_content, instance, db_name): @@ -406,7 +482,9 @@ def _table_ref(sql_content, instance, db_name): :return: """ engine = GoInceptionEngine() - query_tree = engine.query_print(instance=instance, db_name=db_name, sql=sql_content).get('query_tree') + query_tree = engine.query_print( + instance=instance, db_name=db_name, sql=sql_content + ).get("query_tree") return engine.get_table_ref(json.loads(query_tree), db_name=db_name) @@ -420,11 +498,16 @@ def _db_priv(user, instance, db_name): TODO 返回统一为 int 类型, 不存在返回0 (虽然其实在python中 0==False) """ # 获取用户库权限 - user_privileges = QueryPrivileges.objects.filter(user_name=user.username, instance=instance, db_name=str(db_name), - valid_date__gte=datetime.datetime.now(), is_deleted=0, - priv_type=1) + user_privileges = QueryPrivileges.objects.filter( + user_name=user.username, + instance=instance, + db_name=str(db_name), + valid_date__gte=datetime.datetime.now(), + is_deleted=0, + priv_type=1, + ) if user.is_superuser: - return int(SysConfig().get('admin_query_limit', 5000)) + return int(SysConfig().get("admin_query_limit", 5000)) else: if user_privileges.exists(): return user_privileges.first().limit_num @@ -441,11 +524,17 @@ def _tb_priv(user, instance, db_name, tb_name): :return: 权限存在则返回对应权限的limit_num,否则返回False """ # 获取用户表权限 - user_privileges = QueryPrivileges.objects.filter(user_name=user.username, instance=instance, db_name=str(db_name), - table_name=str(tb_name), valid_date__gte=datetime.datetime.now(), - is_deleted=0, priv_type=2) + user_privileges = QueryPrivileges.objects.filter( + user_name=user.username, + instance=instance, + db_name=str(db_name), + table_name=str(tb_name), + valid_date__gte=datetime.datetime.now(), + is_deleted=0, + priv_type=2, + ) if user.is_superuser: - return int(SysConfig().get('admin_query_limit', 5000)) + return int(SysConfig().get("admin_query_limit", 5000)) else: if user_privileges.exists(): return user_privileges.first().limit_num @@ -473,7 +562,7 @@ def _priv_limit(user, instance, db_name, tb_name=None): elif tb_limit_num: return tb_limit_num else: - raise RuntimeError('用户无任何有效权限!') + raise RuntimeError("用户无任何有效权限!") def _query_apply_audit_call_back(apply_id, workflow_status): @@ -488,27 +577,37 @@ def _query_apply_audit_call_back(apply_id, workflow_status): apply_info.status = workflow_status apply_info.save() # 审核通过插入权限信息,批量插入,减少性能消耗 - if workflow_status == WorkflowDict.workflow_status['audit_success']: + if workflow_status == WorkflowDict.workflow_status["audit_success"]: apply_queryset = QueryPrivilegesApply.objects.get(apply_id=apply_id) # 库权限 if apply_queryset.priv_type == 1: - insert_list = [QueryPrivileges( - user_name=apply_queryset.user_name, - user_display=apply_queryset.user_display, - instance=apply_queryset.instance, - db_name=db_name, - table_name=apply_queryset.table_list, valid_date=apply_queryset.valid_date, - limit_num=apply_queryset.limit_num, priv_type=apply_queryset.priv_type) for db_name in - apply_queryset.db_list.split(',')] + insert_list = [ + QueryPrivileges( + user_name=apply_queryset.user_name, + user_display=apply_queryset.user_display, + instance=apply_queryset.instance, + db_name=db_name, + table_name=apply_queryset.table_list, + valid_date=apply_queryset.valid_date, + limit_num=apply_queryset.limit_num, + priv_type=apply_queryset.priv_type, + ) + for db_name in apply_queryset.db_list.split(",") + ] # 表权限 elif apply_queryset.priv_type == 2: - insert_list = [QueryPrivileges( - user_name=apply_queryset.user_name, - user_display=apply_queryset.user_display, - instance=apply_queryset.instance, - db_name=apply_queryset.db_list, - table_name=table_name, valid_date=apply_queryset.valid_date, - limit_num=apply_queryset.limit_num, priv_type=apply_queryset.priv_type) for table_name in - apply_queryset.table_list.split(',')] + insert_list = [ + QueryPrivileges( + user_name=apply_queryset.user_name, + user_display=apply_queryset.user_display, + instance=apply_queryset.instance, + db_name=apply_queryset.db_list, + table_name=table_name, + valid_date=apply_queryset.valid_date, + limit_num=apply_queryset.limit_num, + priv_type=apply_queryset.priv_type, + ) + for table_name in apply_queryset.table_list.split(",") + ] QueryPrivileges.objects.bulk_create(insert_list) diff --git a/sql/resource_group.py b/sql/resource_group.py index e9875a14f5..6d9fa0226e 100644 --- a/sql/resource_group.py +++ b/sql/resource_group.py @@ -14,29 +14,33 @@ from sql.utils.resource_group import user_instances from sql.utils.workflow_audit import Audit -logger = logging.getLogger('default') +logger = logging.getLogger("default") @superuser_required def group(request): """获取资源组列表""" - limit = int(request.POST.get('limit')) - offset = int(request.POST.get('offset')) + limit = int(request.POST.get("limit")) + offset = int(request.POST.get("offset")) limit = offset + limit - search = request.POST.get('search', '') + search = request.POST.get("search", "") # 过滤搜索条件 group_obj = ResourceGroup.objects.filter(group_name__icontains=search, is_deleted=0) group_count = group_obj.count() - group_list = group_obj[offset:limit].values("group_id", "group_name", "ding_webhook") + group_list = group_obj[offset:limit].values( + "group_id", "group_name", "ding_webhook" + ) # QuerySet 序列化 rows = [row for row in group_list] result = {"total": group_count, "rows": rows} # 返回查询结果 - return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), - content_type='application/json') + return HttpResponse( + json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), + content_type="application/json", + ) def associated_objects(request): @@ -44,12 +48,12 @@ def associated_objects(request): 获取资源组已关联对象信息 type:(0, '用户'), (1, '实例') """ - group_id = int(request.POST.get('group_id')) - object_type = request.POST.get('type') - limit = int(request.POST.get('limit')) - offset = int(request.POST.get('offset')) + group_id = int(request.POST.get("group_id")) + object_type = request.POST.get("type") + limit = int(request.POST.get("limit")) + offset = int(request.POST.get("offset")) limit = offset + limit - search = request.POST.get('search') + search = request.POST.get("search") # 获取关联数据 resource_group = ResourceGroup.objects.get(group_id=group_id) @@ -60,29 +64,25 @@ def associated_objects(request): rows_users = rows_users.filter(display__contains=search) rows_instances = rows_instances.filter(instance_name__contains=search) rows_users = rows_users.annotate( - object_id=F('id'), + object_id=F("id"), object_type=Value(0, output_field=IntegerField()), - object_name=F('display'), - group_id=F('resource_group__group_id'), - group_name=F('resource_group__group_name') - ).values( - 'object_type', 'object_id', 'object_name', - 'group_id', 'group_name') + object_name=F("display"), + group_id=F("resource_group__group_id"), + group_name=F("resource_group__group_name"), + ).values("object_type", "object_id", "object_name", "group_id", "group_name") rows_instances = rows_instances.annotate( - object_id=F('id'), + object_id=F("id"), object_type=Value(1, output_field=IntegerField()), - object_name=F('instance_name'), - group_id=F('resource_group__group_id'), - group_name=F('resource_group__group_name') - ).values( - 'object_type', 'object_id', 'object_name', - 'group_id', 'group_name') + object_name=F("instance_name"), + group_id=F("resource_group__group_id"), + group_name=F("resource_group__group_name"), + ).values("object_type", "object_id", "object_name", "group_id", "group_name") # 过滤对象类型 - if object_type == '0': + if object_type == "0": rows_obj = rows_users count = rows_obj.count() rows = [row for row in rows_obj][offset:limit] - elif object_type == '1': + elif object_type == "1": rows_obj = rows_instances count = rows_obj.count() rows = [row for row in rows_obj][offset:limit] @@ -90,8 +90,10 @@ def associated_objects(request): rows = list(chain(rows_users, rows_instances)) count = len(rows) rows = rows[offset:limit] - result = {'status': 0, 'msg': 'ok', "total": count, "rows": rows} - return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder), content_type='application/json') + result = {"status": 0, "msg": "ok", "total": count, "rows": rows} + return HttpResponse( + json.dumps(result, cls=ExtendJSONEncoder), content_type="application/json" + ) def unassociated_objects(request): @@ -99,33 +101,38 @@ def unassociated_objects(request): 获取资源组未关联对象信息 type:(0, '用户'), (1, '实例') """ - group_id = int(request.POST.get('group_id')) - object_type = int(request.POST.get('object_type')) + group_id = int(request.POST.get("group_id")) + object_type = int(request.POST.get("object_type")) # 获取关联数据 resource_group = ResourceGroup.objects.get(group_id=group_id) if object_type == 0: associated_user_ids = [user.id for user in resource_group.users_set.all()] - rows = Users.objects.exclude(pk__in=associated_user_ids).annotate( - object_id=F('pk'), object_name=F('display')).values('object_id', 'object_name') + rows = ( + Users.objects.exclude(pk__in=associated_user_ids) + .annotate(object_id=F("pk"), object_name=F("display")) + .values("object_id", "object_name") + ) elif object_type == 1: associated_instance_ids = [ins.id for ins in resource_group.instance_set.all()] - rows = Instance.objects.exclude(pk__in=associated_instance_ids).annotate( - object_id=F('pk'), object_name=F('instance_name') - ).values('object_id', 'object_name') + rows = ( + Instance.objects.exclude(pk__in=associated_instance_ids) + .annotate(object_id=F("pk"), object_name=F("instance_name")) + .values("object_id", "object_name") + ) else: - raise ValueError('关联对象类型不正确') + raise ValueError("关联对象类型不正确") rows = [row for row in rows] - result = {'status': 0, 'msg': 'ok', "rows": rows, "total": len(rows)} - return HttpResponse(json.dumps(result), content_type='application/json') + result = {"status": 0, "msg": "ok", "rows": rows, "total": len(rows)} + return HttpResponse(json.dumps(result), content_type="application/json") def instances(request): """获取资源组关联实例列表""" - group_name = request.POST.get('group_name') + group_name = request.POST.get("group_name") group_id = ResourceGroup.objects.get(group_name=group_name).group_id - tag_code = request.POST.get('tag_code') - db_type = request.POST.get('db_type') + tag_code = request.POST.get("tag_code") + db_type = request.POST.get("db_type") # 先获取资源组关联所有实例列表 ins = ResourceGroup.objects.get(group_id=group_id).instance_set.all() @@ -134,28 +141,34 @@ def instances(request): filter_dict = dict() # db_type if db_type: - filter_dict['db_type'] = db_type + filter_dict["db_type"] = db_type if tag_code: - filter_dict['instance_tag__tag_code'] = tag_code - filter_dict['instance_tag__active'] = True - ins = ins.filter(**filter_dict).order_by(Convert('instance_name', 'gbk').asc()).values( - 'id', 'type', 'db_type', 'instance_name') + filter_dict["instance_tag__tag_code"] = tag_code + filter_dict["instance_tag__active"] = True + ins = ( + ins.filter(**filter_dict) + .order_by(Convert("instance_name", "gbk").asc()) + .values("id", "type", "db_type", "instance_name") + ) rows = [row for row in ins] - result = {'status': 0, 'msg': 'ok', "data": rows} - return HttpResponse(json.dumps(result), content_type='application/json') + result = {"status": 0, "msg": "ok", "data": rows} + return HttpResponse(json.dumps(result), content_type="application/json") def user_all_instances(request): """获取用户所有实例列表(通过资源组间接关联)""" user = request.user - type = request.GET.get('type') - db_type = request.GET.getlist('db_type[]') - tag_codes = request.GET.getlist('tag_codes[]') - instances = user_instances(user, type, db_type, tag_codes).order_by( - Convert('instance_name', 'gbk').asc()).values('id', 'type', 'db_type', 'instance_name') + type = request.GET.get("type") + db_type = request.GET.getlist("db_type[]") + tag_codes = request.GET.getlist("tag_codes[]") + instances = ( + user_instances(user, type, db_type, tag_codes) + .order_by(Convert("instance_name", "gbk").asc()) + .values("id", "type", "db_type", "instance_name") + ) rows = [row for row in instances] - result = {'status': 0, 'msg': 'ok', "data": rows} - return HttpResponse(json.dumps(result), content_type='application/json') + result = {"status": 0, "msg": "ok", "data": rows} + return HttpResponse(json.dumps(result), content_type="application/json") @superuser_required @@ -164,71 +177,84 @@ def addrelation(request): 添加资源组关联对象 type:(0, '用户'), (1, '实例') """ - group_id = int(request.POST.get('group_id')) - object_type = request.POST.get('object_type') - object_list = json.loads(request.POST.get('object_info')) + group_id = int(request.POST.get("group_id")) + object_type = request.POST.get("object_type") + object_list = json.loads(request.POST.get("object_info")) try: resource_group = ResourceGroup.objects.get(group_id=group_id) - obj_ids = [int(obj.split(',')[0]) for obj in object_list] - if object_type == '0': # 用户 + obj_ids = [int(obj.split(",")[0]) for obj in object_list] + if object_type == "0": # 用户 resource_group.users_set.add(*Users.objects.filter(pk__in=obj_ids)) - elif object_type == '1': # 实例 + elif object_type == "1": # 实例 resource_group.instance_set.add(*Instance.objects.filter(pk__in=obj_ids)) - result = {'status': 0, 'msg': 'ok'} + result = {"status": 0, "msg": "ok"} except Exception as e: logger.error(traceback.format_exc()) - result = {'status': 1, 'msg': str(e)} - return HttpResponse(json.dumps(result), content_type='application/json') + result = {"status": 1, "msg": str(e)} + return HttpResponse(json.dumps(result), content_type="application/json") def auditors(request): """获取资源组的审批流程""" - group_name = request.POST.get('group_name') - workflow_type = request.POST['workflow_type'] - result = {'status': 0, 'msg': 'ok', 'data': {'auditors': '', 'auditors_display': ''}} + group_name = request.POST.get("group_name") + workflow_type = request.POST["workflow_type"] + result = { + "status": 0, + "msg": "ok", + "data": {"auditors": "", "auditors_display": ""}, + } if group_name: group_id = ResourceGroup.objects.get(group_name=group_name).group_id - audit_auth_groups = Audit.settings(group_id=group_id, workflow_type=workflow_type) + audit_auth_groups = Audit.settings( + group_id=group_id, workflow_type=workflow_type + ) else: - result['status'] = 1 - result['msg'] = '参数错误' - return HttpResponse(json.dumps(result), content_type='application/json') + result["status"] = 1 + result["msg"] = "参数错误" + return HttpResponse(json.dumps(result), content_type="application/json") # 获取权限组名称 if audit_auth_groups: # 校验配置 - for auth_group_id in audit_auth_groups.split(','): + for auth_group_id in audit_auth_groups.split(","): try: Group.objects.get(id=auth_group_id) except Exception: - result['status'] = 1 - result['msg'] = '审批流程权限组不存在,请重新配置!' - return HttpResponse(json.dumps(result), content_type='application/json') - audit_auth_groups_name = '->'.join( - [Group.objects.get(id=auth_group_id).name for auth_group_id in audit_auth_groups.split(',')]) - result['data']['auditors'] = audit_auth_groups - result['data']['auditors_display'] = audit_auth_groups_name + result["status"] = 1 + result["msg"] = "审批流程权限组不存在,请重新配置!" + return HttpResponse(json.dumps(result), content_type="application/json") + audit_auth_groups_name = "->".join( + [ + Group.objects.get(id=auth_group_id).name + for auth_group_id in audit_auth_groups.split(",") + ] + ) + result["data"]["auditors"] = audit_auth_groups + result["data"]["auditors_display"] = audit_auth_groups_name - return HttpResponse(json.dumps(result), content_type='application/json') + return HttpResponse(json.dumps(result), content_type="application/json") @superuser_required def changeauditors(request): """设置资源组的审批流程""" - auth_groups = request.POST.get('audit_auth_groups') - group_name = request.POST.get('group_name') - workflow_type = request.POST.get('workflow_type') - result = {'status': 0, 'msg': 'ok', 'data': []} + auth_groups = request.POST.get("audit_auth_groups") + group_name = request.POST.get("group_name") + workflow_type = request.POST.get("workflow_type") + result = {"status": 0, "msg": "ok", "data": []} # 调用工作流修改审核配置 group_id = ResourceGroup.objects.get(group_name=group_name).group_id - audit_auth_groups = [str(Group.objects.get(name=auth_group).id) for auth_group in auth_groups.split(',')] + audit_auth_groups = [ + str(Group.objects.get(name=auth_group).id) + for auth_group in auth_groups.split(",") + ] try: - Audit.change_settings(group_id, workflow_type, ','.join(audit_auth_groups)) + Audit.change_settings(group_id, workflow_type, ",".join(audit_auth_groups)) except Exception as msg: logger.error(traceback.format_exc()) - result['msg'] = str(msg) - result['status'] = 1 + result["msg"] = str(msg) + result["status"] = 1 # 返回结果 - return HttpResponse(json.dumps(result), content_type='application/json') + return HttpResponse(json.dumps(result), content_type="application/json") diff --git a/sql/slowlog.py b/sql/slowlog.py index f389f8e451..630e9d9446 100644 --- a/sql/slowlog.py +++ b/sql/slowlog.py @@ -14,24 +14,26 @@ from common.utils.extend_json_encoder import ExtendJSONEncoder from .models import Instance, SlowQuery, SlowQueryHistory, AliyunRdsConfig -from .aliyun_rds import slowquery_review as aliyun_rds_slowquery_review, \ - slowquery_review_history as aliyun_rds_slowquery_review_history +from .aliyun_rds import ( + slowquery_review as aliyun_rds_slowquery_review, + slowquery_review_history as aliyun_rds_slowquery_review_history, +) import logging -logger = logging.getLogger('default') +logger = logging.getLogger("default") # 获取SQL慢日志统计 -@permission_required('sql.menu_slowquery', raise_exception=True) +@permission_required("sql.menu_slowquery", raise_exception=True) def slowquery_review(request): - instance_name = request.POST.get('instance_name') + instance_name = request.POST.get("instance_name") # 服务端权限校验 try: - user_instances(request.user, db_type=['mysql']).get(instance_name=instance_name) + user_instances(request.user, db_type=["mysql"]).get(instance_name=instance_name) except Exception: - result = {'status': 1, 'msg': '你所在组未关联该实例', 'data': []} - return HttpResponse(json.dumps(result), content_type='application/json') + result = {"status": 1, "msg": "你所在组未关联该实例", "data": []} + return HttpResponse(json.dumps(result), content_type="application/json") # 判断是RDS还是其他实例 instance_info = Instance.objects.get(instance_name=instance_name) @@ -39,65 +41,85 @@ def slowquery_review(request): # 调用阿里云慢日志接口 result = aliyun_rds_slowquery_review(request) else: - start_time = request.POST.get('StartTime') - end_time = request.POST.get('EndTime') - db_name = request.POST.get('db_name') - limit = int(request.POST.get('limit')) - offset = int(request.POST.get('offset')) + start_time = request.POST.get("StartTime") + end_time = request.POST.get("EndTime") + db_name = request.POST.get("db_name") + limit = int(request.POST.get("limit")) + offset = int(request.POST.get("offset")) limit = offset + limit - search = request.POST.get('search') - sortName = str(request.POST.get('sortName')) - sortOrder = str(request.POST.get('sortOrder')).lower() + search = request.POST.get("search") + sortName = str(request.POST.get("sortName")) + sortOrder = str(request.POST.get("sortOrder")).lower() # 时间处理 - end_time = datetime.datetime.strptime(end_time, '%Y-%m-%d') + datetime.timedelta(days=1) + end_time = datetime.datetime.strptime( + end_time, "%Y-%m-%d" + ) + datetime.timedelta(days=1) filter_kwargs = {"slowqueryhistory__db_max": db_name} if db_name else {} # 获取慢查数据 - slowsql_obj = SlowQuery.objects.filter( - slowqueryhistory__hostname_max=(instance_info.host + ':' + str(instance_info.port)), - slowqueryhistory__ts_min__range=(start_time, end_time), - fingerprint__icontains=search, - **filter_kwargs - ).annotate(SQLText=F('fingerprint'), SQLId=F('checksum')).values('SQLText', 'SQLId').annotate( - CreateTime=Max('slowqueryhistory__ts_max'), - DBName=Max('slowqueryhistory__db_max'), # 数据库 - QueryTimeAvg=Sum('slowqueryhistory__query_time_sum') / Sum('slowqueryhistory__ts_cnt'), # 平均执行时长 - MySQLTotalExecutionCounts=Sum('slowqueryhistory__ts_cnt'), # 执行总次数 - MySQLTotalExecutionTimes=Sum('slowqueryhistory__query_time_sum'), # 执行总时长 - ParseTotalRowCounts=Sum('slowqueryhistory__rows_examined_sum'), # 扫描总行数 - ReturnTotalRowCounts=Sum('slowqueryhistory__rows_sent_sum'), # 返回总行数 - ParseRowAvg=Sum('slowqueryhistory__rows_examined_sum') / Sum('slowqueryhistory__ts_cnt'), # 平均扫描行数 - ReturnRowAvg=Sum('slowqueryhistory__rows_sent_sum') / Sum('slowqueryhistory__ts_cnt'), # 平均返回行数 + slowsql_obj = ( + SlowQuery.objects.filter( + slowqueryhistory__hostname_max=( + instance_info.host + ":" + str(instance_info.port) + ), + slowqueryhistory__ts_min__range=(start_time, end_time), + fingerprint__icontains=search, + **filter_kwargs + ) + .annotate(SQLText=F("fingerprint"), SQLId=F("checksum")) + .values("SQLText", "SQLId") + .annotate( + CreateTime=Max("slowqueryhistory__ts_max"), + DBName=Max("slowqueryhistory__db_max"), # 数据库 + QueryTimeAvg=Sum("slowqueryhistory__query_time_sum") + / Sum("slowqueryhistory__ts_cnt"), # 平均执行时长 + MySQLTotalExecutionCounts=Sum("slowqueryhistory__ts_cnt"), # 执行总次数 + MySQLTotalExecutionTimes=Sum( + "slowqueryhistory__query_time_sum" + ), # 执行总时长 + ParseTotalRowCounts=Sum("slowqueryhistory__rows_examined_sum"), # 扫描总行数 + ReturnTotalRowCounts=Sum("slowqueryhistory__rows_sent_sum"), # 返回总行数 + ParseRowAvg=Sum("slowqueryhistory__rows_examined_sum") + / Sum("slowqueryhistory__ts_cnt"), # 平均扫描行数 + ReturnRowAvg=Sum("slowqueryhistory__rows_sent_sum") + / Sum("slowqueryhistory__ts_cnt"), # 平均返回行数 + ) ) slow_sql_count = slowsql_obj.count() # 默认“执行总次数”倒序排列 - slow_sql_list = slowsql_obj.order_by('-' + sortName if 'desc'.__eq__(sortOrder) else sortName)[offset:limit] + slow_sql_list = slowsql_obj.order_by( + "-" + sortName if "desc".__eq__(sortOrder) else sortName + )[offset:limit] # QuerySet 序列化 sql_slow_log = [] for SlowLog in slow_sql_list: - SlowLog['QueryTimeAvg'] = round(SlowLog['QueryTimeAvg'], 6) - SlowLog['MySQLTotalExecutionTimes'] = round(SlowLog['MySQLTotalExecutionTimes'], 6) - SlowLog['ParseRowAvg'] = int(SlowLog['ParseRowAvg']) - SlowLog['ReturnRowAvg'] = int(SlowLog['ReturnRowAvg']) + SlowLog["QueryTimeAvg"] = round(SlowLog["QueryTimeAvg"], 6) + SlowLog["MySQLTotalExecutionTimes"] = round( + SlowLog["MySQLTotalExecutionTimes"], 6 + ) + SlowLog["ParseRowAvg"] = int(SlowLog["ParseRowAvg"]) + SlowLog["ReturnRowAvg"] = int(SlowLog["ReturnRowAvg"]) sql_slow_log.append(SlowLog) result = {"total": slow_sql_count, "rows": sql_slow_log} # 返回查询结果 - return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), - content_type='application/json') + return HttpResponse( + json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), + content_type="application/json", + ) # 获取SQL慢日志明细 -@permission_required('sql.menu_slowquery', raise_exception=True) +@permission_required("sql.menu_slowquery", raise_exception=True) def slowquery_review_history(request): - instance_name = request.POST.get('instance_name') + instance_name = request.POST.get("instance_name") # 服务端权限校验 try: - user_instances(request.user, db_type=['mysql']).get(instance_name=instance_name) + user_instances(request.user, db_type=["mysql"]).get(instance_name=instance_name) except Exception: - result = {'status': 1, 'msg': '你所在组未关联该实例', 'data': []} - return HttpResponse(json.dumps(result), content_type='application/json') + result = {"status": 1, "msg": "你所在组未关联该实例", "data": []} + return HttpResponse(json.dumps(result), content_type="application/json") # 判断是RDS还是其他实例 instance_info = Instance.objects.get(instance_name=instance_name) @@ -105,87 +127,115 @@ def slowquery_review_history(request): # 调用阿里云慢日志接口 result = aliyun_rds_slowquery_review_history(request) else: - start_time = request.POST.get('StartTime') - end_time = request.POST.get('EndTime') - db_name = request.POST.get('db_name') - sql_id = request.POST.get('SQLId') - limit = int(request.POST.get('limit')) - offset = int(request.POST.get('offset')) - search = request.POST.get('search') - sortName = str(request.POST.get('sortName')) - sortOrder = str(request.POST.get('sortOrder')).lower() + start_time = request.POST.get("StartTime") + end_time = request.POST.get("EndTime") + db_name = request.POST.get("db_name") + sql_id = request.POST.get("SQLId") + limit = int(request.POST.get("limit")) + offset = int(request.POST.get("offset")) + search = request.POST.get("search") + sortName = str(request.POST.get("sortName")) + sortOrder = str(request.POST.get("sortOrder")).lower() # 时间处理 - end_time = datetime.datetime.strptime(end_time, '%Y-%m-%d') + datetime.timedelta(days=1) + end_time = datetime.datetime.strptime( + end_time, "%Y-%m-%d" + ) + datetime.timedelta(days=1) limit = offset + limit filter_kwargs = {} filter_kwargs.update({"checksum": sql_id}) if sql_id else None - filter_kwargs.update({'db_max': db_name}) if db_name else None + filter_kwargs.update({"db_max": db_name}) if db_name else None # SQLId、DBName非必传 # 获取慢查明细数据 slow_sql_record_obj = SlowQueryHistory.objects.filter( - hostname_max=(instance_info.host + ':' + str(instance_info.port)), + hostname_max=(instance_info.host + ":" + str(instance_info.port)), ts_min__range=(start_time, end_time), sample__icontains=search, **filter_kwargs - ).annotate(ExecutionStartTime=F('ts_min'), # 本次统计(每5分钟一次)该类型sql语句出现的最小时间 - DBName=F('db_max'), # 数据库名 - HostAddress=Concat(V('\''), 'user_max', V('\''), V('@'), V('\''), 'client_max', V('\'')), # 用户名 - SQLText=F('sample'), # SQL语句 - TotalExecutionCounts=F('ts_cnt'), # 本次统计该sql语句出现的次数 - QueryTimePct95=F('query_time_pct_95'), # 本次统计该sql语句95%耗时 - QueryTimes=F('query_time_sum'), # 本次统计该sql语句花费的总时间(秒) - LockTimes=F('lock_time_sum'), # 本次统计该sql语句锁定总时长(秒) - ParseRowCounts=F('rows_examined_sum'), # 本次统计该sql语句解析总行数 - ReturnRowCounts=F('rows_sent_sum') # 本次统计该sql语句返回总行数 - ) + ).annotate( + ExecutionStartTime=F("ts_min"), # 本次统计(每5分钟一次)该类型sql语句出现的最小时间 + DBName=F("db_max"), # 数据库名 + HostAddress=Concat( + V("'"), "user_max", V("'"), V("@"), V("'"), "client_max", V("'") + ), # 用户名 + SQLText=F("sample"), # SQL语句 + TotalExecutionCounts=F("ts_cnt"), # 本次统计该sql语句出现的次数 + QueryTimePct95=F("query_time_pct_95"), # 本次统计该sql语句95%耗时 + QueryTimes=F("query_time_sum"), # 本次统计该sql语句花费的总时间(秒) + LockTimes=F("lock_time_sum"), # 本次统计该sql语句锁定总时长(秒) + ParseRowCounts=F("rows_examined_sum"), # 本次统计该sql语句解析总行数 + ReturnRowCounts=F("rows_sent_sum"), # 本次统计该sql语句返回总行数 + ) slow_sql_record_count = slow_sql_record_obj.count() - slow_sql_record_list = slow_sql_record_obj.order_by('-' + sortName if 'desc'.__eq__(sortOrder) else sortName)[ - offset:limit].values('ExecutionStartTime', 'DBName', 'HostAddress', - 'SQLText', - 'TotalExecutionCounts', 'QueryTimePct95', - 'QueryTimes', 'LockTimes', 'ParseRowCounts', - 'ReturnRowCounts' - ) + slow_sql_record_list = slow_sql_record_obj.order_by( + "-" + sortName if "desc".__eq__(sortOrder) else sortName + )[offset:limit].values( + "ExecutionStartTime", + "DBName", + "HostAddress", + "SQLText", + "TotalExecutionCounts", + "QueryTimePct95", + "QueryTimes", + "LockTimes", + "ParseRowCounts", + "ReturnRowCounts", + ) # QuerySet 序列化 sql_slow_record = [] for SlowRecord in slow_sql_record_list: - SlowRecord['QueryTimePct95'] = round(SlowRecord['QueryTimePct95'], 6) - SlowRecord['QueryTimes'] = round(SlowRecord['QueryTimes'], 6) - SlowRecord['LockTimes'] = round(SlowRecord['LockTimes'], 6) + SlowRecord["QueryTimePct95"] = round(SlowRecord["QueryTimePct95"], 6) + SlowRecord["QueryTimes"] = round(SlowRecord["QueryTimes"], 6) + SlowRecord["LockTimes"] = round(SlowRecord["LockTimes"], 6) sql_slow_record.append(SlowRecord) result = {"total": slow_sql_record_count, "rows": sql_slow_record} # 返回查询结果 - return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), - content_type='application/json') + return HttpResponse( + json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), + content_type="application/json", + ) @cache_page(60 * 10) def report(request): """返回慢SQL历史趋势""" - checksum = request.GET.get('checksum') + checksum = request.GET.get("checksum") cnt_data = ChartDao().slow_query_review_history_by_cnt(checksum) pct_data = ChartDao().slow_query_review_history_by_pct_95_time(checksum) - cnt_x_data = [row[1] for row in cnt_data['rows']] - cnt_y_data = [int(row[0]) for row in cnt_data['rows']] - pct_y_data = [str(row[0]) for row in pct_data['rows']] - line = Line(init_opts=opts.InitOpts(width='800', height='380px')) + cnt_x_data = [row[1] for row in cnt_data["rows"]] + cnt_y_data = [int(row[0]) for row in cnt_data["rows"]] + pct_y_data = [str(row[0]) for row in pct_data["rows"]] + line = Line(init_opts=opts.InitOpts(width="800", height="380px")) line.add_xaxis(cnt_x_data) - line.add_yaxis("慢查次数", cnt_y_data, is_smooth=True, - markline_opts=opts.MarkLineOpts(data=[opts.MarkLineItem(type_="max", name='最大值'), - opts.MarkLineItem(type_="average", name='平均值')])) + line.add_yaxis( + "慢查次数", + cnt_y_data, + is_smooth=True, + markline_opts=opts.MarkLineOpts( + data=[ + opts.MarkLineItem(type_="max", name="最大值"), + opts.MarkLineItem(type_="average", name="平均值"), + ] + ), + ) line.add_yaxis("慢查时长(95%)", pct_y_data, is_smooth=True, is_symbol_show=False) - line.set_series_opts(areastyle_opts=opts.AreaStyleOpts(opacity=0.5, )) - line.set_global_opts(title_opts=opts.TitleOpts(title='SQL历史趋势'), - legend_opts=opts.LegendOpts(selected_mode='single'), - xaxis_opts=opts.AxisOpts( - axistick_opts=opts.AxisTickOpts(is_align_with_label=True), - is_scale=False, - boundary_gap=False, - ), ) - - result = {"status": 0, "msg": '', "data": line.render_embed()} - return HttpResponse(json.dumps(result), content_type='application/json') + line.set_series_opts( + areastyle_opts=opts.AreaStyleOpts( + opacity=0.5, + ) + ) + line.set_global_opts( + title_opts=opts.TitleOpts(title="SQL历史趋势"), + legend_opts=opts.LegendOpts(selected_mode="single"), + xaxis_opts=opts.AxisOpts( + axistick_opts=opts.AxisTickOpts(is_align_with_label=True), + is_scale=False, + boundary_gap=False, + ), + ) + + result = {"status": 0, "msg": "", "data": line.render_embed()} + return HttpResponse(json.dumps(result), content_type="application/json") diff --git a/sql/sql_analyze.py b/sql/sql_analyze.py index c61fa9ece2..15b7bfdc1c 100644 --- a/sql/sql_analyze.py +++ b/sql/sql_analyze.py @@ -16,66 +16,76 @@ from common.utils.extend_json_encoder import ExtendJSONEncoder from .models import Instance -__author__ = 'hhyo' +__author__ = "hhyo" -@permission_required('sql.sql_analyze', raise_exception=True) +@permission_required("sql.sql_analyze", raise_exception=True) def generate(request): """ 解析上传文件为SQL列表 :param request: :return: """ - text = request.POST.get('text') + text = request.POST.get("text") if text is None: result = {"total": 0, "rows": []} else: rows = generate_sql(text) result = {"total": len(rows), "rows": rows} - return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), - content_type='application/json') + return HttpResponse( + json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), + content_type="application/json", + ) -@permission_required('sql.sql_analyze', raise_exception=True) +@permission_required("sql.sql_analyze", raise_exception=True) def analyze(request): """ 利用soar分析SQL :param request: :return: """ - text = request.POST.get('text') - instance_name = request.POST.get('instance_name') - db_name = request.POST.get('db_name') + text = request.POST.get("text") + instance_name = request.POST.get("instance_name") + db_name = request.POST.get("db_name") if not text: result = {"total": 0, "rows": []} else: soar = Soar() - if instance_name != '' and db_name != '': + if instance_name != "" and db_name != "": try: - instance_info = user_instances(request.user, db_type=['mysql']).get(instance_name=instance_name) + instance_info = user_instances(request.user, db_type=["mysql"]).get( + instance_name=instance_name + ) except Instance.DoesNotExist: - return JsonResponse({'status': 1, 'msg': '你所在组未关联该实例!', 'data': []}) - soar_test_dsn = SysConfig().get('soar_test_dsn') + return JsonResponse({"status": 1, "msg": "你所在组未关联该实例!", "data": []}) + soar_test_dsn = SysConfig().get("soar_test_dsn") # 获取实例连接信息 - online_dsn = "{user}:{pwd}@{host}:{port}/{db}".format(user=instance_info.user, - pwd=instance_info.password, - host=instance_info.host, - port=instance_info.port, - db=db_name) + online_dsn = "{user}:{pwd}@{host}:{port}/{db}".format( + user=instance_info.user, + pwd=instance_info.password, + host=instance_info.host, + port=instance_info.port, + db=db_name, + ) else: - online_dsn = '' - soar_test_dsn = '' - args = {"report-type": "markdown", - "query": '', - "online-dsn": online_dsn, - "test-dsn": soar_test_dsn, - "allow-online-as-test": "false"} + online_dsn = "" + soar_test_dsn = "" + args = { + "report-type": "markdown", + "query": "", + "online-dsn": online_dsn, + "test-dsn": soar_test_dsn, + "allow-online-as-test": "false", + } rows = generate_sql(text) for row in rows: - args['query'] = row['sql'] + args["query"] = row["sql"] cmd_args = soar.generate_args2cmd(args=args, shell=True) stdout, stderr = soar.execute_cmd(cmd_args, shell=True).communicate() - row['report'] = stdout if stdout else stderr + row["report"] = stdout if stdout else stderr result = {"total": len(rows), "rows": rows} - return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), - content_type='application/json') + return HttpResponse( + json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), + content_type="application/json", + ) diff --git a/sql/sql_optimize.py b/sql/sql_optimize.py index c3cd89404e..75488f2c43 100644 --- a/sql/sql_optimize.py +++ b/sql/sql_optimize.py @@ -21,166 +21,182 @@ from sql.sql_tuning import SqlTuning from sql.utils.resource_group import user_instances -__author__ = 'hhyo' +__author__ = "hhyo" -@permission_required('sql.optimize_sqladvisor', raise_exception=True) +@permission_required("sql.optimize_sqladvisor", raise_exception=True) def optimize_sqladvisor(request): - sql_content = request.POST.get('sql_content') - instance_name = request.POST.get('instance_name') - db_name = request.POST.get('db_name') - verbose = request.POST.get('verbose', 1) - result = {'status': 0, 'msg': 'ok', 'data': []} + sql_content = request.POST.get("sql_content") + instance_name = request.POST.get("instance_name") + db_name = request.POST.get("db_name") + verbose = request.POST.get("verbose", 1) + result = {"status": 0, "msg": "ok", "data": []} # 服务器端参数验证 if sql_content is None or instance_name is None: - result['status'] = 1 - result['msg'] = '页面提交参数可能为空' - return HttpResponse(json.dumps(result), content_type='application/json') + result["status"] = 1 + result["msg"] = "页面提交参数可能为空" + return HttpResponse(json.dumps(result), content_type="application/json") try: - instance_info = user_instances(request.user, db_type=['mysql']).get(instance_name=instance_name) + instance_info = user_instances(request.user, db_type=["mysql"]).get( + instance_name=instance_name + ) except Instance.DoesNotExist: - result['status'] = 1 - result['msg'] = '你所在组未关联该实例!' - return HttpResponse(json.dumps(result), content_type='application/json') + result["status"] = 1 + result["msg"] = "你所在组未关联该实例!" + return HttpResponse(json.dumps(result), content_type="application/json") # 检查sqladvisor程序路径 - sqladvisor_path = SysConfig().get('sqladvisor') + sqladvisor_path = SysConfig().get("sqladvisor") if sqladvisor_path is None: - result['status'] = 1 - result['msg'] = '请配置SQLAdvisor路径!' - return HttpResponse(json.dumps(result), content_type='application/json') + result["status"] = 1 + result["msg"] = "请配置SQLAdvisor路径!" + return HttpResponse(json.dumps(result), content_type="application/json") # 提交给sqladvisor获取分析报告 sqladvisor = SQLAdvisor() # 准备参数 - args = {"h": instance_info.host, - "P": instance_info.port, - "u": instance_info.user, - "p": instance_info.password, - "d": db_name, - "v": verbose, - "q": sql_content.strip() - } + args = { + "h": instance_info.host, + "P": instance_info.port, + "u": instance_info.user, + "p": instance_info.password, + "d": db_name, + "v": verbose, + "q": sql_content.strip(), + } # 参数检查 args_check_result = sqladvisor.check_args(args) - if args_check_result['status'] == 1: - return HttpResponse(json.dumps(args_check_result), content_type='application/json') + if args_check_result["status"] == 1: + return HttpResponse( + json.dumps(args_check_result), content_type="application/json" + ) # 参数转换 cmd_args = sqladvisor.generate_args2cmd(args, shell=True) # 执行命令 try: stdout, stderr = sqladvisor.execute_cmd(cmd_args, shell=True).communicate() - result['data'] = f'{stdout}{stderr}' + result["data"] = f"{stdout}{stderr}" except RuntimeError as e: - result['status'] = 1 - result['msg'] = str(e) - return HttpResponse(json.dumps(result), content_type='application/json') + result["status"] = 1 + result["msg"] = str(e) + return HttpResponse(json.dumps(result), content_type="application/json") -@permission_required('sql.optimize_soar', raise_exception=True) +@permission_required("sql.optimize_soar", raise_exception=True) def optimize_soar(request): - instance_name = request.POST.get('instance_name') - db_name = request.POST.get('db_name') - sql = request.POST.get('sql') - result = {'status': 0, 'msg': 'ok', 'data': []} + instance_name = request.POST.get("instance_name") + db_name = request.POST.get("db_name") + sql = request.POST.get("sql") + result = {"status": 0, "msg": "ok", "data": []} # 服务器端参数验证 if not (instance_name and db_name and sql): - result['status'] = 1 - result['msg'] = '页面提交参数可能为空' - return HttpResponse(json.dumps(result), content_type='application/json') + result["status"] = 1 + result["msg"] = "页面提交参数可能为空" + return HttpResponse(json.dumps(result), content_type="application/json") try: - instance_info = user_instances(request.user, db_type=['mysql']).get(instance_name=instance_name) + instance_info = user_instances(request.user, db_type=["mysql"]).get( + instance_name=instance_name + ) except Exception: - result['status'] = 1 - result['msg'] = '你所在组未关联该实例' - return HttpResponse(json.dumps(result), content_type='application/json') + result["status"] = 1 + result["msg"] = "你所在组未关联该实例" + return HttpResponse(json.dumps(result), content_type="application/json") # 检查测试实例的连接信息和soar程序路径 - soar_test_dsn = SysConfig().get('soar_test_dsn') - soar_path = SysConfig().get('soar') + soar_test_dsn = SysConfig().get("soar_test_dsn") + soar_path = SysConfig().get("soar") if not (soar_path and soar_test_dsn): - result['status'] = 1 - result['msg'] = '请配置soar_path和test_dsn!' - return HttpResponse(json.dumps(result), content_type='application/json') + result["status"] = 1 + result["msg"] = "请配置soar_path和test_dsn!" + return HttpResponse(json.dumps(result), content_type="application/json") # 目标实例的连接信息 - online_dsn = '{user}:"{pwd}"@{host}:{port}/{db}'.format(user=instance_info.user, - pwd=instance_info.password, - host=instance_info.host, - port=instance_info.port, - db=db_name) + online_dsn = '{user}:"{pwd}"@{host}:{port}/{db}'.format( + user=instance_info.user, + pwd=instance_info.password, + host=instance_info.host, + port=instance_info.port, + db=db_name, + ) # 提交给soar获取分析报告 soar = Soar() # 准备参数 - args = {"online-dsn": online_dsn, - "test-dsn": soar_test_dsn, - "allow-online-as-test": "false", - "report-type": "markdown", - "query": sql.strip() - } + args = { + "online-dsn": online_dsn, + "test-dsn": soar_test_dsn, + "allow-online-as-test": "false", + "report-type": "markdown", + "query": sql.strip(), + } # 参数检查 args_check_result = soar.check_args(args) - if args_check_result['status'] == 1: - return HttpResponse(json.dumps(args_check_result), content_type='application/json') + if args_check_result["status"] == 1: + return HttpResponse( + json.dumps(args_check_result), content_type="application/json" + ) # 参数转换 cmd_args = soar.generate_args2cmd(args, shell=True) # 执行命令 try: stdout, stderr = soar.execute_cmd(cmd_args, shell=True).communicate() - result['data'] = stdout if stdout else stderr + result["data"] = stdout if stdout else stderr except RuntimeError as e: - result['status'] = 1 - result['msg'] = str(e) - return HttpResponse(json.dumps(result), content_type='application/json') + result["status"] = 1 + result["msg"] = str(e) + return HttpResponse(json.dumps(result), content_type="application/json") -@permission_required('sql.optimize_sqltuning', raise_exception=True) +@permission_required("sql.optimize_sqltuning", raise_exception=True) def optimize_sqltuning(request): - instance_name = request.POST.get('instance_name') - db_name = request.POST.get('db_name') - sqltext = request.POST.get('sql_content') - option = request.POST.getlist('option[]') + instance_name = request.POST.get("instance_name") + db_name = request.POST.get("db_name") + sqltext = request.POST.get("sql_content") + option = request.POST.getlist("option[]") sqltext = sqlparse.format(sqltext, strip_comments=True) sqltext = sqlparse.split(sqltext)[0] if re.match(r"^select|^show|^explain", sqltext, re.I) is None: - result = {'status': 1, 'msg': '只支持查询SQL!', 'data': []} - return HttpResponse(json.dumps(result), content_type='application/json') + result = {"status": 1, "msg": "只支持查询SQL!", "data": []} + return HttpResponse(json.dumps(result), content_type="application/json") try: user_instances(request.user).get(instance_name=instance_name) except Instance.DoesNotExist: - result = {'status': 1, 'msg': '你所在组未关联该实例!', 'data': []} - return HttpResponse(json.dumps(result), content_type='application/json') + result = {"status": 1, "msg": "你所在组未关联该实例!", "data": []} + return HttpResponse(json.dumps(result), content_type="application/json") # escape - db_name = MySQLdb.escape_string(db_name).decode('utf-8') + db_name = MySQLdb.escape_string(db_name).decode("utf-8") - sql_tunning = SqlTuning(instance_name=instance_name, db_name=db_name, sqltext=sqltext) - result = {'status': 0, 'msg': 'ok', 'data': {}} - if 'sys_parm' in option: + sql_tunning = SqlTuning( + instance_name=instance_name, db_name=db_name, sqltext=sqltext + ) + result = {"status": 0, "msg": "ok", "data": {}} + if "sys_parm" in option: basic_information = sql_tunning.basic_information() sys_parameter = sql_tunning.sys_parameter() optimizer_switch = sql_tunning.optimizer_switch() - result['data']['basic_information'] = basic_information - result['data']['sys_parameter'] = sys_parameter - result['data']['optimizer_switch'] = optimizer_switch - if 'sql_plan' in option: + result["data"]["basic_information"] = basic_information + result["data"]["sys_parameter"] = sys_parameter + result["data"]["optimizer_switch"] = optimizer_switch + if "sql_plan" in option: plan, optimizer_rewrite_sql = sql_tunning.sqlplan() - result['data']['optimizer_rewrite_sql'] = optimizer_rewrite_sql - result['data']['plan'] = plan - if 'obj_stat' in option: - result['data']['object_statistics'] = sql_tunning.object_statistics() - if 'sql_profile' in option: + result["data"]["optimizer_rewrite_sql"] = optimizer_rewrite_sql + result["data"]["plan"] = plan + if "obj_stat" in option: + result["data"]["object_statistics"] = sql_tunning.object_statistics() + if "sql_profile" in option: session_status = sql_tunning.exec_sql() - result['data']['session_status'] = session_status + result["data"]["session_status"] = session_status # 关闭连接 sql_tunning.engine.close() - result['data']['sqltext'] = sqltext - return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), - content_type='application/json') + result["data"]["sqltext"] = sqltext + return HttpResponse( + json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), + content_type="application/json", + ) def explain(request): @@ -189,46 +205,48 @@ def explain(request): :param request: :return: """ - sql_content = request.POST.get('sql_content') - instance_name = request.POST.get('instance_name') - db_name = request.POST.get('db_name') - result = {'status': 0, 'msg': 'ok', 'data': []} + sql_content = request.POST.get("sql_content") + instance_name = request.POST.get("instance_name") + db_name = request.POST.get("db_name") + result = {"status": 0, "msg": "ok", "data": []} # 服务器端参数验证 if sql_content is None or instance_name is None: - result['status'] = 1 - result['msg'] = '页面提交参数可能为空' - return HttpResponse(json.dumps(result), content_type='application/json') + result["status"] = 1 + result["msg"] = "页面提交参数可能为空" + return HttpResponse(json.dumps(result), content_type="application/json") try: instance = user_instances(request.user).get(instance_name=instance_name) except Instance.DoesNotExist: - result = {'status': 1, 'msg': '实例不存在', 'data': []} - return HttpResponse(json.dumps(result), content_type='application/json') + result = {"status": 1, "msg": "实例不存在", "data": []} + return HttpResponse(json.dumps(result), content_type="application/json") # 删除注释语句,进行语法判断,执行第一条有效sql sql_content = sqlparse.format(sql_content.strip(), strip_comments=True) try: sql_content = sqlparse.split(sql_content)[0] except IndexError: - result['status'] = 1 - result['msg'] = '没有有效的SQL语句' - return HttpResponse(json.dumps(result), content_type='application/json') + result["status"] = 1 + result["msg"] = "没有有效的SQL语句" + return HttpResponse(json.dumps(result), content_type="application/json") else: # 过滤非explain的语句 if not re.match(r"^explain", sql_content, re.I): - result['status'] = 1 - result['msg'] = '仅支持explain开头的语句,请检查' - return HttpResponse(json.dumps(result), content_type='application/json') + result["status"] = 1 + result["msg"] = "仅支持explain开头的语句,请检查" + return HttpResponse(json.dumps(result), content_type="application/json") # 执行获取执行计划语句 query_engine = get_engine(instance=instance) sql_result = query_engine.query(str(db_name), sql_content).to_sep_dict() - result['data'] = sql_result + result["data"] = sql_result # 返回查询结果 - return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), - content_type='application/json') + return HttpResponse( + json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), + content_type="application/json", + ) def optimize_sqltuningadvisor(request): @@ -237,45 +255,47 @@ def optimize_sqltuningadvisor(request): :param request: :return: """ - sql_content = request.POST.get('sql_content') - instance_name = request.POST.get('instance_name') - db_name = request.POST.get('schema_name') - result = {'status': 0, 'msg': 'ok', 'data': []} + sql_content = request.POST.get("sql_content") + instance_name = request.POST.get("instance_name") + db_name = request.POST.get("schema_name") + result = {"status": 0, "msg": "ok", "data": []} # 服务器端参数验证 if sql_content is None or instance_name is None: - result['status'] = 1 - result['msg'] = '页面提交参数可能为空' - return HttpResponse(json.dumps(result), content_type='application/json') + result["status"] = 1 + result["msg"] = "页面提交参数可能为空" + return HttpResponse(json.dumps(result), content_type="application/json") try: instance = user_instances(request.user).get(instance_name=instance_name) except Instance.DoesNotExist: - result = {'status': 1, 'msg': '实例不存在', 'data': []} - return HttpResponse(json.dumps(result), content_type='application/json') + result = {"status": 1, "msg": "实例不存在", "data": []} + return HttpResponse(json.dumps(result), content_type="application/json") # 不删除注释语句,已获取加hints的SQL优化建议,进行语法判断,执行第一条有效sql sql_content = sqlparse.format(sql_content.strip(), strip_comments=False) # 对单引号加转义符,支持plsql语法 - sql_content = sql_content.replace("'", "''"); + sql_content = sql_content.replace("'", "''") try: sql_content = sqlparse.split(sql_content)[0] except IndexError: - result['status'] = 1 - result['msg'] = '没有有效的SQL语句' - return HttpResponse(json.dumps(result), content_type='application/json') + result["status"] = 1 + result["msg"] = "没有有效的SQL语句" + return HttpResponse(json.dumps(result), content_type="application/json") else: # 过滤非Oracle语句 - if not instance.db_type == 'oracle': - result['status'] = 1 - result['msg'] = 'SQLTuningAdvisor仅支持oracle数据库的检查' - return HttpResponse(json.dumps(result), content_type='application/json') + if not instance.db_type == "oracle": + result["status"] = 1 + result["msg"] = "SQLTuningAdvisor仅支持oracle数据库的检查" + return HttpResponse(json.dumps(result), content_type="application/json") # 执行获取优化报告 query_engine = get_engine(instance=instance) sql_result = query_engine.sqltuningadvisor(str(db_name), sql_content).to_sep_dict() - result['data'] = sql_result + result["data"] = sql_result # 返回查询结果 - return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), - content_type='application/json') + return HttpResponse( + json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), + content_type="application/json", + ) diff --git a/sql/sql_tuning.py b/sql/sql_tuning.py index 81251dea42..973406cba9 100644 --- a/sql/sql_tuning.py +++ b/sql/sql_tuning.py @@ -15,19 +15,21 @@ def __init__(self, instance_name, db_name, sqltext): self.engine = query_engine self.db_name = db_name self.sqltext = sqltext - self.sql_variable = ''' + self.sql_variable = """ select lower(variable_name), variable_value from performance_schema.global_variables where upper(variable_name) in ('%s') - order by variable_name;''' % ('\',\''.join(SQLTuning.SYS_PARM_FILTER)) - self.sql_optimizer_switch = ''' + order by variable_name;""" % ( + "','".join(SQLTuning.SYS_PARM_FILTER) + ) + self.sql_optimizer_switch = """ select variable_value from performance_schema.global_variables where upper(variable_name) = 'OPTIMIZER_SWITCH'; - ''' - self.sql_table_info = ''' + """ + self.sql_table_info = """ select table_name, engine, @@ -39,8 +41,8 @@ def __init__(self, instance_name, db_name, sqltext): round((index_length) / 1024 / 1024, 2) as index_mb from information_schema.tables where table_schema = '%s' and table_name = '%s' - ''' - self.sql_table_index = ''' + """ + self.sql_table_index = """ select table_name, index_name, @@ -54,11 +56,11 @@ def __init__(self, instance_name, db_name, sqltext): from information_schema.statistics where table_schema = '%s' and table_name = '%s' order by 1, 3; - ''' + """ def __extract_tables(self): """获取sql语句中的表名""" - return [i['name'].strip('`') for i in extract_tables(self.sqltext)] + return [i["name"].strip("`") for i in extract_tables(self.sqltext)] def basic_information(self): return self.engine.query(sql="select @@version", close_conn=False).to_sep_dict() @@ -67,7 +69,7 @@ def sys_parameter(self): # 获取mysql版本信息 server_version = self.engine.server_version if server_version < (5, 7, 0): - sql = self.sql_variable.replace('performance_schema', 'information_schema') + sql = self.sql_variable.replace("performance_schema", "information_schema") else: sql = self.sql_variable return self.engine.query(sql=sql, close_conn=False).to_sep_dict() @@ -76,41 +78,57 @@ def optimizer_switch(self): # 获取mysql版本信息 server_version = self.engine.server_version if server_version < (5, 7, 0): - sql = self.sql_optimizer_switch.replace('performance_schema', 'information_schema') + sql = self.sql_optimizer_switch.replace( + "performance_schema", "information_schema" + ) else: sql = self.sql_optimizer_switch return self.engine.query(sql=sql, close_conn=False).to_sep_dict() def sqlplan(self): - plan = self.engine.query(self.db_name, "explain " + self.sqltext, close_conn=False).to_sep_dict() - optimizer_rewrite_sql = self.engine.query(sql="show warnings", close_conn=False).to_sep_dict() + plan = self.engine.query( + self.db_name, "explain " + self.sqltext, close_conn=False + ).to_sep_dict() + optimizer_rewrite_sql = self.engine.query( + sql="show warnings", close_conn=False + ).to_sep_dict() return plan, optimizer_rewrite_sql # 获取关联表信息存在缺陷,只能获取到一张表 def object_statistics(self): object_statistics = [] for index, table_name in enumerate(self.__extract_tables()): - object_statistics.append({ - "structure": self.engine.query( - db_name=self.db_name, sql=f"show create table `{table_name}`;", - close_conn=False).to_sep_dict(), - "table_info": self.engine.query( - sql=self.sql_table_info % (self.db_name, table_name), - close_conn=False).to_sep_dict(), - "index_info": self.engine.query( - sql=self.sql_table_index % (self.db_name, table_name), - close_conn=False).to_sep_dict() - }) + object_statistics.append( + { + "structure": self.engine.query( + db_name=self.db_name, + sql=f"show create table `{table_name}`;", + close_conn=False, + ).to_sep_dict(), + "table_info": self.engine.query( + sql=self.sql_table_info % (self.db_name, table_name), + close_conn=False, + ).to_sep_dict(), + "index_info": self.engine.query( + sql=self.sql_table_index % (self.db_name, table_name), + close_conn=False, + ).to_sep_dict(), + } + ) return object_statistics def exec_sql(self): - result = {"EXECUTE_TIME": 0, - "BEFORE_STATUS": {'column_list': [], 'rows': []}, - "AFTER_STATUS": {'column_list': [], 'rows': []}, - "SESSION_STATUS(DIFFERENT)": {'column_list': ['status_name', 'before', 'after', 'diff'], 'rows': []}, - "PROFILING_DETAIL": {'column_list': [], 'rows': []}, - "PROFILING_SUMMARY": {'column_list': [], 'rows': []} - } + result = { + "EXECUTE_TIME": 0, + "BEFORE_STATUS": {"column_list": [], "rows": []}, + "AFTER_STATUS": {"column_list": [], "rows": []}, + "SESSION_STATUS(DIFFERENT)": { + "column_list": ["status_name", "before", "after", "diff"], + "rows": [], + }, + "PROFILING_DETAIL": {"column_list": [], "rows": []}, + "PROFILING_SUMMARY": {"column_list": [], "rows": []}, + } sql_profiling = """select concat(upper(left(variable_name,1)), substring(lower(variable_name), 2, @@ -121,43 +139,60 @@ def exec_sql(self): # 获取mysql版本信息 server_version = self.engine.server_version if server_version < (5, 7, 0): - sql = sql_profiling.replace('performance_schema', 'information_schema') + sql = sql_profiling.replace("performance_schema", "information_schema") else: sql = sql_profiling self.engine.query(sql="set profiling=1", close_conn=False).to_sep_dict() - records = self.engine.query(sql="select ifnull(max(query_id),0) from INFORMATION_SCHEMA.PROFILING", - close_conn=False).to_sep_dict() - query_id = records['rows'][0][0] + 3 # skip next sql + records = self.engine.query( + sql="select ifnull(max(query_id),0) from INFORMATION_SCHEMA.PROFILING", + close_conn=False, + ).to_sep_dict() + query_id = records["rows"][0][0] + 3 # skip next sql # 获取执行前信息 - result['BEFORE_STATUS'] = self.engine.query(sql=sql, close_conn=False).to_sep_dict() + result["BEFORE_STATUS"] = self.engine.query( + sql=sql, close_conn=False + ).to_sep_dict() # 执行查询语句,统计执行时间 t_start = time.time() self.engine.query(sql=self.sqltext, close_conn=False).to_sep_dict() t_end = time.time() cost_time = "%5s" % "{:.4f}".format(t_end - t_start) - result['EXECUTE_TIME'] = cost_time + result["EXECUTE_TIME"] = cost_time # 获取执行后信息 - result['AFTER_STATUS'] = self.engine.query(sql=sql, close_conn=False).to_sep_dict() + result["AFTER_STATUS"] = self.engine.query( + sql=sql, close_conn=False + ).to_sep_dict() # 获取PROFILING_DETAIL信息 - result['PROFILING_DETAIL'] = self.engine.query( - sql="select STATE,DURATION,CPU_USER,CPU_SYSTEM,BLOCK_OPS_IN,BLOCK_OPS_OUT ,MESSAGES_SENT ,MESSAGES_RECEIVED ,PAGE_FAULTS_MAJOR ,PAGE_FAULTS_MINOR ,SWAPS from INFORMATION_SCHEMA.PROFILING where query_id=" + str( - query_id) + " order by seq", close_conn=False).to_sep_dict() - result['PROFILING_SUMMARY'] = self.engine.query( - sql="SELECT STATE,SUM(DURATION) AS Total_R,ROUND(100*SUM(DURATION)/(SELECT SUM(DURATION) FROM INFORMATION_SCHEMA.PROFILING WHERE QUERY_ID=" + str( - query_id) + "),2) AS Pct_R,COUNT(*) AS Calls,SUM(DURATION)/COUNT(*) AS R_Call FROM INFORMATION_SCHEMA.PROFILING WHERE QUERY_ID=" + str( - query_id) + " GROUP BY STATE ORDER BY Total_R DESC", close_conn=False).to_sep_dict() + result["PROFILING_DETAIL"] = self.engine.query( + sql="select STATE,DURATION,CPU_USER,CPU_SYSTEM,BLOCK_OPS_IN,BLOCK_OPS_OUT ,MESSAGES_SENT ,MESSAGES_RECEIVED ,PAGE_FAULTS_MAJOR ,PAGE_FAULTS_MINOR ,SWAPS from INFORMATION_SCHEMA.PROFILING where query_id=" + + str(query_id) + + " order by seq", + close_conn=False, + ).to_sep_dict() + result["PROFILING_SUMMARY"] = self.engine.query( + sql="SELECT STATE,SUM(DURATION) AS Total_R,ROUND(100*SUM(DURATION)/(SELECT SUM(DURATION) FROM INFORMATION_SCHEMA.PROFILING WHERE QUERY_ID=" + + str(query_id) + + "),2) AS Pct_R,COUNT(*) AS Calls,SUM(DURATION)/COUNT(*) AS R_Call FROM INFORMATION_SCHEMA.PROFILING WHERE QUERY_ID=" + + str(query_id) + + " GROUP BY STATE ORDER BY Total_R DESC", + close_conn=False, + ).to_sep_dict() # 处理执行前后对比信息 - before_status_rows = [list(item) for item in result['BEFORE_STATUS']['rows']] - after_status_rows = [list(item) for item in result['AFTER_STATUS']['rows']] + before_status_rows = [list(item) for item in result["BEFORE_STATUS"]["rows"]] + after_status_rows = [list(item) for item in result["AFTER_STATUS"]["rows"]] for index, item in enumerate(before_status_rows): if before_status_rows[index][1] != after_status_rows[index][1]: before_status_rows[index].append(after_status_rows[index][1]) before_status_rows[index].append( - str(float(after_status_rows[index][1]) - float(before_status_rows[index][1]))) + str( + float(after_status_rows[index][1]) + - float(before_status_rows[index][1]) + ) + ) diff_rows = [item for item in before_status_rows if len(item) == 4] - result['SESSION_STATUS(DIFFERENT)']['rows'] = diff_rows + result["SESSION_STATUS(DIFFERENT)"]["rows"] = diff_rows return result diff --git a/sql/sql_workflow.py b/sql/sql_workflow.py index 4a38f38545..322c4d67c1 100644 --- a/sql/sql_workflow.py +++ b/sql/sql_workflow.py @@ -21,172 +21,204 @@ from sql.models import ResourceGroup from sql.utils.resource_group import user_groups, user_instances from sql.utils.tasks import add_sql_schedule, del_schedule -from sql.utils.sql_review import can_timingtask, can_cancel, can_execute, on_correct_time_period, can_view, can_rollback +from sql.utils.sql_review import ( + can_timingtask, + can_cancel, + can_execute, + on_correct_time_period, + can_view, + can_rollback, +) from sql.utils.workflow_audit import Audit from .models import SqlWorkflow, SqlWorkflowContent, Instance from django_q.tasks import async_task from sql.engines import get_engine -logger = logging.getLogger('default') +logger = logging.getLogger("default") -@permission_required('sql.menu_sqlworkflow', raise_exception=True) +@permission_required("sql.menu_sqlworkflow", raise_exception=True) def sql_workflow_list(request): return _sql_workflow_list(request) -@permission_required('sql.audit_user', raise_exception=True) + +@permission_required("sql.audit_user", raise_exception=True) def sql_workflow_list_audit(request): return _sql_workflow_list(request) + def _sql_workflow_list(request): """ 获取审核列表 :param request: :return: """ - nav_status = request.POST.get('navStatus') - instance_id = request.POST.get('instance_id') - resource_group_id = request.POST.get('group_id') - start_date = request.POST.get('start_date') - end_date = request.POST.get('end_date') - limit = int(request.POST.get('limit',0)) - offset = int(request.POST.get('offset',0)) + nav_status = request.POST.get("navStatus") + instance_id = request.POST.get("instance_id") + resource_group_id = request.POST.get("group_id") + start_date = request.POST.get("start_date") + end_date = request.POST.get("end_date") + limit = int(request.POST.get("limit", 0)) + offset = int(request.POST.get("offset", 0)) limit = offset + limit limit = limit if limit else None - search = request.POST.get('search') + search = request.POST.get("search") user = request.user # 组合筛选项 filter_dict = dict() # 工单状态 if nav_status: - filter_dict['status'] = nav_status + filter_dict["status"] = nav_status # 实例 if instance_id: - filter_dict['instance_id'] = instance_id + filter_dict["instance_id"] = instance_id # 资源组 if resource_group_id: - filter_dict['group_id'] = resource_group_id + filter_dict["group_id"] = resource_group_id # 时间 if start_date and end_date: - end_date = datetime.datetime.strptime(end_date, '%Y-%m-%d') + datetime.timedelta(days=1) - filter_dict['create_time__range'] = (start_date, end_date) + end_date = datetime.datetime.strptime( + end_date, "%Y-%m-%d" + ) + datetime.timedelta(days=1) + filter_dict["create_time__range"] = (start_date, end_date) # 管理员,审计员,可查看所有工单 - if user.is_superuser or user.has_perm('sql.audit_user'): + if user.is_superuser or user.has_perm("sql.audit_user"): pass # 非管理员,拥有审核权限、资源组粒度执行权限的,可以查看组内所有工单 - elif user.has_perm('sql.sql_review') or user.has_perm('sql.sql_execute_for_resource_group'): + elif user.has_perm("sql.sql_review") or user.has_perm( + "sql.sql_execute_for_resource_group" + ): # 先获取用户所在资源组列表 group_list = user_groups(user) group_ids = [group.group_id for group in group_list] - filter_dict['group_id__in'] = group_ids + filter_dict["group_id__in"] = group_ids # 其他人只能查看自己提交的工单 else: - filter_dict['engineer'] = user.username + filter_dict["engineer"] = user.username # 过滤组合筛选项 workflow = SqlWorkflow.objects.filter(**filter_dict) # 过滤搜索项,模糊检索项包括提交人名称、工单名 if search: - workflow = workflow.filter(Q(engineer_display__icontains=search) | Q(workflow_name__icontains=search)) + workflow = workflow.filter( + Q(engineer_display__icontains=search) | Q(workflow_name__icontains=search) + ) count = workflow.count() - workflow_list = workflow.order_by('-create_time')[offset:limit].values( - "id", "workflow_name", "engineer_display", - "status", "is_backup", "create_time", - "instance__instance_name", "db_name", - "group_name", "syntax_type") + workflow_list = workflow.order_by("-create_time")[offset:limit].values( + "id", + "workflow_name", + "engineer_display", + "status", + "is_backup", + "create_time", + "instance__instance_name", + "db_name", + "group_name", + "syntax_type", + ) # QuerySet 序列化 rows = [row for row in workflow_list] result = {"total": count, "rows": rows} # 返回查询结果 - return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), - content_type='application/json') + return HttpResponse( + json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), + content_type="application/json", + ) -@permission_required('sql.sql_submit', raise_exception=True) +@permission_required("sql.sql_submit", raise_exception=True) def check(request): """SQL检测按钮, 此处没有产生工单""" - sql_content = request.POST.get('sql_content') - instance_name = request.POST.get('instance_name') + sql_content = request.POST.get("sql_content") + instance_name = request.POST.get("instance_name") instance = Instance.objects.get(instance_name=instance_name) - db_name = request.POST.get('db_name') + db_name = request.POST.get("db_name") - result = {'status': 0, 'msg': 'ok', 'data': {}} + result = {"status": 0, "msg": "ok", "data": {}} # 服务器端参数验证 if sql_content is None or instance_name is None or db_name is None: - result['status'] = 1 - result['msg'] = '页面提交参数可能为空' - return HttpResponse(json.dumps(result), content_type='application/json') + result["status"] = 1 + result["msg"] = "页面提交参数可能为空" + return HttpResponse(json.dumps(result), content_type="application/json") # 交给engine进行检测 try: check_engine = get_engine(instance=instance) - check_result = check_engine.execute_check(db_name=db_name, sql=sql_content.strip()) + check_result = check_engine.execute_check( + db_name=db_name, sql=sql_content.strip() + ) except Exception as e: - result['status'] = 1 - result['msg'] = str(e) - return HttpResponse(json.dumps(result), content_type='application/json') + result["status"] = 1 + result["msg"] = str(e) + return HttpResponse(json.dumps(result), content_type="application/json") # 处理检测结果 - result['data']['rows'] = check_result.to_dict() - result['data']['CheckWarningCount'] = check_result.warning_count - result['data']['CheckErrorCount'] = check_result.error_count - return HttpResponse(json.dumps(result), content_type='application/json') + result["data"]["rows"] = check_result.to_dict() + result["data"]["CheckWarningCount"] = check_result.warning_count + result["data"]["CheckErrorCount"] = check_result.error_count + return HttpResponse(json.dumps(result), content_type="application/json") -@permission_required('sql.sql_submit', raise_exception=True) +@permission_required("sql.sql_submit", raise_exception=True) def submit(request): """正式提交SQL, 此处生成工单""" - sql_content = request.POST.get('sql_content').strip() - workflow_title = request.POST.get('workflow_name') - demand_url = request.POST.get('demand_url', '') + sql_content = request.POST.get("sql_content").strip() + workflow_title = request.POST.get("workflow_name") + demand_url = request.POST.get("demand_url", "") # 检查用户是否有权限涉及到资源组等, 比较复杂, 可以把检查权限改成一个独立的方法 - group_name = request.POST.get('group_name') + group_name = request.POST.get("group_name") group_id = ResourceGroup.objects.get(group_name=group_name).group_id - instance_name = request.POST.get('instance_name') + instance_name = request.POST.get("instance_name") instance = Instance.objects.get(instance_name=instance_name) - db_name = request.POST.get('db_name') - is_backup = True if request.POST.get('is_backup') == 'True' else False - cc_users = request.POST.getlist('cc_users') - run_date_start = request.POST.get('run_date_start') - run_date_end = request.POST.get('run_date_end') + db_name = request.POST.get("db_name") + is_backup = True if request.POST.get("is_backup") == "True" else False + cc_users = request.POST.getlist("cc_users") + run_date_start = request.POST.get("run_date_start") + run_date_end = request.POST.get("run_date_end") # 服务器端参数验证 if None in [sql_content, db_name, instance_name, db_name, is_backup, demand_url]: - context = {'errMsg': '页面提交参数可能为空'} - return render(request, 'error.html', context) + context = {"errMsg": "页面提交参数可能为空"} + return render(request, "error.html", context) # 验证组权限(用户是否在该组、该组是否有指定实例) try: - user_instances(request.user, tag_codes=['can_write']).get(instance_name=instance_name) + user_instances(request.user, tag_codes=["can_write"]).get( + instance_name=instance_name + ) except instance.DoesNotExist: - context = {'errMsg': '你所在组未关联该实例!'} - return render(request, 'error.html', context) + context = {"errMsg": "你所在组未关联该实例!"} + return render(request, "error.html", context) # 再次交给engine进行检测,防止绕过 try: check_engine = get_engine(instance=instance) - check_result = check_engine.execute_check(db_name=db_name, sql=sql_content.strip()) + check_result = check_engine.execute_check( + db_name=db_name, sql=sql_content.strip() + ) except Exception as e: - context = {'errMsg': str(e)} - return render(request, 'error.html', context) + context = {"errMsg": str(e)} + return render(request, "error.html", context) # 未开启备份选项,并且engine支持备份,强制设置备份 sys_config = SysConfig() - if not sys_config.get('enable_backup_switch') and check_engine.auto_backup: + if not sys_config.get("enable_backup_switch") and check_engine.auto_backup: is_backup = True # 按照系统配置确定是自动驳回还是放行 - auto_review_wrong = sys_config.get('auto_review_wrong', '') # 1表示出现警告就驳回,2和空表示出现错误才驳回 - workflow_status = 'workflow_manreviewing' - if check_result.warning_count > 0 and auto_review_wrong == '1': - workflow_status = 'workflow_autoreviewwrong' - elif check_result.error_count > 0 and auto_review_wrong in ('', '1', '2'): - workflow_status = 'workflow_autoreviewwrong' + auto_review_wrong = sys_config.get( + "auto_review_wrong", "" + ) # 1表示出现警告就驳回,2和空表示出现错误才驳回 + workflow_status = "workflow_manreviewing" + if check_result.warning_count > 0 and auto_review_wrong == "1": + workflow_status = "workflow_autoreviewwrong" + elif check_result.error_count > 0 and auto_review_wrong in ("", "1", "2"): + workflow_status = "workflow_autoreviewwrong" # 调用工作流生成工单 # 使用事务保持数据一致性 @@ -200,7 +232,9 @@ def submit(request): group_name=group_name, engineer=request.user.username, engineer_display=request.user.display, - audit_auth_groups=Audit.settings(group_id, WorkflowDict.workflow_type['sqlreview']), + audit_auth_groups=Audit.settings( + group_id, WorkflowDict.workflow_type["sqlreview"] + ), status=workflow_status, is_backup=is_backup, instance=instance, @@ -209,44 +243,55 @@ def submit(request): syntax_type=check_result.syntax_type, create_time=timezone.now(), run_date_start=run_date_start or None, - run_date_end=run_date_end or None + run_date_end=run_date_end or None, + ) + SqlWorkflowContent.objects.create( + workflow=sql_workflow, + sql_content=sql_content, + review_content=check_result.json(), + execute_result="", ) - SqlWorkflowContent.objects.create(workflow=sql_workflow, - sql_content=sql_content, - review_content=check_result.json(), - execute_result='' - ) workflow_id = sql_workflow.id # 自动审核通过了,才调用工作流 - if workflow_status == 'workflow_manreviewing': + if workflow_status == "workflow_manreviewing": # 调用工作流插入审核信息, SQL上线权限申请workflow_type=2 - Audit.add(WorkflowDict.workflow_type['sqlreview'], workflow_id) + Audit.add(WorkflowDict.workflow_type["sqlreview"], workflow_id) except Exception as msg: logger.error(f"提交工单报错,错误信息:{traceback.format_exc()}") - context = {'errMsg': msg} + context = {"errMsg": msg} logger.error(traceback.format_exc()) - return render(request, 'error.html', context) + return render(request, "error.html", context) else: # 自动审核通过且开启了Apply阶段通知参数才发送消息通知 - is_notified = 'Apply' in sys_config.get('notify_phase_control').split(',') \ - if sys_config.get('notify_phase_control') else True - if workflow_status == 'workflow_manreviewing' and is_notified: + is_notified = ( + "Apply" in sys_config.get("notify_phase_control").split(",") + if sys_config.get("notify_phase_control") + else True + ) + if workflow_status == "workflow_manreviewing" and is_notified: # 获取审核信息 - audit_id = Audit.detail_by_workflow_id(workflow_id=workflow_id, - workflow_type=WorkflowDict.workflow_type['sqlreview']).audit_id - async_task(notify_for_audit, audit_id=audit_id, cc_users=cc_users, timeout=60, - task_name=f'sqlreview-submit-{workflow_id}') + audit_id = Audit.detail_by_workflow_id( + workflow_id=workflow_id, + workflow_type=WorkflowDict.workflow_type["sqlreview"], + ).audit_id + async_task( + notify_for_audit, + audit_id=audit_id, + cc_users=cc_users, + timeout=60, + task_name=f"sqlreview-submit-{workflow_id}", + ) - return HttpResponseRedirect(reverse('sql:detail', args=(workflow_id,))) + return HttpResponseRedirect(reverse("sql:detail", args=(workflow_id,))) def detail_content(request): """获取工单内容""" - workflow_id = request.GET.get('workflow_id') + workflow_id = request.GET.get("workflow_id") workflow_detail = get_object_or_404(SqlWorkflow, pk=workflow_id) if not can_view(request.user, workflow_id): raise PermissionDenied - if workflow_detail.status in ['workflow_finish', 'workflow_exception']: + if workflow_detail.status in ["workflow_finish", "workflow_exception"]: rows = workflow_detail.sqlworkflowcontent.execute_result else: rows = workflow_detail.sqlworkflowcontent.review_content @@ -262,32 +307,34 @@ def detail_content(request): review_result.rows += [ReviewResult(inception_result=r)] rows = review_result.json() except IndexError: - review_result.rows += [ReviewResult( - id=1, - sql=workflow_detail.sqlworkflowcontent.sql_content, - errormessage="Json decode failed." - "执行结果Json解析失败, 请联系管理员" - )] + review_result.rows += [ + ReviewResult( + id=1, + sql=workflow_detail.sqlworkflowcontent.sql_content, + errormessage="Json decode failed." "执行结果Json解析失败, 请联系管理员", + ) + ] rows = review_result.json() except json.decoder.JSONDecodeError: - review_result.rows += [ReviewResult( - id=1, - sql=workflow_detail.sqlworkflowcontent.sql_content, - # 迫于无法单元测试这里加上英文报错信息 - errormessage="Json decode failed." - "执行结果Json解析失败, 请联系管理员" - )] + review_result.rows += [ + ReviewResult( + id=1, + sql=workflow_detail.sqlworkflowcontent.sql_content, + # 迫于无法单元测试这里加上英文报错信息 + errormessage="Json decode failed." "执行结果Json解析失败, 请联系管理员", + ) + ] rows = review_result.json() else: rows = workflow_detail.sqlworkflowcontent.review_content result = {"rows": json.loads(rows)} - return HttpResponse(json.dumps(result), content_type='application/json') + return HttpResponse(json.dumps(result), content_type="application/json") def backup_sql(request): """获取回滚语句""" - workflow_id = request.GET.get('workflow_id') + workflow_id = request.GET.get("workflow_id") if not can_rollback(request.user, workflow_id): raise PermissionDenied workflow = get_object_or_404(SqlWorkflow, pk=workflow_id) @@ -297,89 +344,109 @@ def backup_sql(request): list_backup_sql = query_engine.get_rollback(workflow=workflow) except Exception as msg: logger.error(traceback.format_exc()) - return JsonResponse({'status': 1, 'msg': f'{msg}', 'rows': []}) + return JsonResponse({"status": 1, "msg": f"{msg}", "rows": []}) - result = {'status': 0, 'msg': '', 'rows': list_backup_sql} - return HttpResponse(json.dumps(result), content_type='application/json') + result = {"status": 0, "msg": "", "rows": list_backup_sql} + return HttpResponse(json.dumps(result), content_type="application/json") -@permission_required('sql.sql_review', raise_exception=True) +@permission_required("sql.sql_review", raise_exception=True) def alter_run_date(request): """ 审核人修改可执行时间 :param request: :return: """ - workflow_id = int(request.POST.get('workflow_id', 0)) - run_date_start = request.POST.get('run_date_start') - run_date_end = request.POST.get('run_date_end') + workflow_id = int(request.POST.get("workflow_id", 0)) + run_date_start = request.POST.get("run_date_start") + run_date_end = request.POST.get("run_date_end") if workflow_id == 0: - context = {'errMsg': 'workflow_id参数为空.'} - return render(request, 'error.html', context) + context = {"errMsg": "workflow_id参数为空."} + return render(request, "error.html", context) user = request.user if Audit.can_review(user, workflow_id, 2) is False: - context = {'errMsg': '你无权操作当前工单!'} - return render(request, 'error.html', context) + context = {"errMsg": "你无权操作当前工单!"} + return render(request, "error.html", context) try: # 存进数据库里 - SqlWorkflow(id=workflow_id, - run_date_start=run_date_start or None, - run_date_end=run_date_end or None - ).save(update_fields=['run_date_start', 'run_date_end']) + SqlWorkflow( + id=workflow_id, + run_date_start=run_date_start or None, + run_date_end=run_date_end or None, + ).save(update_fields=["run_date_start", "run_date_end"]) except Exception as msg: - context = {'errMsg': msg} - return render(request, 'error.html', context) + context = {"errMsg": msg} + return render(request, "error.html", context) - return HttpResponseRedirect(reverse('sql:detail', args=(workflow_id,))) + return HttpResponseRedirect(reverse("sql:detail", args=(workflow_id,))) -@permission_required('sql.sql_review', raise_exception=True) +@permission_required("sql.sql_review", raise_exception=True) def passed(request): """ 审核通过,不执行 :param request: :return: """ - workflow_id = int(request.POST.get('workflow_id', 0)) - audit_remark = request.POST.get('audit_remark', '') + workflow_id = int(request.POST.get("workflow_id", 0)) + audit_remark = request.POST.get("audit_remark", "") if workflow_id == 0: - context = {'errMsg': 'workflow_id参数为空.'} - return render(request, 'error.html', context) + context = {"errMsg": "workflow_id参数为空."} + return render(request, "error.html", context) user = request.user if Audit.can_review(user, workflow_id, 2) is False: - context = {'errMsg': '你无权操作当前工单!'} - return render(request, 'error.html', context) + context = {"errMsg": "你无权操作当前工单!"} + return render(request, "error.html", context) # 使用事务保持数据一致性 try: with transaction.atomic(): # 调用工作流接口审核 - audit_id = Audit.detail_by_workflow_id(workflow_id=workflow_id, - workflow_type=WorkflowDict.workflow_type['sqlreview']).audit_id - audit_result = Audit.audit(audit_id, WorkflowDict.workflow_status['audit_success'], - user.username, audit_remark) + audit_id = Audit.detail_by_workflow_id( + workflow_id=workflow_id, + workflow_type=WorkflowDict.workflow_type["sqlreview"], + ).audit_id + audit_result = Audit.audit( + audit_id, + WorkflowDict.workflow_status["audit_success"], + user.username, + audit_remark, + ) # 按照审核结果更新业务表审核状态 - if audit_result['data']['workflow_status'] == WorkflowDict.workflow_status['audit_success']: + if ( + audit_result["data"]["workflow_status"] + == WorkflowDict.workflow_status["audit_success"] + ): # 将流程状态修改为审核通过 - SqlWorkflow(id=workflow_id, status='workflow_review_pass').save(update_fields=['status']) + SqlWorkflow(id=workflow_id, status="workflow_review_pass").save( + update_fields=["status"] + ) except Exception as msg: logger.error(f"审核工单报错,错误信息:{traceback.format_exc()}") - context = {'errMsg': msg} - return render(request, 'error.html', context) + context = {"errMsg": msg} + return render(request, "error.html", context) else: # 开启了Pass阶段通知参数才发送消息通知 sys_config = SysConfig() - is_notified = 'Pass' in sys_config.get('notify_phase_control').split(',') \ - if sys_config.get('notify_phase_control') else True + is_notified = ( + "Pass" in sys_config.get("notify_phase_control").split(",") + if sys_config.get("notify_phase_control") + else True + ) if is_notified: - async_task(notify_for_audit, audit_id=audit_id, audit_remark=audit_remark, timeout=60, - task_name=f'sqlreview-pass-{workflow_id}') + async_task( + notify_for_audit, + audit_id=audit_id, + audit_remark=audit_remark, + timeout=60, + task_name=f"sqlreview-pass-{workflow_id}", + ) - return HttpResponseRedirect(reverse('sql:detail', args=(workflow_id,))) + return HttpResponseRedirect(reverse("sql:detail", args=(workflow_id,))) def execute(request): @@ -389,63 +456,84 @@ def execute(request): :return: """ # 校验多个权限 - if not (request.user.has_perm('sql.sql_execute') or request.user.has_perm('sql.sql_execute_for_resource_group')): + if not ( + request.user.has_perm("sql.sql_execute") + or request.user.has_perm("sql.sql_execute_for_resource_group") + ): raise PermissionDenied - workflow_id = int(request.POST.get('workflow_id', 0)) + workflow_id = int(request.POST.get("workflow_id", 0)) if workflow_id == 0: - context = {'errMsg': 'workflow_id参数为空.'} - return render(request, 'error.html', context) + context = {"errMsg": "workflow_id参数为空."} + return render(request, "error.html", context) if can_execute(request.user, workflow_id) is False: - context = {'errMsg': '你无权操作当前工单!'} - return render(request, 'error.html', context) + context = {"errMsg": "你无权操作当前工单!"} + return render(request, "error.html", context) if on_correct_time_period(workflow_id) is False: - context = {'errMsg': '不在可执行时间范围内,如果需要修改执行时间请重新提交工单!'} - return render(request, 'error.html', context) + context = {"errMsg": "不在可执行时间范围内,如果需要修改执行时间请重新提交工单!"} + return render(request, "error.html", context) # 获取审核信息 - audit_id = Audit.detail_by_workflow_id(workflow_id=workflow_id, - workflow_type=WorkflowDict.workflow_type['sqlreview']).audit_id + audit_id = Audit.detail_by_workflow_id( + workflow_id=workflow_id, workflow_type=WorkflowDict.workflow_type["sqlreview"] + ).audit_id # 根据执行模式进行对应修改 - mode = request.POST.get('mode') + mode = request.POST.get("mode") # 交由系统执行 if mode == "auto": # 修改工单状态为排队中 - SqlWorkflow(id=workflow_id, status="workflow_queuing").save(update_fields=['status']) + SqlWorkflow(id=workflow_id, status="workflow_queuing").save( + update_fields=["status"] + ) # 删除定时执行任务 schedule_name = f"sqlreview-timing-{workflow_id}" del_schedule(schedule_name) # 加入执行队列 - async_task('sql.utils.execute_sql.execute', workflow_id, request.user, - hook='sql.utils.execute_sql.execute_callback', - timeout=-1, task_name=f'sqlreview-execute-{workflow_id}') + async_task( + "sql.utils.execute_sql.execute", + workflow_id, + request.user, + hook="sql.utils.execute_sql.execute_callback", + timeout=-1, + task_name=f"sqlreview-execute-{workflow_id}", + ) # 增加工单日志 - Audit.add_log(audit_id=audit_id, - operation_type=5, - operation_type_desc='执行工单', - operation_info='工单执行排队中', - operator=request.user.username, - operator_display=request.user.display) + Audit.add_log( + audit_id=audit_id, + operation_type=5, + operation_type_desc="执行工单", + operation_info="工单执行排队中", + operator=request.user.username, + operator_display=request.user.display, + ) # 线下手工执行 elif mode == "manual": # 将流程状态修改为执行结束 - SqlWorkflow(id=workflow_id, status="workflow_finish", finish_time=datetime.datetime.now() - ).save(update_fields=['status', 'finish_time']) + SqlWorkflow( + id=workflow_id, + status="workflow_finish", + finish_time=datetime.datetime.now(), + ).save(update_fields=["status", "finish_time"]) # 增加工单日志 - Audit.add_log(audit_id=audit_id, - operation_type=6, - operation_type_desc='手工工单', - operation_info='确认手工执行结束', - operator=request.user.username, - operator_display=request.user.display) + Audit.add_log( + audit_id=audit_id, + operation_type=6, + operation_type_desc="手工工单", + operation_info="确认手工执行结束", + operator=request.user.username, + operator_display=request.user.display, + ) # 开启了Execute阶段通知参数才发送消息通知 sys_config = SysConfig() - is_notified = 'Execute' in sys_config.get('notify_phase_control').split(',') \ - if sys_config.get('notify_phase_control') else True + is_notified = ( + "Execute" in sys_config.get("notify_phase_control").split(",") + if sys_config.get("notify_phase_control") + else True + ) if is_notified: notify_for_execute(SqlWorkflow.objects.get(id=workflow_id)) - return HttpResponseRedirect(reverse('sql:detail', args=(workflow_id,))) + return HttpResponseRedirect(reverse("sql:detail", args=(workflow_id,))) def timing_task(request): @@ -455,53 +543,58 @@ def timing_task(request): :return: """ # 校验多个权限 - if not (request.user.has_perm('sql.sql_execute') or request.user.has_perm('sql.sql_execute_for_resource_group')): + if not ( + request.user.has_perm("sql.sql_execute") + or request.user.has_perm("sql.sql_execute_for_resource_group") + ): raise PermissionDenied - workflow_id = request.POST.get('workflow_id') - run_date = request.POST.get('run_date') + workflow_id = request.POST.get("workflow_id") + run_date = request.POST.get("run_date") if run_date is None or workflow_id is None: - context = {'errMsg': '时间不能为空'} - return render(request, 'error.html', context) - elif run_date < datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'): - context = {'errMsg': '时间不能小于当前时间'} - return render(request, 'error.html', context) + context = {"errMsg": "时间不能为空"} + return render(request, "error.html", context) + elif run_date < datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"): + context = {"errMsg": "时间不能小于当前时间"} + return render(request, "error.html", context) workflow_detail = SqlWorkflow.objects.get(id=workflow_id) if can_timingtask(request.user, workflow_id) is False: - context = {'errMsg': '你无权操作当前工单!'} - return render(request, 'error.html', context) + context = {"errMsg": "你无权操作当前工单!"} + return render(request, "error.html", context) run_date = datetime.datetime.strptime(run_date, "%Y-%m-%d %H:%M") schedule_name = f"sqlreview-timing-{workflow_id}" if on_correct_time_period(workflow_id, run_date) is False: - context = {'errMsg': '不在可执行时间范围内,如果需要修改执 行时间请重新提交工单!'} - return render(request, 'error.html', context) + context = {"errMsg": "不在可执行时间范围内,如果需要修改执 行时间请重新提交工单!"} + return render(request, "error.html", context) # 使用事务保持数据一致性 try: with transaction.atomic(): # 将流程状态修改为定时执行 - workflow_detail.status = 'workflow_timingtask' + workflow_detail.status = "workflow_timingtask" workflow_detail.save() # 调用添加定时任务 add_sql_schedule(schedule_name, run_date, workflow_id) # 增加工单日志 - audit_id = Audit.detail_by_workflow_id(workflow_id=workflow_id, - workflow_type=WorkflowDict.workflow_type[ - 'sqlreview']).audit_id - Audit.add_log(audit_id=audit_id, - operation_type=4, - operation_type_desc='定时执行', - operation_info="定时执行时间:{}".format(run_date), - operator=request.user.username, - operator_display=request.user.display - ) + audit_id = Audit.detail_by_workflow_id( + workflow_id=workflow_id, + workflow_type=WorkflowDict.workflow_type["sqlreview"], + ).audit_id + Audit.add_log( + audit_id=audit_id, + operation_type=4, + operation_type_desc="定时执行", + operation_info="定时执行时间:{}".format(run_date), + operator=request.user.username, + operator_display=request.user.display, + ) except Exception as msg: logger.error(f"定时执行工单报错,错误信息:{traceback.format_exc()}") - context = {'errMsg': msg} - return render(request, 'error.html', context) - return HttpResponseRedirect(reverse('sql:detail', args=(workflow_id,))) + context = {"errMsg": msg} + return render(request, "error.html", context) + return HttpResponseRedirect(reverse("sql:detail", args=(workflow_id,))) def cancel(request): @@ -510,94 +603,115 @@ def cancel(request): :param request: :return: """ - workflow_id = int(request.POST.get('workflow_id', 0)) + workflow_id = int(request.POST.get("workflow_id", 0)) if workflow_id == 0: - context = {'errMsg': 'workflow_id参数为空.'} - return render(request, 'error.html', context) + context = {"errMsg": "workflow_id参数为空."} + return render(request, "error.html", context) workflow_detail = SqlWorkflow.objects.get(id=workflow_id) - audit_remark = request.POST.get('cancel_remark') + audit_remark = request.POST.get("cancel_remark") if audit_remark is None: - context = {'errMsg': '终止原因不能为空'} - return render(request, 'error.html', context) + context = {"errMsg": "终止原因不能为空"} + return render(request, "error.html", context) user = request.user if can_cancel(request.user, workflow_id) is False: - context = {'errMsg': '你无权操作当前工单!'} - return render(request, 'error.html', context) + context = {"errMsg": "你无权操作当前工单!"} + return render(request, "error.html", context) # 使用事务保持数据一致性 try: with transaction.atomic(): # 调用工作流接口取消或者驳回 - audit_id = Audit.detail_by_workflow_id(workflow_id=workflow_id, - workflow_type=WorkflowDict.workflow_type[ - 'sqlreview']).audit_id + audit_id = Audit.detail_by_workflow_id( + workflow_id=workflow_id, + workflow_type=WorkflowDict.workflow_type["sqlreview"], + ).audit_id # 仅待审核的需要调用工作流,审核通过的不需要 - if workflow_detail.status != 'workflow_manreviewing': + if workflow_detail.status != "workflow_manreviewing": # 增加工单日志 if user.username == workflow_detail.engineer: - Audit.add_log(audit_id=audit_id, - operation_type=3, - operation_type_desc='取消执行', - operation_info="取消原因:{}".format(audit_remark), - operator=request.user.username, - operator_display=request.user.display - ) + Audit.add_log( + audit_id=audit_id, + operation_type=3, + operation_type_desc="取消执行", + operation_info="取消原因:{}".format(audit_remark), + operator=request.user.username, + operator_display=request.user.display, + ) else: - Audit.add_log(audit_id=audit_id, - operation_type=2, - operation_type_desc='审批不通过', - operation_info="审批备注:{}".format(audit_remark), - operator=request.user.username, - operator_display=request.user.display - ) + Audit.add_log( + audit_id=audit_id, + operation_type=2, + operation_type_desc="审批不通过", + operation_info="审批备注:{}".format(audit_remark), + operator=request.user.username, + operator_display=request.user.display, + ) else: if user.username == workflow_detail.engineer: - Audit.audit(audit_id, - WorkflowDict.workflow_status['audit_abort'], - user.username, audit_remark) + Audit.audit( + audit_id, + WorkflowDict.workflow_status["audit_abort"], + user.username, + audit_remark, + ) # 非提交人需要校验审核权限 - elif user.has_perm('sql.sql_review'): - Audit.audit(audit_id, - WorkflowDict.workflow_status['audit_reject'], - user.username, audit_remark) + elif user.has_perm("sql.sql_review"): + Audit.audit( + audit_id, + WorkflowDict.workflow_status["audit_reject"], + user.username, + audit_remark, + ) else: raise PermissionDenied # 删除定时执行task - if workflow_detail.status == 'workflow_timingtask': + if workflow_detail.status == "workflow_timingtask": schedule_name = f"sqlreview-timing-{workflow_id}" del_schedule(schedule_name) # 将流程状态修改为人工终止流程 - workflow_detail.status = 'workflow_abort' + workflow_detail.status = "workflow_abort" workflow_detail.save() except Exception as msg: logger.error(f"取消工单报错,错误信息:{traceback.format_exc()}") - context = {'errMsg': msg} - return render(request, 'error.html', context) + context = {"errMsg": msg} + return render(request, "error.html", context) else: # 发送取消、驳回通知,开启了Cancel阶段通知参数才发送消息通知 sys_config = SysConfig() - is_notified = 'Cancel' in sys_config.get('notify_phase_control').split(',') \ - if sys_config.get('notify_phase_control') else True + is_notified = ( + "Cancel" in sys_config.get("notify_phase_control").split(",") + if sys_config.get("notify_phase_control") + else True + ) if is_notified: - audit_detail = Audit.detail_by_workflow_id(workflow_id=workflow_id, - workflow_type=WorkflowDict.workflow_type['sqlreview']) + audit_detail = Audit.detail_by_workflow_id( + workflow_id=workflow_id, + workflow_type=WorkflowDict.workflow_type["sqlreview"], + ) if audit_detail.current_status in ( - WorkflowDict.workflow_status['audit_abort'], WorkflowDict.workflow_status['audit_reject']): - async_task(notify_for_audit, audit_id=audit_detail.audit_id, audit_remark=audit_remark, timeout=60, - task_name=f'sqlreview-cancel-{workflow_id}') - return HttpResponseRedirect(reverse('sql:detail', args=(workflow_id,))) + WorkflowDict.workflow_status["audit_abort"], + WorkflowDict.workflow_status["audit_reject"], + ): + async_task( + notify_for_audit, + audit_id=audit_detail.audit_id, + audit_remark=audit_remark, + timeout=60, + task_name=f"sqlreview-cancel-{workflow_id}", + ) + return HttpResponseRedirect(reverse("sql:detail", args=(workflow_id,))) def get_workflow_status(request): """ 获取某个工单的当前状态 """ - workflow_id = request.POST['workflow_id'] - if workflow_id == '' or workflow_id is None: - context = {"status": -1, 'msg': 'workflow_id参数为空.', "data": ""} - return HttpResponse(json.dumps(context), content_type='application/json') + workflow_id = request.POST["workflow_id"] + if workflow_id == "" or workflow_id is None: + context = {"status": -1, "msg": "workflow_id参数为空.", "data": ""} + return HttpResponse(json.dumps(context), content_type="application/json") workflow_id = int(workflow_id) workflow_detail = get_object_or_404(SqlWorkflow, pk=workflow_id) @@ -607,9 +721,9 @@ def get_workflow_status(request): def osc_control(request): """用于mysql控制osc执行""" - workflow_id = request.POST.get('workflow_id') - sqlsha1 = request.POST.get('sqlsha1') - command = request.POST.get('command') + workflow_id = request.POST.get("workflow_id") + sqlsha1 = request.POST.get("sqlsha1") + command = request.POST.get("command") workflow = SqlWorkflow.objects.get(id=workflow_id) execute_engine = get_engine(workflow.instance) try: @@ -620,5 +734,7 @@ def osc_control(request): rows = [] error = str(e) result = {"total": len(rows), "rows": rows, "msg": error} - return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), - content_type='application/json') + return HttpResponse( + json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), + content_type="application/json", + ) diff --git a/sql/templatetags/format_tags.py b/sql/templatetags/format_tags.py index 64991ac17d..902eb52371 100644 --- a/sql/templatetags/format_tags.py +++ b/sql/templatetags/format_tags.py @@ -10,7 +10,7 @@ @register.simple_tag def format_str(string): # 换行 - return mark_safe(string.replace(',', '
').replace('\n', '
')) + return mark_safe(string.replace(",", "
").replace("\n", "
")) # split @@ -26,7 +26,7 @@ def split(string, sep): # in @register.filter def is_in(var, args): - return True if str(var) in args.split(',') else False + return True if str(var) in args.split(",") else False @register.filter @@ -34,4 +34,4 @@ def key_value(data, key): try: return data[key] except KeyError: - return '' + return "" diff --git a/sql/tests.py b/sql/tests.py index 5a6700a0b0..c2fb567cc0 100644 --- a/sql/tests.py +++ b/sql/tests.py @@ -16,9 +16,21 @@ from sql.notify import notify_for_audit, notify_for_execute, notify_for_my2sql from sql.utils.execute_sql import execute_callback from sql.query import kill_query_conn -from sql.models import Users, Instance, QueryPrivilegesApply, QueryPrivileges, SqlWorkflow, SqlWorkflowContent, \ - ResourceGroup, ParamTemplate, WorkflowAudit, QueryLog, WorkflowLog, WorkflowAuditSetting, \ - ArchiveConfig +from sql.models import ( + Users, + Instance, + QueryPrivilegesApply, + QueryPrivileges, + SqlWorkflow, + SqlWorkflowContent, + ResourceGroup, + ParamTemplate, + WorkflowAudit, + QueryLog, + WorkflowLog, + WorkflowAuditSetting, + ArchiveConfig, +) User = Users @@ -32,53 +44,63 @@ def setUp(self): """ self.sys_config = SysConfig() self.client = Client() - self.superuser = User.objects.create(username='super', is_superuser=True) + self.superuser = User.objects.create(username="super", is_superuser=True) self.client.force_login(self.superuser) - self.ins = Instance.objects.create(instance_name='some_ins', type='slave', db_type='mysql', - host='some_host', - port=3306, user='ins_user', password='some_str') - self.res_group = ResourceGroup.objects.create(group_id=1, group_name='group_name') + self.ins = Instance.objects.create( + instance_name="some_ins", + type="slave", + db_type="mysql", + host="some_host", + port=3306, + user="ins_user", + password="some_str", + ) + self.res_group = ResourceGroup.objects.create( + group_id=1, group_name="group_name" + ) self.wf = SqlWorkflow.objects.create( - workflow_name='some_name', + workflow_name="some_name", group_id=1, - group_name='g1', - engineer_display='', - audit_auth_groups='some_audit_group', - status='workflow_finish', + group_name="g1", + engineer_display="", + audit_auth_groups="some_audit_group", + status="workflow_finish", is_backup=True, instance=self.ins, - db_name='some_db', - syntax_type=1 + db_name="some_db", + syntax_type=1, + ) + SqlWorkflowContent.objects.create( + workflow=self.wf, sql_content="some_sql", execute_result="" ) - SqlWorkflowContent.objects.create(workflow=self.wf, - sql_content='some_sql', - execute_result='') self.query_apply = QueryPrivilegesApply.objects.create( group_id=1, - group_name='some_name', - title='some_title1', - user_name='some_user', + group_name="some_name", + title="some_title1", + user_name="some_user", instance=self.ins, - db_list='some_db,some_db2', + db_list="some_db,some_db2", limit_num=100, - valid_date='2020-01-1', + valid_date="2020-01-1", priv_type=1, status=0, - audit_auth_groups='some_audit_group' + audit_auth_groups="some_audit_group", ) self.audit = WorkflowAudit.objects.create( group_id=1, - group_name='some_group', + group_name="some_group", workflow_id=1, workflow_type=1, - workflow_title='申请标题', - workflow_remark='申请备注', - audit_auth_groups='1,2,3', - current_audit='1', - next_audit='2', - current_status=0) - self.wl = WorkflowLog.objects.create(audit_id=self.audit.audit_id, - operation_type=1) + workflow_title="申请标题", + workflow_remark="申请备注", + audit_auth_groups="1,2,3", + current_audit="1", + next_audit="2", + current_status=0, + ) + self.wl = WorkflowLog.objects.create( + audit_id=self.audit.audit_id, operation_type=1 + ) def tearDown(self): self.sys_config.purge() @@ -93,170 +115,170 @@ def tearDown(self): def test_index(self): """测试index页面""" data = {} - r = self.client.get('/index/', data=data) - self.assertRedirects(r, f'/sqlworkflow/', fetch_redirect_response=False) + r = self.client.get("/index/", data=data) + self.assertRedirects(r, f"/sqlworkflow/", fetch_redirect_response=False) def test_dashboard(self): """测试dashboard页面""" data = {} - r = self.client.get('/dashboard/', data=data) + r = self.client.get("/dashboard/", data=data) self.assertEqual(r.status_code, 200) - self.assertContains(r, 'SQL上线工单') + self.assertContains(r, "SQL上线工单") def test_sqlworkflow(self): """测试sqlworkflow页面""" data = {} - r = self.client.get('/sqlworkflow/', data=data) + r = self.client.get("/sqlworkflow/", data=data) self.assertEqual(r.status_code, 200) def test_submitsql(self): """测试submitsql页面""" data = {} - r = self.client.get('/submitsql/', data=data) + r = self.client.get("/submitsql/", data=data) self.assertEqual(r.status_code, 200) def test_rollback(self): """测试rollback页面""" data = {"workflow_id": self.wf.id} - r = self.client.get('/rollback/', data=data) + r = self.client.get("/rollback/", data=data) self.assertEqual(r.status_code, 200) def test_sqlanalyze(self): """测试sqlanalyze页面""" data = {} - r = self.client.get('/sqlanalyze/', data=data) + r = self.client.get("/sqlanalyze/", data=data) self.assertEqual(r.status_code, 200) def test_sqlquery(self): """测试sqlquery页面""" data = {} - r = self.client.get('/sqlquery/', data=data) + r = self.client.get("/sqlquery/", data=data) self.assertEqual(r.status_code, 200) def test_queryapplylist(self): """测试queryapplylist页面""" data = {} - r = self.client.get('/queryapplylist/', data=data) + r = self.client.get("/queryapplylist/", data=data) self.assertEqual(r.status_code, 200) def test_queryuserprivileges(self): """测试queryuserprivileges页面""" data = {} - r = self.client.get(f'/queryuserprivileges/', data=data) + r = self.client.get(f"/queryuserprivileges/", data=data) self.assertEqual(r.status_code, 200) def test_sqladvisor(self): """测试sqladvisor页面""" data = {} - r = self.client.get(f'/sqladvisor/', data=data) + r = self.client.get(f"/sqladvisor/", data=data) self.assertEqual(r.status_code, 200) def test_slowquery(self): """测试slowquery页面""" data = {} - r = self.client.get(f'/slowquery/', data=data) + r = self.client.get(f"/slowquery/", data=data) self.assertEqual(r.status_code, 200) def test_instance(self): """测试instance页面""" data = {} - r = self.client.get(f'/instance/', data=data) + r = self.client.get(f"/instance/", data=data) self.assertEqual(r.status_code, 200) def test_instanceaccount(self): """测试instanceaccount页面""" data = {} - r = self.client.get(f'/instanceaccount/', data=data) + r = self.client.get(f"/instanceaccount/", data=data) self.assertEqual(r.status_code, 200) def test_database(self): """测试database页面""" data = {} - r = self.client.get(f'/database/', data=data) + r = self.client.get(f"/database/", data=data) self.assertEqual(r.status_code, 200) def test_dbdiagnostic(self): """测试dbdiagnostic页面""" data = {} - r = self.client.get(f'/dbdiagnostic/', data=data) + r = self.client.get(f"/dbdiagnostic/", data=data) self.assertEqual(r.status_code, 200) def test_instanceparam(self): """测试instance_param页面""" data = {} - r = self.client.get(f'/instanceparam/', data=data) + r = self.client.get(f"/instanceparam/", data=data) self.assertEqual(r.status_code, 200) def test_my2sql(self): """测试my2sql页面""" data = {} - r = self.client.get(f'/my2sql/', data=data) + r = self.client.get(f"/my2sql/", data=data) self.assertEqual(r.status_code, 200) def test_schemasync(self): """测试schemasync页面""" data = {} - r = self.client.get(f'/schemasync/', data=data) + r = self.client.get(f"/schemasync/", data=data) self.assertEqual(r.status_code, 200) def test_archive(self): """测试archive页面""" data = {} - r = self.client.get(f'/archive/', data=data) + r = self.client.get(f"/archive/", data=data) self.assertEqual(r.status_code, 200) def test_config(self): """测试config页面""" data = {} - r = self.client.get(f'/config/', data=data) + r = self.client.get(f"/config/", data=data) self.assertEqual(r.status_code, 200) def test_group(self): """测试group页面""" data = {} - r = self.client.get(f'/group/', data=data) + r = self.client.get(f"/group/", data=data) self.assertEqual(r.status_code, 200) def test_audit(self): """测试audit页面""" data = {} - r = self.client.get(f'/audit/', data=data) + r = self.client.get(f"/audit/", data=data) self.assertEqual(r.status_code, 200) def test_audit_sqlquery(self): """测试audit_sqlquery页面""" data = {} - r = self.client.get(f'/audit_sqlquery/', data=data) + r = self.client.get(f"/audit_sqlquery/", data=data) self.assertEqual(r.status_code, 200) def test_audit_sqlworkflow(self): """测试audit_sqlworkflow页面""" data = {} - r = self.client.get(f'/audit_sqlworkflow/', data=data) + r = self.client.get(f"/audit_sqlworkflow/", data=data) self.assertEqual(r.status_code, 200) def test_groupmgmt(self): """测试groupmgmt页面""" data = {} - r = self.client.get(f'/grouprelations/{self.res_group.group_id}/', data=data) + r = self.client.get(f"/grouprelations/{self.res_group.group_id}/", data=data) self.assertEqual(r.status_code, 200) def test_workflows(self): """测试workflows页面""" data = {} - r = self.client.get(f'/workflow/', data=data) + r = self.client.get(f"/workflow/", data=data) self.assertEqual(r.status_code, 200) def test_workflowsdetail(self): """测试workflows页面""" data = {} - r = self.client.get(f'/workflow/{self.audit.audit_id}/', data=data) - self.assertRedirects(r, f'/queryapplydetail/1/', fetch_redirect_response=False) + r = self.client.get(f"/workflow/{self.audit.audit_id}/", data=data) + self.assertRedirects(r, f"/queryapplydetail/1/", fetch_redirect_response=False) def test_dbaprinciples(self): """测试workflows页面""" data = {} - r = self.client.get(f'/dbaprinciples/', data=data) + r = self.client.get(f"/dbaprinciples/", data=data) self.assertEqual(r.status_code, 200) @@ -268,10 +290,10 @@ def setUp(self): 创建默认组给注册关联用户, 打开注册 """ archer_config = SysConfig() - archer_config.set('sign_up_enabled', 'true') + archer_config.set("sign_up_enabled", "true") archer_config.get_all_config() self.client = Client() - Group.objects.create(id=1, name='默认组') + Group.objects.create(id=1, name="默认组") def tearDown(self): SysConfig().purge() @@ -282,100 +304,133 @@ def test_sing_up_not_username(self): """ 用户名不能为空 """ - response = self.client.post('/signup/', data={}) + response = self.client.post("/signup/", data={}) data = json.loads(response.content) - content = {'status': 1, 'msg': '用户名和密码不能为空', 'data': None} + content = {"status": 1, "msg": "用户名和密码不能为空", "data": None} self.assertEqual(data, content) def test_sing_up_not_password(self): """ 密码不能为空 """ - response = self.client.post('/signup/', data={'username': 'test'}) + response = self.client.post("/signup/", data={"username": "test"}) data = json.loads(response.content) - content = {'status': 1, 'msg': '用户名和密码不能为空', 'data': None} + content = {"status": 1, "msg": "用户名和密码不能为空", "data": None} self.assertEqual(data, content) def test_sing_up_not_display(self): """ 中文名不能为空 """ - response = self.client.post('/signup/', data={'username': 'test', 'password': '123456test', - 'password2': '123456test', 'display': '', - 'email': '123@123.com'}) + response = self.client.post( + "/signup/", + data={ + "username": "test", + "password": "123456test", + "password2": "123456test", + "display": "", + "email": "123@123.com", + }, + ) data = json.loads(response.content) - content = {'status': 1, 'msg': '请填写中文名', 'data': None} + content = {"status": 1, "msg": "请填写中文名", "data": None} self.assertEqual(data, content) def test_sing_up_2password(self): """ 两次输入密码不一致 """ - response = self.client.post('/signup/', data={'username': 'test', 'password': '123456', 'password2': '12345'}) + response = self.client.post( + "/signup/", + data={"username": "test", "password": "123456", "password2": "12345"}, + ) data = json.loads(response.content) - content = {'status': 1, 'msg': '两次输入密码不一致', 'data': None} + content = {"status": 1, "msg": "两次输入密码不一致", "data": None} self.assertEqual(data, content) def test_sing_up_duplicate_uesrname(self): """ 用户名已存在 """ - User.objects.create(username='test', password='123456') - response = self.client.post('/signup/', - data={'username': 'test', 'password': '123456', 'password2': '123456'}) + User.objects.create(username="test", password="123456") + response = self.client.post( + "/signup/", + data={"username": "test", "password": "123456", "password2": "123456"}, + ) data = json.loads(response.content) - content = {'status': 1, 'msg': '用户名已存在', 'data': None} + content = {"status": 1, "msg": "用户名已存在", "data": None} self.assertEqual(data, content) def test_sing_up_invalid(self): """ 密码无效 """ - self.client.post('/signup/', - data={'username': 'test', 'password': '123456', - 'password2': '123456test', 'display': 'test', 'email': '123@123.com'}) + self.client.post( + "/signup/", + data={ + "username": "test", + "password": "123456", + "password2": "123456test", + "display": "test", + "email": "123@123.com", + }, + ) with self.assertRaises(User.DoesNotExist): - User.objects.get(username='test') + User.objects.get(username="test") - @patch('common.auth.init_user') + @patch("common.auth.init_user") def test_sing_up_valid(self, mock_init): """ 正常注册 """ - self.client.post('/signup/', - data={'username': 'test', 'password': '123456test', - 'password2': '123456test', 'display': 'test', 'email': '123@123.com'}) - user = User.objects.get(username='test') + self.client.post( + "/signup/", + data={ + "username": "test", + "password": "123456test", + "password2": "123456test", + "display": "test", + "email": "123@123.com", + }, + ) + user = User.objects.get(username="test") self.assertTrue(user) # 注册后登录 - r = self.client.post('/authenticate/', data={'username': 'test', 'password': '123456test'}, follow=False) + r = self.client.post( + "/authenticate/", + data={"username": "test", "password": "123456test"}, + follow=False, + ) r_json = r.json() - self.assertEqual(0, r_json['status']) + self.assertEqual(0, r_json["status"]) # 只允许初始化用户一次 mock_init.assert_called_once() class TestUser(TestCase): def setUp(self): - self.u1 = User(username='test_user', display='中文显示', is_active=True) - self.u1.set_password('test_password') + self.u1 = User(username="test_user", display="中文显示", is_active=True) + self.u1.set_password("test_password") self.u1.save() def tearDown(self): self.u1.delete() - @patch('common.auth.init_user') + @patch("common.auth.init_user") def testLogin(self, mock_init): """login 页面测试""" - r = self.client.get('/login/') + r = self.client.get("/login/") self.assertEqual(r.status_code, 200) - self.assertTemplateUsed(r, 'login.html') - r = self.client.post('/authenticate/', data={'username': 'test_user', 'password': 'test_password'}) + self.assertTemplateUsed(r, "login.html") + r = self.client.post( + "/authenticate/", + data={"username": "test_user", "password": "test_password"}, + ) r_json = r.json() - self.assertEqual(0, r_json['status']) + self.assertEqual(0, r_json["status"]) # 登录后直接跳首页 - r = self.client.get('/login/', follow=True) - self.assertRedirects(r, '/sqlworkflow/') + r = self.client.get("/login/", follow=True) + self.assertRedirects(r, "/sqlworkflow/") # init 只调用一次 mock_init.assert_called_once() @@ -401,18 +456,22 @@ class TestQueryPrivilegesCheck(TestCase): """测试权限校验""" def setUp(self): - self.superuser = User.objects.create(username='super', is_superuser=True) - self.user_can_query_all = User.objects.create(username='normaluser') - query_all_instance_perm = Permission.objects.get(codename='query_all_instances') + self.superuser = User.objects.create(username="super", is_superuser=True) + self.user_can_query_all = User.objects.create(username="normaluser") + query_all_instance_perm = Permission.objects.get(codename="query_all_instances") self.user_can_query_all.user_permissions.add(query_all_instance_perm) - self.user = User.objects.create(username='user') + self.user = User.objects.create(username="user") # 使用 travis.ci 时实例和测试service保持一致 - self.slave = Instance.objects.create(instance_name='test_instance', type='slave', db_type='mysql', - host=settings.DATABASES['default']['HOST'], - port=settings.DATABASES['default']['PORT'], - user=settings.DATABASES['default']['USER'], - password=settings.DATABASES['default']['PASSWORD']) - self.db_name = settings.DATABASES['default']['TEST']['NAME'] + self.slave = Instance.objects.create( + instance_name="test_instance", + type="slave", + db_type="mysql", + host=settings.DATABASES["default"]["HOST"], + port=settings.DATABASES["default"]["PORT"], + user=settings.DATABASES["default"]["USER"], + password=settings.DATABASES["default"]["PASSWORD"], + ) + self.db_name = settings.DATABASES["default"]["TEST"]["NAME"] self.sys_config = SysConfig() self.client = Client() @@ -428,9 +487,11 @@ def test_db_priv_super(self): 测试超级管理员验证数据库权限 :return: """ - self.sys_config.set('admin_query_limit', '50') + self.sys_config.set("admin_query_limit", "50") self.sys_config.get_all_config() - r = sql.query_privileges._db_priv(user=self.superuser, instance=self.slave, db_name=self.db_name) + r = sql.query_privileges._db_priv( + user=self.superuser, instance=self.slave, db_name=self.db_name + ) self.assertEqual(r, 50) def test_db_priv_user_priv_not_exist(self): @@ -438,7 +499,9 @@ def test_db_priv_user_priv_not_exist(self): 测试普通用户验证数据库权限,用户无权限 :return: """ - r = sql.query_privileges._db_priv(user=self.user, instance=self.slave, db_name=self.db_name) + r = sql.query_privileges._db_priv( + user=self.user, instance=self.slave, db_name=self.db_name + ) self.assertFalse(r) def test_db_priv_user_priv_exist(self): @@ -446,13 +509,17 @@ def test_db_priv_user_priv_exist(self): 测试普通用户验证数据库权限,用户有权限 :return: """ - QueryPrivileges.objects.create(user_name=self.user.username, - instance=self.slave, - db_name=self.db_name, - valid_date=date.today() + timedelta(days=1), - limit_num=10, - priv_type=1) - r = sql.query_privileges._db_priv(user=self.user, instance=self.slave, db_name=self.db_name) + QueryPrivileges.objects.create( + user_name=self.user.username, + instance=self.slave, + db_name=self.db_name, + valid_date=date.today() + timedelta(days=1), + limit_num=10, + priv_type=1, + ) + r = sql.query_privileges._db_priv( + user=self.user, instance=self.slave, db_name=self.db_name + ) self.assertTrue(r) def test_tb_priv_super(self): @@ -460,10 +527,14 @@ def test_tb_priv_super(self): 测试超级管理员验证表权限 :return: """ - self.sys_config.set('admin_query_limit', '50') + self.sys_config.set("admin_query_limit", "50") self.sys_config.get_all_config() - r = sql.query_privileges._tb_priv(user=self.superuser, instance=self.slave, db_name=self.db_name, - tb_name='table_name') + r = sql.query_privileges._tb_priv( + user=self.superuser, + instance=self.slave, + db_name=self.db_name, + tb_name="table_name", + ) self.assertEqual(r, 50) def test_tb_priv_user_priv_not_exist(self): @@ -471,8 +542,12 @@ def test_tb_priv_user_priv_not_exist(self): 测试普通用户验证表权限,用户无权限 :return: """ - r = sql.query_privileges._tb_priv(user=self.user, instance=self.slave, db_name=self.db_name, - tb_name='table_name') + r = sql.query_privileges._tb_priv( + user=self.user, + instance=self.slave, + db_name=self.db_name, + tb_name="table_name", + ) self.assertFalse(r) def test_tb_priv_user_priv_exist(self): @@ -480,29 +555,37 @@ def test_tb_priv_user_priv_exist(self): 测试普通用户验证表权限,用户有权限 :return: """ - QueryPrivileges.objects.create(user_name=self.user.username, - instance=self.slave, - db_name=self.db_name, - table_name='table_name', - valid_date=date.today() + timedelta(days=1), - limit_num=10, - priv_type=2) - r = sql.query_privileges._tb_priv(user=self.user, instance=self.slave, db_name=self.db_name, - tb_name='table_name') + QueryPrivileges.objects.create( + user_name=self.user.username, + instance=self.slave, + db_name=self.db_name, + table_name="table_name", + valid_date=date.today() + timedelta(days=1), + limit_num=10, + priv_type=2, + ) + r = sql.query_privileges._tb_priv( + user=self.user, + instance=self.slave, + db_name=self.db_name, + tb_name="table_name", + ) self.assertTrue(r) - @patch('sql.query_privileges._db_priv') + @patch("sql.query_privileges._db_priv") def test_priv_limit_from_db(self, __db_priv): """ 测试用户获取查询数量限制,通过库名获取 :return: """ __db_priv.return_value = 10 - r = sql.query_privileges._priv_limit(user=self.user, instance=self.slave, db_name=self.db_name) + r = sql.query_privileges._priv_limit( + user=self.user, instance=self.slave, db_name=self.db_name + ) self.assertEqual(r, 10) - @patch('sql.query_privileges._tb_priv') - @patch('sql.query_privileges._db_priv') + @patch("sql.query_privileges._tb_priv") + @patch("sql.query_privileges._db_priv") def test_priv_limit_from_tb(self, __db_priv, __tb_priv): """ 测试用户获取查询数量限制,通过表名获取 @@ -510,190 +593,286 @@ def test_priv_limit_from_tb(self, __db_priv, __tb_priv): """ __db_priv.return_value = 10 __tb_priv.return_value = 1 - r = sql.query_privileges._priv_limit(user=self.user, instance=self.slave, db_name=self.db_name, tb_name='test') + r = sql.query_privileges._priv_limit( + user=self.user, instance=self.slave, db_name=self.db_name, tb_name="test" + ) self.assertEqual(r, 1) - @patch('sql.engines.goinception.GoInceptionEngine.query_print') + @patch("sql.engines.goinception.GoInceptionEngine.query_print") def test_table_ref(self, _query_print): """ 测试通过goInception获取查询语句的table_ref :return: """ - _query_print.return_value = {'id': 2, 'statement': 'select * from sql_users limit 100', 'errlevel': 0, - 'query_tree': '{"text":"select * from sql_users limit 100","resultFields":null,"SQLCache":true,"CalcFoundRows":false,"StraightJoin":false,"Priority":0,"Distinct":false,"From":{"text":"","TableRefs":{"text":"","resultFields":null,"Left":{"text":"","Source":{"text":"","resultFields":null,"Schema":{"O":"","L":""},"Name":{"O":"sql_users","L":"sql_users"},"DBInfo":null,"TableInfo":null,"IndexHints":null},"AsName":{"O":"","L":""}},"Right":null,"Tp":0,"On":null,"Using":null,"NaturalJoin":false,"StraightJoin":false}},"Where":null,"Fields":{"text":"","Fields":[{"text":"","Offset":33,"WildCard":{"text":"","Table":{"O":"","L":""},"Schema":{"O":"","L":""}},"Expr":null,"AsName":{"O":"","L":""},"Auxiliary":false}]},"GroupBy":null,"Having":null,"OrderBy":null,"Limit":{"text":"","Count":{"text":"","k":2,"collation":0,"decimal":0,"length":0,"i":100,"b":null,"x":null,"Type":{"Tp":8,"Flag":160,"Flen":3,"Decimal":0,"Charset":"binary","Collate":"binary","Elems":null},"flag":0,"projectionOffset":-1},"Offset":null},"LockTp":0,"TableHints":null,"IsAfterUnionDistinct":false,"IsInBraces":false}', - 'errmsg': None} - r = sql.query_privileges._table_ref('select * from sql_users limit 100;', self.slave, self.db_name) - self.assertListEqual(r, [{'schema': 'test_archery', 'name': 'sql_users'}]) + _query_print.return_value = { + "id": 2, + "statement": "select * from sql_users limit 100", + "errlevel": 0, + "query_tree": '{"text":"select * from sql_users limit 100","resultFields":null,"SQLCache":true,"CalcFoundRows":false,"StraightJoin":false,"Priority":0,"Distinct":false,"From":{"text":"","TableRefs":{"text":"","resultFields":null,"Left":{"text":"","Source":{"text":"","resultFields":null,"Schema":{"O":"","L":""},"Name":{"O":"sql_users","L":"sql_users"},"DBInfo":null,"TableInfo":null,"IndexHints":null},"AsName":{"O":"","L":""}},"Right":null,"Tp":0,"On":null,"Using":null,"NaturalJoin":false,"StraightJoin":false}},"Where":null,"Fields":{"text":"","Fields":[{"text":"","Offset":33,"WildCard":{"text":"","Table":{"O":"","L":""},"Schema":{"O":"","L":""}},"Expr":null,"AsName":{"O":"","L":""},"Auxiliary":false}]},"GroupBy":null,"Having":null,"OrderBy":null,"Limit":{"text":"","Count":{"text":"","k":2,"collation":0,"decimal":0,"length":0,"i":100,"b":null,"x":null,"Type":{"Tp":8,"Flag":160,"Flen":3,"Decimal":0,"Charset":"binary","Collate":"binary","Elems":null},"flag":0,"projectionOffset":-1},"Offset":null},"LockTp":0,"TableHints":null,"IsAfterUnionDistinct":false,"IsInBraces":false}', + "errmsg": None, + } + r = sql.query_privileges._table_ref( + "select * from sql_users limit 100;", self.slave, self.db_name + ) + self.assertListEqual(r, [{"schema": "test_archery", "name": "sql_users"}]) - @patch('sql.engines.goinception.GoInceptionEngine.query_print') + @patch("sql.engines.goinception.GoInceptionEngine.query_print") def test_table_ref_wrong(self, _query_print): """ 测试通过goInception获取查询语句的table_ref :return: """ - _query_print.side_effect = RuntimeError('语法错误') + _query_print.side_effect = RuntimeError("语法错误") with self.assertRaises(RuntimeError): - sql.query_privileges._table_ref('select * from archery.sql_users;', self.slave, self.db_name) + sql.query_privileges._table_ref( + "select * from archery.sql_users;", self.slave, self.db_name + ) def test_query_priv_check_super(self): """ 测试用户权限校验,超级管理员不做校验,直接返回系统配置的limit :return: """ - r = sql.query_privileges.query_priv_check(user=self.superuser, - instance=self.slave, db_name=self.db_name, - sql_content="select * from archery.sql_users;", - limit_num=100) - self.assertDictEqual(r, {'status': 0, 'msg': 'ok', 'data': {'priv_check': True, 'limit_num': 100}}) - r = sql.query_privileges.query_priv_check(user=self.user_can_query_all, - instance=self.slave, db_name=self.db_name, - sql_content="select * from archery.sql_users;", - limit_num=100) - self.assertDictEqual(r, {'status': 0, 'msg': 'ok', 'data': {'priv_check': True, 'limit_num': 100}}) + r = sql.query_privileges.query_priv_check( + user=self.superuser, + instance=self.slave, + db_name=self.db_name, + sql_content="select * from archery.sql_users;", + limit_num=100, + ) + self.assertDictEqual( + r, + {"status": 0, "msg": "ok", "data": {"priv_check": True, "limit_num": 100}}, + ) + r = sql.query_privileges.query_priv_check( + user=self.user_can_query_all, + instance=self.slave, + db_name=self.db_name, + sql_content="select * from archery.sql_users;", + limit_num=100, + ) + self.assertDictEqual( + r, + {"status": 0, "msg": "ok", "data": {"priv_check": True, "limit_num": 100}}, + ) def test_query_priv_check_explain_or_show_create(self): """测试用户权限校验,explain和show create不做校验""" - r = sql.query_privileges.query_priv_check(user=self.user, - instance=self.slave, db_name=self.db_name, - sql_content="show create table archery.sql_users;", - limit_num=100) + r = sql.query_privileges.query_priv_check( + user=self.user, + instance=self.slave, + db_name=self.db_name, + sql_content="show create table archery.sql_users;", + limit_num=100, + ) self.assertTrue(r) - @patch('sql.query_privileges._table_ref', return_value=[{'schema': 'archery', 'name': 'sql_users'}]) - @patch('sql.query_privileges._tb_priv', return_value=False) - @patch('sql.query_privileges._db_priv', return_value=False) + @patch( + "sql.query_privileges._table_ref", + return_value=[{"schema": "archery", "name": "sql_users"}], + ) + @patch("sql.query_privileges._tb_priv", return_value=False) + @patch("sql.query_privileges._db_priv", return_value=False) def test_query_priv_check_no_priv(self, __db_priv, __tb_priv, __table_ref): """ 测试用户权限校验,mysql实例、普通用户 无库表权限,inception语法树正常打印 :return: """ - r = sql.query_privileges.query_priv_check(user=self.user, - instance=self.slave, db_name=self.db_name, - sql_content="select * from archery.sql_users;", - limit_num=100) - self.assertDictEqual(r, {'status': 2, 'msg': '你无archery.sql_users表的查询权限!请先到查询权限管理进行申请', - 'data': {'priv_check': True, 'limit_num': 0}}) + r = sql.query_privileges.query_priv_check( + user=self.user, + instance=self.slave, + db_name=self.db_name, + sql_content="select * from archery.sql_users;", + limit_num=100, + ) + self.assertDictEqual( + r, + { + "status": 2, + "msg": "你无archery.sql_users表的查询权限!请先到查询权限管理进行申请", + "data": {"priv_check": True, "limit_num": 0}, + }, + ) - @patch('sql.query_privileges._table_ref', return_value=[{'schema': 'archery', 'name': 'sql_users'}]) - @patch('sql.query_privileges._tb_priv', return_value=False) - @patch('sql.query_privileges._db_priv', return_value=1000) + @patch( + "sql.query_privileges._table_ref", + return_value=[{"schema": "archery", "name": "sql_users"}], + ) + @patch("sql.query_privileges._tb_priv", return_value=False) + @patch("sql.query_privileges._db_priv", return_value=1000) def test_query_priv_check_db_priv_exist(self, __db_priv, __tb_priv, __table_ref): """ 测试用户权限校验,mysql实例、普通用户 有库权限,inception语法树正常打印 :return: """ - r = sql.query_privileges.query_priv_check(user=self.user, - instance=self.slave, db_name=self.db_name, - sql_content="select * from archery.sql_users;", - limit_num=100) - self.assertDictEqual(r, {'data': {'limit_num': 100, 'priv_check': True}, 'msg': 'ok', 'status': 0}) + r = sql.query_privileges.query_priv_check( + user=self.user, + instance=self.slave, + db_name=self.db_name, + sql_content="select * from archery.sql_users;", + limit_num=100, + ) + self.assertDictEqual( + r, + {"data": {"limit_num": 100, "priv_check": True}, "msg": "ok", "status": 0}, + ) - @patch('sql.query_privileges._table_ref', return_value=[{'schema': 'archery', 'name': 'sql_users'}]) - @patch('sql.query_privileges._tb_priv', return_value=10) - @patch('sql.query_privileges._db_priv', return_value=False) + @patch( + "sql.query_privileges._table_ref", + return_value=[{"schema": "archery", "name": "sql_users"}], + ) + @patch("sql.query_privileges._tb_priv", return_value=10) + @patch("sql.query_privileges._db_priv", return_value=False) def test_query_priv_check_tb_priv_exist(self, __db_priv, __tb_priv, __table_ref): """ 测试用户权限校验,mysql实例、普通用户 ,有表权限,inception语法树正常打印 :return: """ - r = sql.query_privileges.query_priv_check(user=self.user, - instance=self.slave, db_name=self.db_name, - sql_content="select * from archery.sql_users;", - limit_num=100) - self.assertDictEqual(r, {'data': {'limit_num': 10, 'priv_check': True}, 'msg': 'ok', 'status': 0}) + r = sql.query_privileges.query_priv_check( + user=self.user, + instance=self.slave, + db_name=self.db_name, + sql_content="select * from archery.sql_users;", + limit_num=100, + ) + self.assertDictEqual( + r, {"data": {"limit_num": 10, "priv_check": True}, "msg": "ok", "status": 0} + ) - @patch('sql.query_privileges._table_ref') - @patch('sql.query_privileges._tb_priv', return_value=False) - @patch('sql.query_privileges._db_priv', return_value=False) - def test_query_priv_check_table_ref_Exception_and_no_db_priv(self, __db_priv, __tb_priv, __table_ref): + @patch("sql.query_privileges._table_ref") + @patch("sql.query_privileges._tb_priv", return_value=False) + @patch("sql.query_privileges._db_priv", return_value=False) + def test_query_priv_check_table_ref_Exception_and_no_db_priv( + self, __db_priv, __tb_priv, __table_ref + ): """ 测试用户权限校验,mysql实例、普通用户 ,inception语法树抛出异常 :return: """ - __table_ref.side_effect = RuntimeError('语法错误') + __table_ref.side_effect = RuntimeError("语法错误") self.sys_config.get_all_config() - r = sql.query_privileges.query_priv_check(user=self.user, - instance=self.slave, db_name=self.db_name, - sql_content="select * from archery.sql_users;", - limit_num=100) - self.assertDictEqual(r, {'status': 1, - 'msg': "无法校验查询语句权限,请联系管理员,错误信息:语法错误", - 'data': {'priv_check': True, 'limit_num': 0}}) - - @patch('sql.query_privileges._db_priv', return_value=1000) + r = sql.query_privileges.query_priv_check( + user=self.user, + instance=self.slave, + db_name=self.db_name, + sql_content="select * from archery.sql_users;", + limit_num=100, + ) + self.assertDictEqual( + r, + { + "status": 1, + "msg": "无法校验查询语句权限,请联系管理员,错误信息:语法错误", + "data": {"priv_check": True, "limit_num": 0}, + }, + ) + + @patch("sql.query_privileges._db_priv", return_value=1000) def test_query_priv_check_not_mysql_db_priv_exist(self, __db_priv): """ 测试用户权限校验,非mysql实例、普通用户 有库权限 :return: """ - mssql_instance = Instance(instance_name='mssql', type='slave', db_type='mssql', - host='some_host', port=3306, user='some_user', password='some_str') - r = sql.query_privileges.query_priv_check(user=self.user, - instance=mssql_instance, db_name=self.db_name, - sql_content="select * from archery.sql_users;", - limit_num=100) - self.assertDictEqual(r, {'data': {'limit_num': 100, 'priv_check': True}, 'msg': 'ok', 'status': 0}) + mssql_instance = Instance( + instance_name="mssql", + type="slave", + db_type="mssql", + host="some_host", + port=3306, + user="some_user", + password="some_str", + ) + r = sql.query_privileges.query_priv_check( + user=self.user, + instance=mssql_instance, + db_name=self.db_name, + sql_content="select * from archery.sql_users;", + limit_num=100, + ) + self.assertDictEqual( + r, + {"data": {"limit_num": 100, "priv_check": True}, "msg": "ok", "status": 0}, + ) - @patch('sql.query_privileges._db_priv', return_value=False) + @patch("sql.query_privileges._db_priv", return_value=False) def test_query_priv_check_not_mysql_db_priv_not_exist(self, __db_priv): """ 测试用户权限校验,非mysql实例、普通用户 无库权限 :return: """ - mssql_instance = Instance(instance_name='mssql', type='slave', db_type='oracle', - host='some_host', port=3306, user='some_user', password='some_str') - r = sql.query_privileges.query_priv_check(user=self.user, - instance=mssql_instance, db_name=self.db_name, - sql_content="select * from archery.sql_users;", - limit_num=100) - self.assertDictEqual(r, {'data': {'limit_num': 0, 'priv_check': True}, - 'msg': '你无archery数据库的查询权限!请先到查询权限管理进行申请', - 'status': 2}) + mssql_instance = Instance( + instance_name="mssql", + type="slave", + db_type="oracle", + host="some_host", + port=3306, + user="some_user", + password="some_str", + ) + r = sql.query_privileges.query_priv_check( + user=self.user, + instance=mssql_instance, + db_name=self.db_name, + sql_content="select * from archery.sql_users;", + limit_num=100, + ) + self.assertDictEqual( + r, + { + "data": {"limit_num": 0, "priv_check": True}, + "msg": "你无archery数据库的查询权限!请先到查询权限管理进行申请", + "status": 2, + }, + ) class TestQueryPrivilegesApply(TestCase): """测试权限列表、权限管理""" def setUp(self): - self.superuser = User.objects.create(username='super', is_superuser=True) - self.user = User.objects.create(username='user') + self.superuser = User.objects.create(username="super", is_superuser=True) + self.user = User.objects.create(username="user") # 使用 travis.ci 时实例和测试service保持一致 - self.slave = Instance.objects.create(instance_name='test_instance', type='slave', db_type='mysql', - host=settings.DATABASES['default']['HOST'], - port=settings.DATABASES['default']['PORT'], - user=settings.DATABASES['default']['USER'], - password=settings.DATABASES['default']['PASSWORD']) - self.db_name = settings.DATABASES['default']['TEST']['NAME'] + self.slave = Instance.objects.create( + instance_name="test_instance", + type="slave", + db_type="mysql", + host=settings.DATABASES["default"]["HOST"], + port=settings.DATABASES["default"]["PORT"], + user=settings.DATABASES["default"]["USER"], + password=settings.DATABASES["default"]["PASSWORD"], + ) + self.db_name = settings.DATABASES["default"]["TEST"]["NAME"] self.sys_config = SysConfig() self.client = Client() tomorrow = datetime.today() + timedelta(days=1) - self.group = ResourceGroup.objects.create(group_id=1, group_name='group_name') + self.group = ResourceGroup.objects.create(group_id=1, group_name="group_name") self.query_apply_1 = QueryPrivilegesApply.objects.create( group_id=self.group.group_id, group_name=self.group.group_name, - title='some_title1', - user_name='some_user', + title="some_title1", + user_name="some_user", instance=self.slave, - db_list='some_db,some_db2', + db_list="some_db,some_db2", limit_num=100, valid_date=tomorrow, priv_type=1, status=0, - audit_auth_groups='some_audit_group' + audit_auth_groups="some_audit_group", ) self.query_apply_2 = QueryPrivilegesApply.objects.create( group_id=2, - group_name='some_group2', - title='some_title2', - user_name='some_user', + group_name="some_group2", + title="some_title2", + user_name="some_user", instance=self.slave, - db_list='some_db', - table_list='some_table,some_tb2', + db_list="some_db", + table_list="some_table,some_tb2", limit_num=100, valid_date=tomorrow, priv_type=2, status=0, - audit_auth_groups='some_audit_group' + audit_auth_groups="some_audit_group", ) def tearDown(self): @@ -708,189 +887,263 @@ def tearDown(self): def test_query_audit_call_back(self): """测试权限申请工单回调""" # 工单状态改为审核失败, 验证工单状态 - sql.query_privileges._query_apply_audit_call_back(self.query_apply_1.apply_id, 2) + sql.query_privileges._query_apply_audit_call_back( + self.query_apply_1.apply_id, 2 + ) self.query_apply_1.refresh_from_db() self.assertEqual(self.query_apply_1.status, 2) - for db in self.query_apply_1.db_list.split(','): - self.assertEqual(len(QueryPrivileges.objects.filter( - user_name=self.query_apply_1.user_name, - db_name=db, - limit_num=100)), 0) + for db in self.query_apply_1.db_list.split(","): + self.assertEqual( + len( + QueryPrivileges.objects.filter( + user_name=self.query_apply_1.user_name, + db_name=db, + limit_num=100, + ) + ), + 0, + ) # 工单改为审核成功, 验证工单状态和权限状态 - sql.query_privileges._query_apply_audit_call_back(self.query_apply_1.apply_id, 1) + sql.query_privileges._query_apply_audit_call_back( + self.query_apply_1.apply_id, 1 + ) self.query_apply_1.refresh_from_db() self.assertEqual(self.query_apply_1.status, 1) - for db in self.query_apply_1.db_list.split(','): - self.assertEqual(len(QueryPrivileges.objects.filter( - user_name=self.query_apply_1.user_name, - db_name=db, - limit_num=100)), 1) + for db in self.query_apply_1.db_list.split(","): + self.assertEqual( + len( + QueryPrivileges.objects.filter( + user_name=self.query_apply_1.user_name, + db_name=db, + limit_num=100, + ) + ), + 1, + ) # 表权限申请测试, 只测试审核成功 - sql.query_privileges._query_apply_audit_call_back(self.query_apply_2.apply_id, 1) + sql.query_privileges._query_apply_audit_call_back( + self.query_apply_2.apply_id, 1 + ) self.query_apply_2.refresh_from_db() self.assertEqual(self.query_apply_2.status, 1) - for tb in self.query_apply_2.table_list.split(','): - self.assertEqual(len(QueryPrivileges.objects.filter( - user_name=self.query_apply_2.user_name, - db_name=self.query_apply_2.db_list, - table_name=tb, - limit_num=self.query_apply_2.limit_num)), 1) + for tb in self.query_apply_2.table_list.split(","): + self.assertEqual( + len( + QueryPrivileges.objects.filter( + user_name=self.query_apply_2.user_name, + db_name=self.query_apply_2.db_list, + table_name=tb, + limit_num=self.query_apply_2.limit_num, + ) + ), + 1, + ) def test_query_priv_apply_list_super_with_search(self): """ 测试权限申请列表,管理员查看所有用户,并且搜索 """ - data = { - "limit": 14, - "offset": 0, - "search": 'some_title1' - } + data = {"limit": 14, "offset": 0, "search": "some_title1"} self.client.force_login(self.superuser) - r = self.client.post(path='/query/applylist/', data=data) - self.assertEqual(json.loads(r.content)['total'], 1) - keys = list(json.loads(r.content)['rows'][0].keys()) - self.assertListEqual(keys, - ['apply_id', 'title', 'instance__instance_name', 'db_list', 'priv_type', 'table_list', - 'limit_num', 'valid_date', 'user_display', 'status', 'create_time', 'group_name']) + r = self.client.post(path="/query/applylist/", data=data) + self.assertEqual(json.loads(r.content)["total"], 1) + keys = list(json.loads(r.content)["rows"][0].keys()) + self.assertListEqual( + keys, + [ + "apply_id", + "title", + "instance__instance_name", + "db_list", + "priv_type", + "table_list", + "limit_num", + "valid_date", + "user_display", + "status", + "create_time", + "group_name", + ], + ) def test_query_priv_apply_list_with_query_review_perm(self): """ 测试权限申请列表,普通用户,拥有sql.query_review权限,在组内 """ - data = { - "limit": 14, - "offset": 0, - "search": '' - } + data = {"limit": 14, "offset": 0, "search": ""} - menu_queryapplylist = Permission.objects.get(codename='menu_queryapplylist') + menu_queryapplylist = Permission.objects.get(codename="menu_queryapplylist") self.user.user_permissions.add(menu_queryapplylist) - query_review = Permission.objects.get(codename='query_review') + query_review = Permission.objects.get(codename="query_review") self.user.user_permissions.add(query_review) self.user.resource_group.add(self.group) self.client.force_login(self.user) - r = self.client.post(path='/query/applylist/', data=data) - self.assertEqual(json.loads(r.content)['total'], 1) - keys = list(json.loads(r.content)['rows'][0].keys()) - self.assertListEqual(keys, - ['apply_id', 'title', 'instance__instance_name', 'db_list', 'priv_type', 'table_list', - 'limit_num', 'valid_date', 'user_display', 'status', 'create_time', 'group_name']) + r = self.client.post(path="/query/applylist/", data=data) + self.assertEqual(json.loads(r.content)["total"], 1) + keys = list(json.loads(r.content)["rows"][0].keys()) + self.assertListEqual( + keys, + [ + "apply_id", + "title", + "instance__instance_name", + "db_list", + "priv_type", + "table_list", + "limit_num", + "valid_date", + "user_display", + "status", + "create_time", + "group_name", + ], + ) def test_query_priv_apply_list_no_query_review_perm(self): """ 测试权限申请列表,普通用户,无sql.query_review权限,在组内 """ - data = { - "limit": 14, - "offset": 0, - "search": '' - } + data = {"limit": 14, "offset": 0, "search": ""} - menu_queryapplylist = Permission.objects.get(codename='menu_queryapplylist') + menu_queryapplylist = Permission.objects.get(codename="menu_queryapplylist") self.user.user_permissions.add(menu_queryapplylist) self.user.resource_group.add(self.group) self.client.force_login(self.user) - r = self.client.post(path='/query/applylist/', data=data) + r = self.client.post(path="/query/applylist/", data=data) self.assertEqual(json.loads(r.content), {"total": 0, "rows": []}) def test_user_query_priv_with_search(self): """ 测试权限申请列表,管理员查看所有用户,并且搜索 """ - data = { - "limit": 14, - "offset": 0, - "search": 'user' - } - QueryPrivileges.objects.create(user_name=self.user.username, - user_display='user2', - instance=self.slave, - db_name=self.db_name, - table_name='table_name', - valid_date=date.today() + timedelta(days=1), - limit_num=10, - priv_type=2) + data = {"limit": 14, "offset": 0, "search": "user"} + QueryPrivileges.objects.create( + user_name=self.user.username, + user_display="user2", + instance=self.slave, + db_name=self.db_name, + table_name="table_name", + valid_date=date.today() + timedelta(days=1), + limit_num=10, + priv_type=2, + ) self.client.force_login(self.superuser) - r = self.client.post(path='/query/userprivileges/', data=data) - self.assertEqual(json.loads(r.content)['total'], 1) - keys = list(json.loads(r.content)['rows'][0].keys()) - self.assertListEqual(keys, - ['privilege_id', 'user_display', 'instance__instance_name', 'db_name', 'priv_type', - 'table_name', 'limit_num', 'valid_date']) + r = self.client.post(path="/query/userprivileges/", data=data) + self.assertEqual(json.loads(r.content)["total"], 1) + keys = list(json.loads(r.content)["rows"][0].keys()) + self.assertListEqual( + keys, + [ + "privilege_id", + "user_display", + "instance__instance_name", + "db_name", + "priv_type", + "table_name", + "limit_num", + "valid_date", + ], + ) def test_user_query_priv_with_query_mgtpriv(self): """ 测试权限申请列表,普通用户,拥有sql.query_mgtpriv权限,在组内 """ - data = { - "limit": 14, - "offset": 0, - "search": 'user' - } - QueryPrivileges.objects.create(user_name='some_name', - user_display='user2', - instance=self.slave, - db_name=self.db_name, - table_name='table_name', - valid_date=date.today() + timedelta(days=1), - limit_num=10, - priv_type=2) - menu_queryapplylist = Permission.objects.get(codename='menu_queryapplylist') + data = {"limit": 14, "offset": 0, "search": "user"} + QueryPrivileges.objects.create( + user_name="some_name", + user_display="user2", + instance=self.slave, + db_name=self.db_name, + table_name="table_name", + valid_date=date.today() + timedelta(days=1), + limit_num=10, + priv_type=2, + ) + menu_queryapplylist = Permission.objects.get(codename="menu_queryapplylist") self.user.user_permissions.add(menu_queryapplylist) - query_mgtpriv = Permission.objects.get(codename='query_mgtpriv') + query_mgtpriv = Permission.objects.get(codename="query_mgtpriv") self.user.user_permissions.add(query_mgtpriv) self.user.resource_group.add(self.group) self.client.force_login(self.user) - r = self.client.post(path='/query/userprivileges/', data=data) - self.assertEqual(json.loads(r.content)['total'], 1) - keys = list(json.loads(r.content)['rows'][0].keys()) - self.assertListEqual(keys, - ['privilege_id', 'user_display', 'instance__instance_name', 'db_name', 'priv_type', - 'table_name', 'limit_num', 'valid_date']) + r = self.client.post(path="/query/userprivileges/", data=data) + self.assertEqual(json.loads(r.content)["total"], 1) + keys = list(json.loads(r.content)["rows"][0].keys()) + self.assertListEqual( + keys, + [ + "privilege_id", + "user_display", + "instance__instance_name", + "db_name", + "priv_type", + "table_name", + "limit_num", + "valid_date", + ], + ) def test_user_query_priv_no_query_mgtpriv(self): """ 测试权限申请列表,普通用户,没有sql.query_mgtpriv权限,在组内 """ - data = { - "limit": 14, - "offset": 0, - "search": 'user' - } - QueryPrivileges.objects.create(user_name='some_name', - user_display='user2', - instance=self.slave, - db_name=self.db_name, - table_name='table_name', - valid_date=date.today() + timedelta(days=1), - limit_num=10, - priv_type=2) - menu_queryapplylist = Permission.objects.get(codename='menu_queryapplylist') + data = {"limit": 14, "offset": 0, "search": "user"} + QueryPrivileges.objects.create( + user_name="some_name", + user_display="user2", + instance=self.slave, + db_name=self.db_name, + table_name="table_name", + valid_date=date.today() + timedelta(days=1), + limit_num=10, + priv_type=2, + ) + menu_queryapplylist = Permission.objects.get(codename="menu_queryapplylist") self.user.user_permissions.add(menu_queryapplylist) self.user.resource_group.add(self.group) self.client.force_login(self.user) - r = self.client.post(path='/query/userprivileges/', data=data) + r = self.client.post(path="/query/userprivileges/", data=data) self.assertEqual(json.loads(r.content), {"total": 0, "rows": []}) class TestQuery(TransactionTestCase): def setUp(self): - self.slave1 = Instance(instance_name='test_slave_instance', type='slave', db_type='mysql', - host='testhost', port=3306, user='mysql_user', password='mysql_password') - self.slave2 = Instance(instance_name='test_instance_non_mysql', type='slave', db_type='mssql', - host='some_host2', port=3306, user='some_user', password='some_str') + self.slave1 = Instance( + instance_name="test_slave_instance", + type="slave", + db_type="mysql", + host="testhost", + port=3306, + user="mysql_user", + password="mysql_password", + ) + self.slave2 = Instance( + instance_name="test_instance_non_mysql", + type="slave", + db_type="mssql", + host="some_host2", + port=3306, + user="some_user", + password="some_str", + ) self.slave1.save() self.slave2.save() - self.superuser1 = User.objects.create(username='super1', is_superuser=True) - self.u1 = User.objects.create(username='test_user', display='中文显示', is_active=True) - self.u2 = User.objects.create(username='test_user2', display='中文显示', is_active=True) - self.query_log = QueryLog.objects.create(instance_name=self.slave1.instance_name, - db_name='some_db', - sqllog='select 1;', - effect_row=10, - cost_time=1, - username=self.superuser1.username) - sql_query_perm = Permission.objects.get(codename='query_submit') + self.superuser1 = User.objects.create(username="super1", is_superuser=True) + self.u1 = User.objects.create( + username="test_user", display="中文显示", is_active=True + ) + self.u2 = User.objects.create( + username="test_user2", display="中文显示", is_active=True + ) + self.query_log = QueryLog.objects.create( + instance_name=self.slave1.instance_name, + db_name="some_db", + sqllog="select 1;", + effect_row=10, + cost_time=1, + username=self.superuser1.username, + ) + sql_query_perm = Permission.objects.get(codename="query_submit") self.u2.user_permissions.add(sql_query_perm) def tearDown(self): @@ -902,90 +1155,139 @@ def tearDown(self): self.slave1.delete() self.slave2.delete() archer_config = SysConfig() - archer_config.set('disable_star', False) + archer_config.set("disable_star", False) - @patch('sql.query.user_instances') - @patch('sql.query.get_engine') - @patch('sql.query.query_priv_check') + @patch("sql.query.user_instances") + @patch("sql.query.get_engine") + @patch("sql.query.query_priv_check") def testCorrectSQL(self, _priv_check, _get_engine, _user_instances): c = Client() - some_sql = 'select some from some_table limit 100;' - some_db = 'some_db' + some_sql = "select some from some_table limit 100;" + some_db = "some_db" some_limit = 100 c.force_login(self.u1) - r = c.post('/query/', data={'instance_name': self.slave1.instance_name, - 'sql_content': some_sql, - 'db_name': some_db, - 'limit_num': some_limit}) + r = c.post( + "/query/", + data={ + "instance_name": self.slave1.instance_name, + "sql_content": some_sql, + "db_name": some_db, + "limit_num": some_limit, + }, + ) self.assertEqual(r.status_code, 403) c.force_login(self.u2) - q_result = ResultSet(full_sql=some_sql, rows=['value']) - q_result.column_list = ['some'] + q_result = ResultSet(full_sql=some_sql, rows=["value"]) + q_result.column_list = ["some"] _get_engine.return_value.query_check.return_value = { - 'msg': '', 'bad_query': False, 'filtered_sql': some_sql, 'has_star': False} + "msg": "", + "bad_query": False, + "filtered_sql": some_sql, + "has_star": False, + } _get_engine.return_value.filter_sql.return_value = some_sql _get_engine.return_value.query.return_value = q_result _get_engine.return_value.seconds_behind_master = 100 - _priv_check.return_value = {'status': 0, 'data': {'limit_num': 100, 'priv_check': True}} + _priv_check.return_value = { + "status": 0, + "data": {"limit_num": 100, "priv_check": True}, + } _user_instances.return_value.get.return_value = self.slave1 - r = c.post('/query/', data={'instance_name': self.slave1.instance_name, - 'sql_content': some_sql, - 'db_name': some_db, - 'limit_num': some_limit}) + r = c.post( + "/query/", + data={ + "instance_name": self.slave1.instance_name, + "sql_content": some_sql, + "db_name": some_db, + "limit_num": some_limit, + }, + ) _get_engine.return_value.query.assert_called_once_with( - some_db, some_sql, some_limit, schema_name=None, tb_name=None, max_execution_time=60000) + some_db, + some_sql, + some_limit, + schema_name=None, + tb_name=None, + max_execution_time=60000, + ) r_json = r.json() - self.assertEqual(r_json['data']['rows'], ['value']) - self.assertEqual(r_json['data']['column_list'], ['some']) - self.assertEqual(r_json['data']['seconds_behind_master'], 100) + self.assertEqual(r_json["data"]["rows"], ["value"]) + self.assertEqual(r_json["data"]["column_list"], ["some"]) + self.assertEqual(r_json["data"]["seconds_behind_master"], 100) - @patch('sql.query.user_instances') - @patch('sql.query.get_engine') - @patch('sql.query.query_priv_check') + @patch("sql.query.user_instances") + @patch("sql.query.get_engine") + @patch("sql.query.query_priv_check") def testSQLWithoutLimit(self, _priv_check, _get_engine, _user_instances): c = Client() some_limit = 100 - sql_without_limit = 'select some from some_table' - sql_with_limit = 'select some from some_table limit {0};'.format(some_limit) - some_db = 'some_db' + sql_without_limit = "select some from some_table" + sql_with_limit = "select some from some_table limit {0};".format(some_limit) + some_db = "some_db" c.force_login(self.u2) - q_result = ResultSet(full_sql=sql_without_limit, rows=['value']) - q_result.column_list = ['some'] + q_result = ResultSet(full_sql=sql_without_limit, rows=["value"]) + q_result.column_list = ["some"] _get_engine.return_value.query_check.return_value = { - 'msg': '', 'bad_query': False, 'filtered_sql': sql_without_limit, 'has_star': False} + "msg": "", + "bad_query": False, + "filtered_sql": sql_without_limit, + "has_star": False, + } _get_engine.return_value.filter_sql.return_value = sql_with_limit _get_engine.return_value.query.return_value = q_result - _priv_check.return_value = {'status': 0, 'data': {'limit_num': 100, 'priv_check': True}} + _priv_check.return_value = { + "status": 0, + "data": {"limit_num": 100, "priv_check": True}, + } _user_instances.return_value.get.return_value = self.slave1 - r = c.post('/query/', data={'instance_name': self.slave1.instance_name, - 'sql_content': sql_without_limit, - 'db_name': some_db, - 'limit_num': some_limit}) + r = c.post( + "/query/", + data={ + "instance_name": self.slave1.instance_name, + "sql_content": sql_without_limit, + "db_name": some_db, + "limit_num": some_limit, + }, + ) _get_engine.return_value.query.assert_called_once_with( - some_db, sql_with_limit, some_limit, schema_name=None, tb_name=None, max_execution_time=60000) + some_db, + sql_with_limit, + some_limit, + schema_name=None, + tb_name=None, + max_execution_time=60000, + ) r_json = r.json() - self.assertEqual(r_json['data']['rows'], ['value']) - self.assertEqual(r_json['data']['column_list'], ['some']) + self.assertEqual(r_json["data"]["rows"], ["value"]) + self.assertEqual(r_json["data"]["column_list"], ["some"]) - @patch('sql.query.query_priv_check') + @patch("sql.query.query_priv_check") def testStarOptionOn(self, _priv_check): c = Client() c.force_login(self.u2) some_limit = 100 - sql_with_star = 'select * from some_table' - some_db = 'some_db' - _priv_check.return_value = {'status': 0, 'data': {'limit_num': 100, 'priv_check': True}} + sql_with_star = "select * from some_table" + some_db = "some_db" + _priv_check.return_value = { + "status": 0, + "data": {"limit_num": 100, "priv_check": True}, + } archer_config = SysConfig() - archer_config.set('disable_star', True) - r = c.post('/query/', data={'instance_name': self.slave1.instance_name, - 'sql_content': sql_with_star, - 'db_name': some_db, - 'limit_num': some_limit}) - archer_config.set('disable_star', False) + archer_config.set("disable_star", True) + r = c.post( + "/query/", + data={ + "instance_name": self.slave1.instance_name, + "sql_content": sql_with_star, + "db_name": some_db, + "limit_num": some_limit, + }, + ) + archer_config.set("disable_star", False) r_json = r.json() - self.assertEqual(1, r_json['status']) + self.assertEqual(1, r_json["status"]) - @patch('sql.query.get_engine') + @patch("sql.query.get_engine") def test_kill_query_conn(self, _get_engine): kill_query_conn(self.slave1.id, 10) _get_engine.return_value.kill_connection.return_value = ResultSet() @@ -994,110 +1296,121 @@ def test_query_log(self): """测试获取查询历史""" c = Client() c.force_login(self.superuser1) - QueryLog(id=self.query_log.id, favorite=True, alias='test_a').save(update_fields=['favorite', 'alias']) - data = {"star": "true", - "query_log_id": self.query_log.id, - "limit": 14, - "offset": 0, } - r = c.get('/query/querylog/', data=data) - self.assertEqual(r.json()['total'], 1) + QueryLog(id=self.query_log.id, favorite=True, alias="test_a").save( + update_fields=["favorite", "alias"] + ) + data = { + "star": "true", + "query_log_id": self.query_log.id, + "limit": 14, + "offset": 0, + } + r = c.get("/query/querylog/", data=data) + self.assertEqual(r.json()["total"], 1) def test_star(self): """测试查询语句收藏""" c = Client() c.force_login(self.superuser1) - r = c.post('/query/favorite/', data={'query_log_id': self.query_log.id, - 'star': 'true', - 'alias': 'test_alias'}) + r = c.post( + "/query/favorite/", + data={ + "query_log_id": self.query_log.id, + "star": "true", + "alias": "test_alias", + }, + ) query_log = QueryLog.objects.get(id=self.query_log.id) self.assertTrue(query_log.favorite) - self.assertEqual(query_log.alias, 'test_alias') + self.assertEqual(query_log.alias, "test_alias") def test_un_star(self): """测试查询语句取消收藏""" c = Client() c.force_login(self.superuser1) - r = c.post('/query/favorite/', data={'query_log_id': self.query_log.id, - 'star': 'false', - 'alias': ''}) + r = c.post( + "/query/favorite/", + data={"query_log_id": self.query_log.id, "star": "false", "alias": ""}, + ) r_json = r.json() query_log = QueryLog.objects.get(id=self.query_log.id) self.assertFalse(query_log.favorite) - self.assertEqual(query_log.alias, '') + self.assertEqual(query_log.alias, "") class TestWorkflowView(TransactionTestCase): - def setUp(self): self.now = datetime.now() - can_view_permission = Permission.objects.get(codename='menu_sqlworkflow') - can_execute_permission = Permission.objects.get(codename='sql_execute') - can_execute_resource_permission = Permission.objects.get(codename='sql_execute_for_resource_group') - self.u1 = User(username='some_user', display='用户1') + can_view_permission = Permission.objects.get(codename="menu_sqlworkflow") + can_execute_permission = Permission.objects.get(codename="sql_execute") + can_execute_resource_permission = Permission.objects.get( + codename="sql_execute_for_resource_group" + ) + self.u1 = User(username="some_user", display="用户1") self.u1.save() self.u1.user_permissions.add(can_view_permission) - self.u2 = User(username='some_user2', display='用户2') + self.u2 = User(username="some_user2", display="用户2") self.u2.save() self.u2.user_permissions.add(can_view_permission) - self.u3 = User(username='some_user3', display='用户3') + self.u3 = User(username="some_user3", display="用户3") self.u3.save() self.u3.user_permissions.add(can_view_permission) - self.executor1 = User(username='some_executor', display='执行者') + self.executor1 = User(username="some_executor", display="执行者") self.executor1.save() - self.executor1.user_permissions.add(can_view_permission, can_execute_permission, - can_execute_resource_permission) - self.superuser1 = User(username='super1', is_superuser=True) + self.executor1.user_permissions.add( + can_view_permission, can_execute_permission, can_execute_resource_permission + ) + self.superuser1 = User(username="super1", is_superuser=True) self.superuser1.save() - self.master1 = Instance(instance_name='test_master_instance', type='master', db_type='mysql', - host='testhost', port=3306, user='mysql_user', password='mysql_password') + self.master1 = Instance( + instance_name="test_master_instance", + type="master", + db_type="mysql", + host="testhost", + port=3306, + user="mysql_user", + password="mysql_password", + ) self.master1.save() self.wf1 = SqlWorkflow.objects.create( - workflow_name='some_name', + workflow_name="some_name", group_id=1, - group_name='g1', + group_name="g1", engineer=self.u1.username, engineer_display=self.u1.display, - audit_auth_groups='some_group', + audit_auth_groups="some_group", create_time=self.now - timedelta(days=1), - status='workflow_finish', + status="workflow_finish", is_backup=True, instance=self.master1, - db_name='some_db', + db_name="some_db", syntax_type=1, ) self.wfc1 = SqlWorkflowContent.objects.create( workflow=self.wf1, - sql_content='some_sql', - execute_result=json.dumps([{ - 'id': 1, - 'sql': 'some_content' - }]) + sql_content="some_sql", + execute_result=json.dumps([{"id": 1, "sql": "some_content"}]), ) self.wf2 = SqlWorkflow.objects.create( - workflow_name='some_name2', + workflow_name="some_name2", group_id=1, - group_name='g1', + group_name="g1", engineer=self.u2.username, engineer_display=self.u2.display, - audit_auth_groups='some_group', + audit_auth_groups="some_group", create_time=self.now - timedelta(days=1), - status='workflow_manreviewing', + status="workflow_manreviewing", is_backup=True, instance=self.master1, - db_name='some_db', - syntax_type=1 + db_name="some_db", + syntax_type=1, ) self.wfc2 = SqlWorkflowContent.objects.create( workflow=self.wf2, - sql_content='some_sql', - execute_result=json.dumps([{ - 'id': 1, - 'sql': 'some_content' - }]) - ) - self.resource_group1 = ResourceGroup( - group_name='some_group' + sql_content="some_sql", + execute_result=json.dumps([{"id": 1, "sql": "some_content"}]), ) + self.resource_group1 = ResourceGroup(group_name="some_group") self.resource_group1.save() def tearDown(self): @@ -1113,20 +1426,21 @@ def testWorkflowStatus(self): """测试获取工单状态""" c = Client(header={}) c.force_login(self.u1) - r = c.post('/getWorkflowStatus/', {'workflow_id': self.wf1.id}) + r = c.post("/getWorkflowStatus/", {"workflow_id": self.wf1.id}) r_json = r.json() - self.assertEqual(r_json['status'], 'workflow_finish') + self.assertEqual(r_json["status"], "workflow_finish") def test_check_param_is_None(self): """测试工单检测,参数内容为空""" c = Client() c.force_login(self.superuser1) data = {"instance_name": self.master1.instance_name} - r = c.post('/simplecheck/', data=data) - self.assertDictEqual(json.loads(r.content), - {'status': 1, 'msg': '页面提交参数可能为空', 'data': {}}) + r = c.post("/simplecheck/", data=data) + self.assertDictEqual( + json.loads(r.content), {"status": 1, "msg": "页面提交参数可能为空", "data": {}} + ) - @patch('sql.sql_workflow.get_engine') + @patch("sql.sql_workflow.get_engine") def test_check_inception_Exception(self, _get_engine): """测试工单检测,inception报错""" c = Client() @@ -1136,14 +1450,13 @@ def test_check_inception_Exception(self, _get_engine): "instance_name": self.master1.instance_name, "db_name": "archery", } - _get_engine.side_effect = RuntimeError('RuntimeError') - r = c.post('/simplecheck/', data=data) - self.assertDictEqual(json.loads(r.content), - {'status': 1, - 'msg': "RuntimeError", - 'data': {}}) - - @patch('sql.sql_workflow.get_engine') + _get_engine.side_effect = RuntimeError("RuntimeError") + r = c.post("/simplecheck/", data=data) + self.assertDictEqual( + json.loads(r.content), {"status": 1, "msg": "RuntimeError", "data": {}} + ) + + @patch("sql.sql_workflow.get_engine") def test_check(self, _get_engine): """测试工单检测,正常返回""" c = Client() @@ -1154,378 +1467,459 @@ def test_check(self, _get_engine): "db_name": "archery", } column_list = [ - 'id', 'stage', 'errlevel', 'stagestatus', 'errormessage', 'sql', 'affected_rows', 'sequence', - 'backup_dbname', 'execute_time', 'sqlsha1', 'backup_time', 'actual_affected_rows'] - - rows = [ReviewResult(id=1, - stage='CHECKED', - errlevel=0, - stagestatus='Audit Completed', - errormessage='', - sql='use `archer`', - affected_rows=0, - actual_affected_rows=0, - sequence='0_0_00000000', - backup_dbname='', - execute_time='0', - sqlsha1='')] + "id", + "stage", + "errlevel", + "stagestatus", + "errormessage", + "sql", + "affected_rows", + "sequence", + "backup_dbname", + "execute_time", + "sqlsha1", + "backup_time", + "actual_affected_rows", + ] + + rows = [ + ReviewResult( + id=1, + stage="CHECKED", + errlevel=0, + stagestatus="Audit Completed", + errormessage="", + sql="use `archer`", + affected_rows=0, + actual_affected_rows=0, + sequence="0_0_00000000", + backup_dbname="", + execute_time="0", + sqlsha1="", + ) + ] _get_engine.return_value.execute_check.return_value = ReviewSet( - warning_count=0, - error_count=0, - column_list=column_list, - rows=rows) - r = c.post('/simplecheck/', data=data) - self.assertListEqual(list(json.loads(r.content)['data'].keys()), - ["rows", "CheckWarningCount", "CheckErrorCount"]) - self.assertListEqual(list(json.loads(r.content)['data']['rows'][0].keys()), column_list) + warning_count=0, error_count=0, column_list=column_list, rows=rows + ) + r = c.post("/simplecheck/", data=data) + self.assertListEqual( + list(json.loads(r.content)["data"].keys()), + ["rows", "CheckWarningCount", "CheckErrorCount"], + ) + self.assertListEqual( + list(json.loads(r.content)["data"]["rows"][0].keys()), column_list + ) def test_submit_param_is_None(self): """测试SQL提交,参数内容为空""" c = Client() c.force_login(self.superuser1) - data = {"sql_content": "update sql_users set email='' where id>0;", - "workflow_name": "【回滚工单】原工单Id:163+,3434434343", - "group_name": self.resource_group1.group_name, - "instance_name": self.master1.instance_name, - "run_date_start": "", - "run_date_end": "", - "workflow_auditors": "11"} - r = c.post('/autoreview/', data=data) - self.assertContains(r, '页面提交参数可能为空') - - @patch('sql.sql_workflow.async_task') - @patch('sql.sql_workflow.Audit') - @patch('sql.sql_workflow.get_engine') - @patch('sql.sql_workflow.user_instances') - def test_submit_audit_wrong(self, _user_instances, _get_engine, _audit, _async_task): + data = { + "sql_content": "update sql_users set email='' where id>0;", + "workflow_name": "【回滚工单】原工单Id:163+,3434434343", + "group_name": self.resource_group1.group_name, + "instance_name": self.master1.instance_name, + "run_date_start": "", + "run_date_end": "", + "workflow_auditors": "11", + } + r = c.post("/autoreview/", data=data) + self.assertContains(r, "页面提交参数可能为空") + + @patch("sql.sql_workflow.async_task") + @patch("sql.sql_workflow.Audit") + @patch("sql.sql_workflow.get_engine") + @patch("sql.sql_workflow.user_instances") + def test_submit_audit_wrong( + self, _user_instances, _get_engine, _audit, _async_task + ): """测试SQL提交,获取审核信息报错""" c = Client() c.force_login(self.superuser1) - data = {"sql_content": "update sql_users set email='' where id>0;", - "workflow_name": "【回滚工单】原工单Id:163+,3434434343", - "group_name": self.resource_group1.group_name, - "instance_name": self.master1.instance_name, - "db_name": "archery", - "demand_url": 'test_url', - "run_date_start": "", - "run_date_end": "", - "workflow_auditors": "11"} + data = { + "sql_content": "update sql_users set email='' where id>0;", + "workflow_name": "【回滚工单】原工单Id:163+,3434434343", + "group_name": self.resource_group1.group_name, + "instance_name": self.master1.instance_name, + "db_name": "archery", + "demand_url": "test_url", + "run_date_start": "", + "run_date_end": "", + "workflow_auditors": "11", + } _user_instances.return_value.get.return_value = self.master1 _get_engine.return_value.execute_check.return_value = ReviewSet( - syntax_type=1, - warning_count=1, - error_count=1) - _audit.settings = ValueError('error') + syntax_type=1, warning_count=1, error_count=1 + ) + _audit.settings = ValueError("error") _audit.add.return_value = None _async_task.return_value = None - r = c.post('/autoreview/', data=data) - self.assertContains(r, 'ValueError') + r = c.post("/autoreview/", data=data) + self.assertContains(r, "ValueError") - @patch('sql.sql_workflow.async_task') - @patch('sql.sql_workflow.Audit') - @patch('sql.sql_workflow.get_engine') - @patch('sql.sql_workflow.user_instances') + @patch("sql.sql_workflow.async_task") + @patch("sql.sql_workflow.Audit") + @patch("sql.sql_workflow.get_engine") + @patch("sql.sql_workflow.user_instances") def test_submit(self, _user_instances, _get_engine, _audit, _async_task): """测试SQL提交,正常提交""" c = Client() c.force_login(self.superuser1) - data = {"sql_content": "update sql_users set email='' where id>0;", - "workflow_name": "【回滚工单】原工单Id:163+,3434434343", - "group_name": self.resource_group1.group_name, - "instance_name": self.master1.instance_name, - "db_name": "archery", - "demand_url": 'test_url', - "run_date_start": "", - "run_date_end": "", - "workflow_auditors": "11"} + data = { + "sql_content": "update sql_users set email='' where id>0;", + "workflow_name": "【回滚工单】原工单Id:163+,3434434343", + "group_name": self.resource_group1.group_name, + "instance_name": self.master1.instance_name, + "db_name": "archery", + "demand_url": "test_url", + "run_date_start": "", + "run_date_end": "", + "workflow_auditors": "11", + } _user_instances.return_value.get.return_value = self.master1 _get_engine.return_value.execute_check.return_value = ReviewSet( - syntax_type=1, - warning_count=1, - error_count=1) - _audit.settings.return_value = 'some_group,another_group' + syntax_type=1, warning_count=1, error_count=1 + ) + _audit.settings.return_value = "some_group,another_group" _audit.add.return_value = None _async_task.return_value = None - r = c.post('/autoreview/', data=data) - workflow_id = SqlWorkflow.objects.order_by('-id').first().id - self.assertRedirects(r, f'/detail/{workflow_id}/', fetch_redirect_response=False) + r = c.post("/autoreview/", data=data) + workflow_id = SqlWorkflow.objects.order_by("-id").first().id + self.assertRedirects( + r, f"/detail/{workflow_id}/", fetch_redirect_response=False + ) - @patch('sql.utils.workflow_audit.Audit.can_review') + @patch("sql.utils.workflow_audit.Audit.can_review") def test_alter_run_date_no_perm(self, _can_review): """测试修改可执行时间,无权限""" - sql_review = Permission.objects.get(codename='sql_review') + sql_review = Permission.objects.get(codename="sql_review") self.u1.user_permissions.add(sql_review) _can_review.return_value = False c = Client() c.force_login(self.u1) data = {"workflow_id": self.wf1.id} - r = c.post('/alter_run_date/', data=data) - self.assertContains(r, '你无权操作当前工单') + r = c.post("/alter_run_date/", data=data) + self.assertContains(r, "你无权操作当前工单") - @patch('sql.utils.workflow_audit.Audit.can_review') + @patch("sql.utils.workflow_audit.Audit.can_review") def test_alter_run_date(self, _can_review): """测试修改可执行时间,有权限""" - sql_review = Permission.objects.get(codename='sql_review') + sql_review = Permission.objects.get(codename="sql_review") self.u1.user_permissions.add(sql_review) _can_review.return_value = True c = Client() c.force_login(self.u1) data = {"workflow_id": self.wf1.id} - r = c.post('/alter_run_date/', data=data) - self.assertRedirects(r, f'/detail/{self.wf1.id}/', fetch_redirect_response=False) + r = c.post("/alter_run_date/", data=data) + self.assertRedirects( + r, f"/detail/{self.wf1.id}/", fetch_redirect_response=False + ) - @patch('sql.utils.workflow_audit.Audit.logs') - @patch('sql.utils.workflow_audit.Audit.detail_by_workflow_id') - @patch('sql.utils.workflow_audit.Audit.review_info') - @patch('sql.utils.workflow_audit.Audit.can_review') + @patch("sql.utils.workflow_audit.Audit.logs") + @patch("sql.utils.workflow_audit.Audit.detail_by_workflow_id") + @patch("sql.utils.workflow_audit.Audit.review_info") + @patch("sql.utils.workflow_audit.Audit.can_review") def testWorkflowDetailView(self, _can_review, _review_info, _detail_by_id, _logs): """测试工单详情""" - _review_info.return_value = ('some_auth_group', 'current_auth_group') + _review_info.return_value = ("some_auth_group", "current_auth_group") _can_review.return_value = False _detail_by_id.return_value.audit_id = 123 - _logs.return_value.latest('id').operation_info = '' + _logs.return_value.latest("id").operation_info = "" c = Client() c.force_login(self.u1) - r = c.get('/detail/{}/'.format(self.wf1.id)) + r = c.get("/detail/{}/".format(self.wf1.id)) expected_status_display = r"""id="workflow_detail_disaply">已正常结束""" self.assertContains(r, expected_status_display) exepcted_status = r"""id="workflow_detail_status">workflow_finish""" self.assertContains(r, exepcted_status) # 测试执行详情解析失败 - self.wfc1.execute_result = 'cannotbedecode:1,:' + self.wfc1.execute_result = "cannotbedecode:1,:" self.wfc1.save() - r = c.get('/detail/{}/'.format(self.wf1.id)) + r = c.get("/detail/{}/".format(self.wf1.id)) self.assertContains(r, expected_status_display) self.assertContains(r, exepcted_status) # 执行详情为空 self.wfc1.review_content = [ - {"id": 1, "stage": "CHECKED", "errlevel": 0, "stagestatus": "Audit completed", "errormessage": "None", - "sql": "use archery", "affected_rows": 0, "sequence": "'0_0_0'", "backup_dbname": "None", - "execute_time": "0", "sqlsha1": "", "actual_affected_rows": ""}] - self.wfc1.execute_result = '' + { + "id": 1, + "stage": "CHECKED", + "errlevel": 0, + "stagestatus": "Audit completed", + "errormessage": "None", + "sql": "use archery", + "affected_rows": 0, + "sequence": "'0_0_0'", + "backup_dbname": "None", + "execute_time": "0", + "sqlsha1": "", + "actual_affected_rows": "", + } + ] + self.wfc1.execute_result = "" self.wfc1.save() - r = c.get('/detail/{}/'.format(self.wf1.id)) + r = c.get("/detail/{}/".format(self.wf1.id)) def testWorkflowListView(self): """测试工单列表""" c = Client() c.force_login(self.superuser1) - r = c.post('/sqlworkflow_list/', {'limit': 10, 'offset': 0, 'navStatus': ''}) + r = c.post("/sqlworkflow_list/", {"limit": 10, "offset": 0, "navStatus": ""}) r_json = r.json() - self.assertEqual(r_json['total'], 2) + self.assertEqual(r_json["total"], 2) # 列表按创建时间倒序排列, 第二个是wf1 , 是已正常结束 - self.assertEqual(r_json['rows'][1]['status'], 'workflow_finish') + self.assertEqual(r_json["rows"][1]["status"], "workflow_finish") # u1拿到u1的 c.force_login(self.u1) - r = c.post('/sqlworkflow_list/', {'limit': 10, 'offset': 0, 'navStatus': ''}) + r = c.post("/sqlworkflow_list/", {"limit": 10, "offset": 0, "navStatus": ""}) r_json = r.json() - self.assertEqual(r_json['total'], 1) - self.assertEqual(r_json['rows'][0]['id'], self.wf1.id) + self.assertEqual(r_json["total"], 1) + self.assertEqual(r_json["rows"][0]["id"], self.wf1.id) # u3拿到None c.force_login(self.u3) - r = c.post('/sqlworkflow_list/', {'limit': 10, 'offset': 0, 'navStatus': ''}) + r = c.post("/sqlworkflow_list/", {"limit": 10, "offset": 0, "navStatus": ""}) r_json = r.json() - self.assertEqual(r_json['total'], 0) + self.assertEqual(r_json["total"], 0) def testWorkflowListViewFilter(self): """测试工单列表筛选""" c = Client() c.force_login(self.superuser1) # 工单状态 - r = c.post('/sqlworkflow_list/', {'limit': 10, 'offset': 0, 'navStatus': 'workflow_finish'}) + r = c.post( + "/sqlworkflow_list/", + {"limit": 10, "offset": 0, "navStatus": "workflow_finish"}, + ) r_json = r.json() - self.assertEqual(r_json['total'], 1) + self.assertEqual(r_json["total"], 1) # 列表按创建时间倒序排列 - self.assertEqual(r_json['rows'][0]['status'], 'workflow_finish') + self.assertEqual(r_json["rows"][0]["status"], "workflow_finish") # 实例 - r = c.post('/sqlworkflow_list/', {'limit': 10, 'offset': 0, 'instance_id': self.wf1.instance_id}) + r = c.post( + "/sqlworkflow_list/", + {"limit": 10, "offset": 0, "instance_id": self.wf1.instance_id}, + ) r_json = r.json() - self.assertEqual(r_json['total'], 2) + self.assertEqual(r_json["total"], 2) # 列表按创建时间倒序排列, 第二个是wf1 - self.assertEqual(r_json['rows'][1]['workflow_name'], self.wf1.workflow_name) + self.assertEqual(r_json["rows"][1]["workflow_name"], self.wf1.workflow_name) # 资源组 - r = c.post('/sqlworkflow_list/', {'limit': 10, 'offset': 0, 'resource_group_id': self.wf1.group_id}) + r = c.post( + "/sqlworkflow_list/", + {"limit": 10, "offset": 0, "resource_group_id": self.wf1.group_id}, + ) r_json = r.json() - self.assertEqual(r_json['total'], 2) + self.assertEqual(r_json["total"], 2) # 列表按创建时间倒序排列, 第二个是wf1 - self.assertEqual(r_json['rows'][1]['workflow_name'], self.wf1.workflow_name) + self.assertEqual(r_json["rows"][1]["workflow_name"], self.wf1.workflow_name) # 时间 - start_date = datetime.strftime(self.now, '%Y-%m-%d') - end_date = datetime.strftime(self.now, '%Y-%m-%d') - r = c.post('/sqlworkflow_list/', {'limit': 10, 'offset': 0, 'start_date': start_date, 'end_date': end_date}) + start_date = datetime.strftime(self.now, "%Y-%m-%d") + end_date = datetime.strftime(self.now, "%Y-%m-%d") + r = c.post( + "/sqlworkflow_list/", + {"limit": 10, "offset": 0, "start_date": start_date, "end_date": end_date}, + ) r_json = r.json() - self.assertEqual(r_json['total'], 2) + self.assertEqual(r_json["total"], 2) - @patch('sql.utils.workflow_audit.Audit.detail_by_workflow_id') - @patch('sql.utils.workflow_audit.Audit.audit') - @patch('sql.utils.workflow_audit.Audit.can_review') + @patch("sql.utils.workflow_audit.Audit.detail_by_workflow_id") + @patch("sql.utils.workflow_audit.Audit.audit") + @patch("sql.utils.workflow_audit.Audit.can_review") def testWorkflowPassedView(self, _can_review, _audit, _detail_by_id): """测试审核工单""" c = Client() c.force_login(self.superuser1) - r = c.post('/passed/') - self.assertContains(r, 'workflow_id参数为空.') + r = c.post("/passed/") + self.assertContains(r, "workflow_id参数为空.") _can_review.return_value = False - r = c.post('/passed/', {'workflow_id': self.wf1.id}) - self.assertContains(r, '你无权操作当前工单!') + r = c.post("/passed/", {"workflow_id": self.wf1.id}) + self.assertContains(r, "你无权操作当前工单!") _can_review.return_value = True _detail_by_id.return_value.audit_id = 123 - _audit.return_value = { - "data": { - "workflow_status": 1 # TODO 改为audit_success - } - } - r = c.post('/passed/', data={'workflow_id': self.wf1.id, 'audit_remark': 'some_audit'}, follow=False) - self.assertRedirects(r, '/detail/{}/'.format(self.wf1.id), fetch_redirect_response=False) + _audit.return_value = {"data": {"workflow_status": 1}} # TODO 改为audit_success + r = c.post( + "/passed/", + data={"workflow_id": self.wf1.id, "audit_remark": "some_audit"}, + follow=False, + ) + self.assertRedirects( + r, "/detail/{}/".format(self.wf1.id), fetch_redirect_response=False + ) self.wf1.refresh_from_db() - self.assertEqual(self.wf1.status, 'workflow_review_pass') + self.assertEqual(self.wf1.status, "workflow_review_pass") - @patch('sql.sql_workflow.Audit.add_log') - @patch('sql.sql_workflow.Audit.detail_by_workflow_id') - @patch('sql.sql_workflow.can_execute') + @patch("sql.sql_workflow.Audit.add_log") + @patch("sql.sql_workflow.Audit.detail_by_workflow_id") + @patch("sql.sql_workflow.can_execute") def test_workflow_execute(self, mock_can_excute, mock_detail_by_id, mock_add_log): """测试工单执行""" c = Client() c.force_login(self.executor1) - r = c.post('/execute/') - self.assertContains(r, 'workflow_id参数为空.') + r = c.post("/execute/") + self.assertContains(r, "workflow_id参数为空.") mock_can_excute.return_value = False - r = c.post('/execute/', data={'workflow_id': self.wf2.id}) - self.assertContains(r, '你无权操作当前工单!') + r = c.post("/execute/", data={"workflow_id": self.wf2.id}) + self.assertContains(r, "你无权操作当前工单!") mock_can_excute.return_value = True mock_detail_by_id = 123 - r = c.post('/execute/', data={'workflow_id': self.wf2.id, 'mode': 'manual'}) + r = c.post("/execute/", data={"workflow_id": self.wf2.id, "mode": "manual"}) self.wf2.refresh_from_db() - self.assertEqual('workflow_finish', self.wf2.status) + self.assertEqual("workflow_finish", self.wf2.status) - @patch('sql.sql_workflow.Audit.add_log') - @patch('sql.sql_workflow.Audit.detail_by_workflow_id') - @patch('sql.sql_workflow.Audit.audit') + @patch("sql.sql_workflow.Audit.add_log") + @patch("sql.sql_workflow.Audit.detail_by_workflow_id") + @patch("sql.sql_workflow.Audit.audit") # patch view里的can_cancel 而不是原始位置的can_cancel ,因为在调用时, 已经 import 了真的 can_cancel ,会导致mock失效 # 在import 静态函数时需要注意这一点, 动态对象因为每次都会重新生成,也可以 mock 原函数/方法/对象 # 参见 : https://docs.python.org/3/library/unittest.mock.html#where-to-patch - @patch('sql.sql_workflow.can_cancel') + @patch("sql.sql_workflow.can_cancel") def testWorkflowCancelView(self, _can_cancel, _audit, _detail_by_id, _add_log): """测试工单驳回、取消""" c = Client() c.force_login(self.u2) - r = c.post('/cancel/') - self.assertContains(r, 'workflow_id参数为空.') - r = c.post('/cancel/', data={'workflow_id': self.wf2.id}) - self.assertContains(r, '终止原因不能为空') + r = c.post("/cancel/") + self.assertContains(r, "workflow_id参数为空.") + r = c.post("/cancel/", data={"workflow_id": self.wf2.id}) + self.assertContains(r, "终止原因不能为空") _can_cancel.return_value = False - r = c.post('/cancel/', data={'workflow_id': self.wf2.id, 'cancel_remark': 'some_reason'}) - self.assertContains(r, '你无权操作当前工单!') + r = c.post( + "/cancel/", + data={"workflow_id": self.wf2.id, "cancel_remark": "some_reason"}, + ) + self.assertContains(r, "你无权操作当前工单!") _can_cancel.return_value = True _detail_by_id = 123 - r = c.post('/cancel/', data={'workflow_id': self.wf2.id, 'cancel_remark': 'some_reason'}) + r = c.post( + "/cancel/", + data={"workflow_id": self.wf2.id, "cancel_remark": "some_reason"}, + ) self.wf2.refresh_from_db() - self.assertEqual('workflow_abort', self.wf2.status) - - @patch('sql.sql_workflow.async_task') - @patch('sql.sql_workflow.Audit') - @patch('sql.sql_workflow.get_engine') - @patch('sql.sql_workflow.user_instances') - def test_workflow_auto_review_view(self, mock_user_instances, mock_get_engine, mock_audit, mock_async_task): + self.assertEqual("workflow_abort", self.wf2.status) + + @patch("sql.sql_workflow.async_task") + @patch("sql.sql_workflow.Audit") + @patch("sql.sql_workflow.get_engine") + @patch("sql.sql_workflow.user_instances") + def test_workflow_auto_review_view( + self, mock_user_instances, mock_get_engine, mock_audit, mock_async_task + ): """测试 autoreview/submit view""" c = Client() c.force_login(self.superuser1) request_data = { - 'sql_content': "update some_db set some_key=\'some value\';", - 'workflow_name': 'some_title', - 'group_name': self.resource_group1.group_name, - 'group_id': self.resource_group1.group_id, - 'instance_name': self.master1.instance_name, - "demand_url": 'test_url', - 'db_name': 'some_db', - 'is_backup': True, - 'notify_users': '' + "sql_content": "update some_db set some_key='some value';", + "workflow_name": "some_title", + "group_name": self.resource_group1.group_name, + "group_id": self.resource_group1.group_id, + "instance_name": self.master1.instance_name, + "demand_url": "test_url", + "db_name": "some_db", + "is_backup": True, + "notify_users": "", } mock_user_instances.return_value.get.return_value = None mock_get_engine.return_value.execute_check.return_value.warning_count = 0 mock_get_engine.return_value.execute_check.return_value.error_count = 0 mock_get_engine.return_value.execute_check.return_value.syntax_type = 0 mock_get_engine.return_value.execute_check.return_value.rows = [] - mock_get_engine.return_value.execute_check.return_value.json.return_value = json.dumps([{ - "id": 1, - "stage": "CHECKED", - "errlevel": 0, - "stagestatus": "Audit completed", - "errormessage": "None", "sql": "use thirdservice_db", "affected_rows": 0, - "sequence": "'0_0_0'", "backup_dbname": "None", "execute_time": "0", "sqlsha1": "", - "actual_affected_rows": None}]) - mock_audit.settings.return_value = 'some_group,another_group' + mock_get_engine.return_value.execute_check.return_value.json.return_value = ( + json.dumps( + [ + { + "id": 1, + "stage": "CHECKED", + "errlevel": 0, + "stagestatus": "Audit completed", + "errormessage": "None", + "sql": "use thirdservice_db", + "affected_rows": 0, + "sequence": "'0_0_0'", + "backup_dbname": "None", + "execute_time": "0", + "sqlsha1": "", + "actual_affected_rows": None, + } + ] + ) + ) + mock_audit.settings.return_value = "some_group,another_group" mock_audit.add.return_value = None mock_async_task.return_value = None - r = c.post('/autoreview/', data=request_data, follow=False) - self.assertIn('detail', r.url) - workflow_id = int(re.search(r'\/detail\/(\d+)\/', r.url).groups()[0]) - self.assertEqual(request_data['workflow_name'], SqlWorkflow.objects.get(id=workflow_id).workflow_name) + r = c.post("/autoreview/", data=request_data, follow=False) + self.assertIn("detail", r.url) + workflow_id = int(re.search(r"\/detail\/(\d+)\/", r.url).groups()[0]) + self.assertEqual( + request_data["workflow_name"], + SqlWorkflow.objects.get(id=workflow_id).workflow_name, + ) # 强制备份测试 # 打开备份开关, 对备份不要求 request_data_without_backup = { - 'sql_content': "update some_db set some_key=\'some value\';", - 'workflow_name': 'some_title_2', - 'group_name': self.resource_group1.group_name, - 'group_id': self.resource_group1.group_id, - 'instance_name': self.master1.instance_name, - 'db_name': 'some_db', - "demand_url": 'test_url', - 'is_backup': False, - 'notify_users': '' + "sql_content": "update some_db set some_key='some value';", + "workflow_name": "some_title_2", + "group_name": self.resource_group1.group_name, + "group_id": self.resource_group1.group_id, + "instance_name": self.master1.instance_name, + "db_name": "some_db", + "demand_url": "test_url", + "is_backup": False, + "notify_users": "", } archer_config = SysConfig() - archer_config.set('enable_backup_switch', 'true') - r = c.post('/autoreview/', data=request_data_without_backup, follow=False) - self.assertIn('detail', r.url) - workflow_id = int(re.search(r'\/detail\/(\d+)\/', r.url).groups()[0]) - self.assertEqual(request_data_without_backup['workflow_name'], - SqlWorkflow.objects.get(id=workflow_id).workflow_name) + archer_config.set("enable_backup_switch", "true") + r = c.post("/autoreview/", data=request_data_without_backup, follow=False) + self.assertIn("detail", r.url) + workflow_id = int(re.search(r"\/detail\/(\d+)\/", r.url).groups()[0]) + self.assertEqual( + request_data_without_backup["workflow_name"], + SqlWorkflow.objects.get(id=workflow_id).workflow_name, + ) # 关闭备份选项, 不允许不备份 - archer_config.set('enable_backup_switch', 'false') - r = c.post('/autoreview/', data=request_data_without_backup, follow=False) - self.assertIn('detail', r.url) - workflow_id = int(re.search(r'\/detail\/(\d+)\/', r.url).groups()[0]) + archer_config.set("enable_backup_switch", "false") + r = c.post("/autoreview/", data=request_data_without_backup, follow=False) + self.assertIn("detail", r.url) + workflow_id = int(re.search(r"\/detail\/(\d+)\/", r.url).groups()[0]) self.assertEqual(SqlWorkflow.objects.get(id=workflow_id).is_backup, True) - @patch('sql.sql_workflow.get_engine') + @patch("sql.sql_workflow.get_engine") def test_osc_control(self, _get_engine): """测试MySQL工单osc控制""" c = Client() c.force_login(self.superuser1) request_data = { - 'workflow_id': self.wf1.id, - 'sqlsha1': 'sqlsha1', - 'command': 'get', + "workflow_id": self.wf1.id, + "sqlsha1": "sqlsha1", + "command": "get", } _get_engine.return_value.osc_control.return_value = ResultSet() - r = c.post('/inception/osc_control/', data=request_data, follow=False) - self.assertDictEqual(json.loads(r.content), - {"total": 0, "rows": [], "msg": None}) + r = c.post("/inception/osc_control/", data=request_data, follow=False) + self.assertDictEqual( + json.loads(r.content), {"total": 0, "rows": [], "msg": None} + ) - @patch('sql.sql_workflow.get_engine') + @patch("sql.sql_workflow.get_engine") def test_osc_control_exception(self, _get_engine): """测试MySQL工单OSC控制异常""" c = Client() c.force_login(self.superuser1) request_data = { - 'workflow_id': self.wf1.id, - 'sqlsha1': 'sqlsha1', - 'command': 'get', + "workflow_id": self.wf1.id, + "sqlsha1": "sqlsha1", + "command": "get", } - _get_engine.return_value.osc_control.side_effect = RuntimeError('RuntimeError') - r = c.post('/inception/osc_control/', data=request_data, follow=False) - self.assertDictEqual(json.loads(r.content), - {"total": 0, "rows": [], "msg": "RuntimeError"}) + _get_engine.return_value.osc_control.side_effect = RuntimeError("RuntimeError") + r = c.post("/inception/osc_control/", data=request_data, follow=False) + self.assertDictEqual( + json.loads(r.content), {"total": 0, "rows": [], "msg": "RuntimeError"} + ) class TestOptimize(TestCase): @@ -1534,14 +1928,18 @@ class TestOptimize(TestCase): """ def setUp(self): - self.superuser = User(username='super', is_superuser=True) + self.superuser = User(username="super", is_superuser=True) self.superuser.save() # 使用 travis.ci 时实例和测试service保持一致 - self.master = Instance(instance_name='test_instance', type='master', db_type='mysql', - host=settings.DATABASES['default']['HOST'], - port=settings.DATABASES['default']['PORT'], - user=settings.DATABASES['default']['USER'], - password=settings.DATABASES['default']['PASSWORD']) + self.master = Instance( + instance_name="test_instance", + type="master", + db_type="mysql", + host=settings.DATABASES["default"]["HOST"], + port=settings.DATABASES["default"]["PORT"], + user=settings.DATABASES["default"]["USER"], + password=settings.DATABASES["default"]["PASSWORD"], + ) self.master.save() self.sys_config = SysConfig() self.client = Client() @@ -1557,70 +1955,105 @@ def test_sqladvisor(self): 测试SQLAdvisor报告 :return: """ - r = self.client.post(path='/slowquery/optimize_sqladvisor/') - self.assertEqual(json.loads(r.content), {'status': 1, 'msg': '页面提交参数可能为空', 'data': []}) - r = self.client.post(path='/slowquery/optimize_sqladvisor/', - data={"sql_content": "select 1;", "instance_name": "test_instance"}) - self.assertEqual(json.loads(r.content), {'status': 1, 'msg': '请配置SQLAdvisor路径!', 'data': []}) - self.sys_config.set('sqladvisor', '/opt/archery/src/plugins/sqladvisor') + r = self.client.post(path="/slowquery/optimize_sqladvisor/") + self.assertEqual( + json.loads(r.content), {"status": 1, "msg": "页面提交参数可能为空", "data": []} + ) + r = self.client.post( + path="/slowquery/optimize_sqladvisor/", + data={"sql_content": "select 1;", "instance_name": "test_instance"}, + ) + self.assertEqual( + json.loads(r.content), {"status": 1, "msg": "请配置SQLAdvisor路径!", "data": []} + ) + self.sys_config.set("sqladvisor", "/opt/archery/src/plugins/sqladvisor") self.sys_config.get_all_config() - r = self.client.post(path='/slowquery/optimize_sqladvisor/', - data={"sql_content": "select 1;", "instance_name": "test_instance"}) - self.assertEqual(json.loads(r.content)['status'], 0) + r = self.client.post( + path="/slowquery/optimize_sqladvisor/", + data={"sql_content": "select 1;", "instance_name": "test_instance"}, + ) + self.assertEqual(json.loads(r.content)["status"], 0) def test_soar(self): """ 测试SOAR报告 :return: """ - r = self.client.post(path='/slowquery/optimize_soar/') - self.assertEqual(json.loads(r.content), {'status': 1, 'msg': '页面提交参数可能为空', 'data': []}) - r = self.client.post(path='/slowquery/optimize_soar/', - data={"sql": "select 1;", "instance_name": "test_instance", "db_name": "mysql"}) - self.assertEqual(json.loads(r.content), {'status': 1, 'msg': '请配置soar_path和test_dsn!', 'data': []}) - self.sys_config.set('soar', '/opt/archery/src/plugins/soar') - self.sys_config.set('soar_test_dsn', 'root:@127.0.0.1:3306/information_schema') + r = self.client.post(path="/slowquery/optimize_soar/") + self.assertEqual( + json.loads(r.content), {"status": 1, "msg": "页面提交参数可能为空", "data": []} + ) + r = self.client.post( + path="/slowquery/optimize_soar/", + data={ + "sql": "select 1;", + "instance_name": "test_instance", + "db_name": "mysql", + }, + ) + self.assertEqual( + json.loads(r.content), + {"status": 1, "msg": "请配置soar_path和test_dsn!", "data": []}, + ) + self.sys_config.set("soar", "/opt/archery/src/plugins/soar") + self.sys_config.set("soar_test_dsn", "root:@127.0.0.1:3306/information_schema") self.sys_config.get_all_config() - r = self.client.post(path='/slowquery/optimize_soar/', - data={"sql": "select 1;", "instance_name": "test_instance", "db_name": "mysql"}) - self.assertEqual(json.loads(r.content)['status'], 0) + r = self.client.post( + path="/slowquery/optimize_soar/", + data={ + "sql": "select 1;", + "instance_name": "test_instance", + "db_name": "mysql", + }, + ) + self.assertEqual(json.loads(r.content)["status"], 0) def test_tuning(self): """ 测试SQLTuning报告 :return: """ - data = {"sql_content": "select * from test_archery.sql_users;", - "instance_name": "test_instance", - "db_name": settings.DATABASES['default']['TEST']['NAME'] - } - data['instance_name'] = 'test_instancex' - r = self.client.post(path='/slowquery/optimize_sqltuning/', data=data) - self.assertEqual(json.loads(r.content), {'status': 1, 'msg': '你所在组未关联该实例!', 'data': []}) + data = { + "sql_content": "select * from test_archery.sql_users;", + "instance_name": "test_instance", + "db_name": settings.DATABASES["default"]["TEST"]["NAME"], + } + data["instance_name"] = "test_instancex" + r = self.client.post(path="/slowquery/optimize_sqltuning/", data=data) + self.assertEqual( + json.loads(r.content), {"status": 1, "msg": "你所在组未关联该实例!", "data": []} + ) # 获取sys_parm - data['instance_name'] = 'test_instance' - data['option[]'] = 'sys_parm' - r = self.client.post(path='/slowquery/optimize_sqltuning/', data=data) - self.assertListEqual(list(json.loads(r.content)['data'].keys()), - ['basic_information', 'sys_parameter', 'optimizer_switch', 'sqltext']) + data["instance_name"] = "test_instance" + data["option[]"] = "sys_parm" + r = self.client.post(path="/slowquery/optimize_sqltuning/", data=data) + self.assertListEqual( + list(json.loads(r.content)["data"].keys()), + ["basic_information", "sys_parameter", "optimizer_switch", "sqltext"], + ) # 获取sql_plan - data['option[]'] = 'sql_plan' - r = self.client.post(path='/slowquery/optimize_sqltuning/', data=data) - self.assertListEqual(list(json.loads(r.content)['data'].keys()), - ['optimizer_rewrite_sql', 'plan', 'sqltext']) + data["option[]"] = "sql_plan" + r = self.client.post(path="/slowquery/optimize_sqltuning/", data=data) + self.assertListEqual( + list(json.loads(r.content)["data"].keys()), + ["optimizer_rewrite_sql", "plan", "sqltext"], + ) # 获取obj_stat - data['option[]'] = 'obj_stat' - r = self.client.post(path='/slowquery/optimize_sqltuning/', data=data) - self.assertListEqual(list(json.loads(r.content)['data'].keys()), - ['object_statistics', 'sqltext']) + data["option[]"] = "obj_stat" + r = self.client.post(path="/slowquery/optimize_sqltuning/", data=data) + self.assertListEqual( + list(json.loads(r.content)["data"].keys()), ["object_statistics", "sqltext"] + ) # 获取sql_profile - data['option[]'] = 'sql_profile' - r = self.client.post(path='/slowquery/optimize_sqltuning/', data=data) - self.assertListEqual(list(json.loads(r.content)['data'].keys()), ['session_status', 'sqltext']) + data["option[]"] = "sql_profile" + r = self.client.post(path="/slowquery/optimize_sqltuning/", data=data) + self.assertListEqual( + list(json.loads(r.content)["data"].keys()), ["session_status", "sqltext"] + ) class TestSchemaSync(TestCase): @@ -1629,14 +2062,18 @@ class TestSchemaSync(TestCase): """ def setUp(self): - self.superuser = User(username='super', is_superuser=True) + self.superuser = User(username="super", is_superuser=True) self.superuser.save() # 使用 travis.ci 时实例和测试service保持一致 - self.master = Instance(instance_name='test_instance', type='master', db_type='mysql', - host=settings.DATABASES['default']['HOST'], - port=settings.DATABASES['default']['PORT'], - user=settings.DATABASES['default']['USER'], - password=settings.DATABASES['default']['PASSWORD']) + self.master = Instance( + instance_name="test_instance", + type="master", + db_type="mysql", + host=settings.DATABASES["default"]["HOST"], + port=settings.DATABASES["default"]["PORT"], + user=settings.DATABASES["default"]["USER"], + password=settings.DATABASES["default"]["PASSWORD"], + ) self.master.save() self.sys_config = SysConfig() self.client = Client() @@ -1652,14 +2089,16 @@ def test_schema_sync(self): 测试SchemaSync :return: """ - data = {"instance_name": "test_instance", - "db_name": "test", - "target_instance_name": "test_instance", - "target_db_name": "test", - "sync_auto_inc": True, - "sync_comments": False} - r = self.client.post(path='/instance/schemasync/', data=data) - self.assertEqual(json.loads(r.content)['status'], 0) + data = { + "instance_name": "test_instance", + "db_name": "test", + "target_instance_name": "test_instance", + "target_db_name": "test", + "sync_auto_inc": True, + "sync_comments": False, + } + r = self.client.post(path="/instance/schemasync/", data=data) + self.assertEqual(json.loads(r.content)["status"], 0) class TestArchiver(TestCase): @@ -1668,39 +2107,45 @@ class TestArchiver(TestCase): """ def setUp(self): - self.superuser = User.objects.create(username='super', is_superuser=True) - self.u1 = User.objects.create(username='u1', is_superuser=False) - self.u2 = User.objects.create(username='u2', is_superuser=False) - menu_archive = Permission.objects.get(codename='menu_archive') - archive_review = Permission.objects.get(codename='archive_review') + self.superuser = User.objects.create(username="super", is_superuser=True) + self.u1 = User.objects.create(username="u1", is_superuser=False) + self.u2 = User.objects.create(username="u2", is_superuser=False) + menu_archive = Permission.objects.get(codename="menu_archive") + archive_review = Permission.objects.get(codename="archive_review") self.u1.user_permissions.add(menu_archive) self.u2.user_permissions.add(menu_archive) self.u2.user_permissions.add(archive_review) # 使用 travis.ci 时实例和测试service保持一致 - self.ins = Instance.objects.create(instance_name='test_instance', type='master', db_type='mysql', - host=settings.DATABASES['default']['HOST'], - port=settings.DATABASES['default']['PORT'], - user=settings.DATABASES['default']['USER'], - password=settings.DATABASES['default']['PASSWORD']) - self.res_group = ResourceGroup.objects.create(group_id=1, group_name='group_name') + self.ins = Instance.objects.create( + instance_name="test_instance", + type="master", + db_type="mysql", + host=settings.DATABASES["default"]["HOST"], + port=settings.DATABASES["default"]["PORT"], + user=settings.DATABASES["default"]["USER"], + password=settings.DATABASES["default"]["PASSWORD"], + ) + self.res_group = ResourceGroup.objects.create( + group_id=1, group_name="group_name" + ) self.archive_apply = ArchiveConfig.objects.create( - title='title', + title="title", resource_group=self.res_group, - audit_auth_groups='some_audit_group', + audit_auth_groups="some_audit_group", src_instance=self.ins, - src_db_name='src_db_name', - src_table_name='src_table_name', + src_db_name="src_db_name", + src_table_name="src_table_name", dest_instance=self.ins, - dest_db_name='src_db_name', - dest_table_name='src_table_name', - condition='1=1', - mode='file', + dest_db_name="src_db_name", + dest_table_name="src_table_name", + condition="1=1", + mode="file", no_delete=True, sleep=1, - status=WorkflowDict.workflow_status['audit_wait'], + status=WorkflowDict.workflow_status["audit_wait"], state=False, - user_name='some_user', - user_display='display', + user_name="some_user", + user_display="display", ) self.sys_config = SysConfig() self.client = Client() @@ -1718,11 +2163,9 @@ def test_archive_list_super(self): 测试管理员获取归档申请列表 :return: """ - data = {"filter_instance_id": self.ins.id, - "state": 'false', - "search": "text"} + data = {"filter_instance_id": self.ins.id, "state": "false", "search": "text"} self.client.force_login(self.superuser) - r = self.client.get(path='/archive/list/', data=data) + r = self.client.get(path="/archive/list/", data=data) self.assertDictEqual(json.loads(r.content), {"total": 0, "rows": []}) def test_archive_list_own(self): @@ -1730,11 +2173,9 @@ def test_archive_list_own(self): 测试非管理员和审核人获取归档申请列表 :return: """ - data = {"filter_instance_id": self.ins.id, - "state": 'false', - "search": "text"} + data = {"filter_instance_id": self.ins.id, "state": "false", "search": "text"} self.client.force_login(self.u1) - r = self.client.get(path='/archive/list/', data=data) + r = self.client.get(path="/archive/list/", data=data) self.assertDictEqual(json.loads(r.content), {"total": 0, "rows": []}) def test_archive_list_review(self): @@ -1742,11 +2183,9 @@ def test_archive_list_review(self): 测试审核人获取归档申请列表 :return: """ - data = {"filter_instance_id": self.ins.id, - "state": 'false', - "search": "text"} + data = {"filter_instance_id": self.ins.id, "state": "false", "search": "text"} self.client.force_login(self.u2) - r = self.client.get(path='/archive/list/', data=data) + r = self.client.get(path="/archive/list/", data=data) self.assertDictEqual(json.loads(r.content), {"total": 0, "rows": []}) def test_archive_apply_not_param(self): @@ -1757,19 +2196,21 @@ def test_archive_apply_not_param(self): data = { "group_name": self.res_group.group_name, "src_instance_name": self.ins.instance_name, - "src_db_name": 'src_db_name', - "src_table_name": 'src_table_name', - "mode": 'dest', + "src_db_name": "src_db_name", + "src_table_name": "src_table_name", + "mode": "dest", "dest_instance_name": self.ins.instance_name, - "dest_db_name": 'dest_db_name', - "dest_table_name": 'dest_table_name', - "condition": '1=1', - "no_delete": 'true', - "sleep": 10 + "dest_db_name": "dest_db_name", + "dest_table_name": "dest_table_name", + "condition": "1=1", + "no_delete": "true", + "sleep": 10, } self.client.force_login(self.superuser) - r = self.client.post(path='/archive/apply/', data=data) - self.assertDictEqual(json.loads(r.content), {'status': 1, 'msg': '请填写完整!', 'data': {}}) + r = self.client.post(path="/archive/apply/", data=data) + self.assertDictEqual( + json.loads(r.content), {"status": 1, "msg": "请填写完整!", "data": {}} + ) def test_archive_apply_not_dest_param(self): """ @@ -1777,85 +2218,99 @@ def test_archive_apply_not_dest_param(self): :return: """ data = { - "title": 'title', + "title": "title", "group_name": self.res_group.group_name, "src_instance_name": self.ins.instance_name, - "src_db_name": 'src_db_name', - "src_table_name": 'src_table_name', - "mode": 'dest', - "condition": '1=1', - "no_delete": 'true', - "sleep": 10 + "src_db_name": "src_db_name", + "src_table_name": "src_table_name", + "mode": "dest", + "condition": "1=1", + "no_delete": "true", + "sleep": 10, } self.client.force_login(self.superuser) - r = self.client.post(path='/archive/apply/', data=data) - self.assertDictEqual(json.loads(r.content), {'status': 1, 'msg': '归档到实例时目标实例信息必选!', 'data': {}}) + r = self.client.post(path="/archive/apply/", data=data) + self.assertDictEqual( + json.loads(r.content), {"status": 1, "msg": "归档到实例时目标实例信息必选!", "data": {}} + ) def test_archive_apply_not_exist_review(self): """ 测试申请归档实例数据,未配置审批流程 :return: """ - data = {"title": 'title', - "group_name": self.res_group.group_name, - "src_instance_name": self.ins.instance_name, - "src_db_name": 'src_db_name', - "src_table_name": 'src_table_name', - "mode": 'dest', - "dest_instance_name": self.ins.instance_name, - "dest_db_name": 'dest_db_name', - "dest_table_name": 'dest_table_name', - "condition": '1=1', - "no_delete": 'true', - "sleep": 10 - } + data = { + "title": "title", + "group_name": self.res_group.group_name, + "src_instance_name": self.ins.instance_name, + "src_db_name": "src_db_name", + "src_table_name": "src_table_name", + "mode": "dest", + "dest_instance_name": self.ins.instance_name, + "dest_db_name": "dest_db_name", + "dest_table_name": "dest_table_name", + "condition": "1=1", + "no_delete": "true", + "sleep": 10, + } self.client.force_login(self.superuser) - r = self.client.post(path='/archive/apply/', data=data) - self.assertDictEqual(json.loads(r.content), {'data': {}, 'msg': '审批流程不能为空,请先配置审批流程', 'status': 1}) + r = self.client.post(path="/archive/apply/", data=data) + self.assertDictEqual( + json.loads(r.content), {"data": {}, "msg": "审批流程不能为空,请先配置审批流程", "status": 1} + ) - @patch('sql.archiver.async_task') + @patch("sql.archiver.async_task") def test_archive_apply(self, _async_task): """ 测试申请归档实例数据 :return: """ - WorkflowAuditSetting.objects.create(workflow_type=3, group_id=1, audit_auth_groups='1') - data = {"title": 'title', - "group_name": self.res_group.group_name, - "src_instance_name": self.ins.instance_name, - "src_db_name": 'src_db_name', - "src_table_name": 'src_table_name', - "mode": 'dest', - "dest_instance_name": self.ins.instance_name, - "dest_db_name": 'dest_db_name', - "dest_table_name": 'dest_table_name', - "condition": '1=1', - "no_delete": 'true', - "sleep": 10 - } + WorkflowAuditSetting.objects.create( + workflow_type=3, group_id=1, audit_auth_groups="1" + ) + data = { + "title": "title", + "group_name": self.res_group.group_name, + "src_instance_name": self.ins.instance_name, + "src_db_name": "src_db_name", + "src_table_name": "src_table_name", + "mode": "dest", + "dest_instance_name": self.ins.instance_name, + "dest_db_name": "dest_db_name", + "dest_table_name": "dest_table_name", + "condition": "1=1", + "no_delete": "true", + "sleep": 10, + } self.client.force_login(self.superuser) - r = self.client.post(path='/archive/apply/', data=data) - self.assertEqual(json.loads(r.content)['status'], 0) + r = self.client.post(path="/archive/apply/", data=data) + self.assertEqual(json.loads(r.content)["status"], 0) - @patch('sql.archiver.Audit') - @patch('sql.archiver.async_task') + @patch("sql.archiver.Audit") + @patch("sql.archiver.async_task") def test_archive_audit(self, _async_task, _audit): """ 测试审核归档实例数据 :return: """ _audit.detail_by_workflow_id.return_value.audit_id = 1 - _audit.audit.return_value = {'status': 0, 'msg': 'ok', 'data': {'workflow_status': 1}} + _audit.audit.return_value = { + "status": 0, + "msg": "ok", + "data": {"workflow_status": 1}, + } data = { "archive_id": self.archive_apply.id, - "audit_status": WorkflowDict.workflow_status['audit_success'], - "audit_remark": 'xxxx' + "audit_status": WorkflowDict.workflow_status["audit_success"], + "audit_remark": "xxxx", } self.client.force_login(self.superuser) - r = self.client.post(path='/archive/audit/', data=data) - self.assertRedirects(r, f'/archive/{self.archive_apply.id}/', fetch_redirect_response=False) + r = self.client.post(path="/archive/audit/", data=data) + self.assertRedirects( + r, f"/archive/{self.archive_apply.id}/", fetch_redirect_response=False + ) - @patch('sql.archiver.async_task') + @patch("sql.archiver.async_task") def test_add_archive_task(self, _async_task): """ 测试添加异步归档任务 @@ -1863,7 +2318,7 @@ def test_add_archive_task(self, _async_task): """ add_archive_task() - @patch('sql.archiver.async_task') + @patch("sql.archiver.async_task") def test_add_archive(self, _async_task): """ 测试执行归档任务 @@ -1872,7 +2327,7 @@ def test_add_archive(self, _async_task): with self.assertRaises(Exception): archive(self.archive_apply.id) - @patch('sql.archiver.async_task') + @patch("sql.archiver.async_task") def test_archive_log(self, _async_task): """ 测试获取归档日志 @@ -1882,48 +2337,52 @@ def test_archive_log(self, _async_task): "archive_id": self.archive_apply.id, } self.client.force_login(self.superuser) - r = self.client.post(path='/archive/log/', data=data) + r = self.client.post(path="/archive/log/", data=data) self.assertDictEqual(json.loads(r.content), {"total": 0, "rows": []}) class TestAsync(TestCase): - def setUp(self): self.now = datetime.now() - self.u1 = User(username='some_user', display='用户1') + self.u1 = User(username="some_user", display="用户1") self.u1.save() - self.master1 = Instance(instance_name='test_master_instance', type='master', db_type='mysql', - host='testhost', port=3306, user='mysql_user', password='mysql_password') + self.master1 = Instance( + instance_name="test_master_instance", + type="master", + db_type="mysql", + host="testhost", + port=3306, + user="mysql_user", + password="mysql_password", + ) self.master1.save() self.wf1 = SqlWorkflow.objects.create( - workflow_name='some_name2', + workflow_name="some_name2", group_id=1, - group_name='g1', + group_name="g1", engineer=self.u1.username, engineer_display=self.u1.display, - audit_auth_groups='some_group', + audit_auth_groups="some_group", create_time=self.now - timedelta(days=1), - status='workflow_executing', + status="workflow_executing", is_backup=True, instance=self.master1, - db_name='some_db', + db_name="some_db", syntax_type=1, ) self.wfc1 = SqlWorkflowContent.objects.create( - workflow=self.wf1, - sql_content='some_sql', - execute_result='' + workflow=self.wf1, sql_content="some_sql", execute_result="" ) # 初始化工单执行返回对象 self.task_result = MagicMock() self.task_result.args = [self.wf1.id] self.task_result.success = True self.task_result.stopped = self.now - self.task_result.result.json.return_value = json.dumps([{ - 'id': 1, - 'sql': 'some_content'}]) - self.task_result.result.warning = '' - self.task_result.result.error = '' + self.task_result.result.json.return_value = json.dumps( + [{"id": 1, "sql": "some_content"}] + ) + self.task_result.result.warning = "" + self.task_result.result.error = "" def tearDown(self): self.wf1.delete() @@ -1931,13 +2390,15 @@ def tearDown(self): self.task_result = None self.master1.delete() - @patch('sql.utils.execute_sql.notify_for_execute') - @patch('sql.utils.execute_sql.Audit') + @patch("sql.utils.execute_sql.notify_for_execute") + @patch("sql.utils.execute_sql.Audit") def test_call_back(self, mock_audit, mock_notify): mock_audit.detail_by_workflow_id.return_value.audit_id = 123 - mock_audit.add_log.return_value = 'any thing' + mock_audit.add_log.return_value = "any thing" execute_callback(self.task_result) - mock_audit.detail_by_workflow_id.assert_called_with(workflow_id=self.wf1.id, workflow_type=ANY) + mock_audit.detail_by_workflow_id.assert_called_with( + workflow_id=self.wf1.id, workflow_type=ANY + ) mock_audit.add_log.assert_called_with( audit_id=123, operation_type=ANY, @@ -1955,14 +2416,18 @@ class TestSQLAnalyze(TestCase): """ def setUp(self): - self.superuser = User(username='super', is_superuser=True) + self.superuser = User(username="super", is_superuser=True) self.superuser.save() # 使用 travis.ci 时实例和测试service保持一致 - self.master = Instance(instance_name='test_instance', type='master', db_type='mysql', - host=settings.DATABASES['default']['HOST'], - port=settings.DATABASES['default']['PORT'], - user=settings.DATABASES['default']['USER'], - password=settings.DATABASES['default']['PASSWORD']) + self.master = Instance( + instance_name="test_instance", + type="master", + db_type="mysql", + host=settings.DATABASES["default"]["HOST"], + port=settings.DATABASES["default"]["PORT"], + user=settings.DATABASES["default"]["USER"], + password=settings.DATABASES["default"]["PASSWORD"], + ) self.master.save() self.sys_config = SysConfig() self.client = Client() @@ -1978,8 +2443,8 @@ def test_generate_text_None(self): 测试解析SQL,text为空 :return: """ - r = self.client.post(path='/sql_analyze/generate/', data={}) - self.assertEqual(json.loads(r.content), {'rows': [], 'total': 0}) + r = self.client.post(path="/sql_analyze/generate/", data={}) + self.assertEqual(json.loads(r.content), {"rows": [], "total": 0}) def test_generate_text_not_None(self): """ @@ -1987,19 +2452,25 @@ def test_generate_text_not_None(self): :return: """ text = "select * from sql_user;select * from sql_workflow;" - r = self.client.post(path='/sql_analyze/generate/', data={"text": text}) - self.assertEqual(json.loads(r.content), - {"total": 2, "rows": [{"sql_id": 1, "sql": "select * from sql_user;"}, - {"sql_id": 2, "sql": "select * from sql_workflow;"}]} - ) + r = self.client.post(path="/sql_analyze/generate/", data={"text": text}) + self.assertEqual( + json.loads(r.content), + { + "total": 2, + "rows": [ + {"sql_id": 1, "sql": "select * from sql_user;"}, + {"sql_id": 2, "sql": "select * from sql_workflow;"}, + ], + }, + ) def test_analyze_text_None(self): """ 测试分析SQL,text为空 :return: """ - r = self.client.post(path='/sql_analyze/analyze/', data={}) - self.assertEqual(json.loads(r.content), {'rows': [], 'total': 0}) + r = self.client.post(path="/sql_analyze/analyze/", data={}) + self.assertEqual(json.loads(r.content), {"rows": [], "total": 0}) def test_analyze_text_not_None(self): """ @@ -2008,10 +2479,14 @@ def test_analyze_text_not_None(self): """ text = "select * from sql_user;select * from sql_workflow;" instance_name = self.master.instance_name - db_name = settings.DATABASES['default']['TEST']['NAME'] - r = self.client.post(path='/sql_analyze/analyze/', - data={"text": text, "instance_name": instance_name, "db_name": db_name}) - self.assertListEqual(list(json.loads(r.content)['rows'][0].keys()), ['sql_id', 'sql', 'report']) + db_name = settings.DATABASES["default"]["TEST"]["NAME"] + r = self.client.post( + path="/sql_analyze/analyze/", + data={"text": text, "instance_name": instance_name, "db_name": db_name}, + ) + self.assertListEqual( + list(json.loads(r.content)["rows"][0].keys()), ["sql_id", "sql", "report"] + ) class TestBinLog(TestCase): @@ -2020,14 +2495,18 @@ class TestBinLog(TestCase): """ def setUp(self): - self.superuser = User(username='super', is_superuser=True) + self.superuser = User(username="super", is_superuser=True) self.superuser.save() # 使用 travis.ci 时实例和测试service保持一致 - self.master = Instance(instance_name='test_instance', type='master', db_type='mysql', - host=settings.DATABASES['default']['HOST'], - port=settings.DATABASES['default']['PORT'], - user=settings.DATABASES['default']['USER'], - password=settings.DATABASES['default']['PASSWORD']) + self.master = Instance( + instance_name="test_instance", + type="master", + db_type="mysql", + host=settings.DATABASES["default"]["HOST"], + port=settings.DATABASES["default"]["PORT"], + user=settings.DATABASES["default"]["USER"], + password=settings.DATABASES["default"]["PASSWORD"], + ) self.master.save() self.sys_config = SysConfig() self.client = Client() @@ -2043,21 +2522,19 @@ def test_binlog_list_instance_not_exist(self): 测试获取binlog列表,实例不存在 :return: """ - data = { - "instance_name": 'some_instance' - } - r = self.client.post(path='/binlog/list/', data=data) - self.assertEqual(json.loads(r.content), {'status': 1, 'msg': '实例不存在', 'data': []}) + data = {"instance_name": "some_instance"} + r = self.client.post(path="/binlog/list/", data=data) + self.assertEqual( + json.loads(r.content), {"status": 1, "msg": "实例不存在", "data": []} + ) def test_binlog_list_instance(self): """ 测试获取binlog列表,实例存在 :return: """ - data = { - "instance_name": 'test_instance' - } - r = self.client.post(path='/binlog/list/', data=data) + data = {"instance_name": "test_instance"} + r = self.client.post(path="/binlog/list/", data=data) # self.assertEqual(json.loads(r.content).get('status'), 1) def test_my2sql_path_not_exist(self): @@ -2065,83 +2542,91 @@ def test_my2sql_path_not_exist(self): 测试获取解析binlog,path未设置 :return: """ - data = {"instance_name": "test_instance", - "save_sql": "false", - "rollback": "2sql", - "num": "", - "threads": 1, - "extra_info": "false", - "ignore_primary_key": "false", - "full_columns": "false", - "no_db_prefix": "false", - "file_per_table": "false", - "start_file": "mysql-bin.000045", - "start_pos": "", - "end_file": "mysql-bin.000045", - "end_pos": "", - "stop_time": "", - "start_time": "", - "only_schemas": "", - "sql_type": ""} - r = self.client.post(path='/binlog/my2sql/', data=data) - self.assertEqual(json.loads(r.content), {'status': 1, 'msg': '可执行文件路径不能为空!', 'data': {}}) - - @patch('sql.plugins.plugin.subprocess') + data = { + "instance_name": "test_instance", + "save_sql": "false", + "rollback": "2sql", + "num": "", + "threads": 1, + "extra_info": "false", + "ignore_primary_key": "false", + "full_columns": "false", + "no_db_prefix": "false", + "file_per_table": "false", + "start_file": "mysql-bin.000045", + "start_pos": "", + "end_file": "mysql-bin.000045", + "end_pos": "", + "stop_time": "", + "start_time": "", + "only_schemas": "", + "sql_type": "", + } + r = self.client.post(path="/binlog/my2sql/", data=data) + self.assertEqual( + json.loads(r.content), {"status": 1, "msg": "可执行文件路径不能为空!", "data": {}} + ) + + @patch("sql.plugins.plugin.subprocess") def test_my2sql(self, _subprocess): """ 测试获取解析binlog,path设置 :param _subprocess: :return: """ - self.sys_config.set('my2sql', '/opt/my2sql') + self.sys_config.set("my2sql", "/opt/my2sql") self.sys_config.get_all_config() - data = {"instance_name": "test_instance", - "save_sql": "1", - "rollback": "2sql", - "num": "1", - "threads": 1, - "extra_info": "false", - "ignore_primary_key": "false", - "full_columns": "false", - "no_db_prefix": "false", - "file_per_table": "false", - "start_file": "mysql-bin.000045", - "start_pos": "", - "end_file": "mysql-bin.000046", - "end_pos": "", - "stop_time": "", - "start_time": "", - "only_schemas": "", - "sql_type": ""} - r = self.client.post(path='/binlog/my2sql/', data=data) + data = { + "instance_name": "test_instance", + "save_sql": "1", + "rollback": "2sql", + "num": "1", + "threads": 1, + "extra_info": "false", + "ignore_primary_key": "false", + "full_columns": "false", + "no_db_prefix": "false", + "file_per_table": "false", + "start_file": "mysql-bin.000045", + "start_pos": "", + "end_file": "mysql-bin.000046", + "end_pos": "", + "stop_time": "", + "start_time": "", + "only_schemas": "", + "sql_type": "", + } + r = self.client.post(path="/binlog/my2sql/", data=data) self.assertEqual(json.loads(r.content), {"status": 0, "msg": "ok", "data": []}) - @patch('builtins.open') + @patch("builtins.open") def test_my2sql_file(self, _open): """ 测试保存文件 :param _subprocess: :return: """ - args = {"instance_name": "test_instance", - "save_sql": "1", - "rollback": "2sql", - "num": "1", - "threads": 1, - "add-extraInfo": "false", - "ignore-primaryKey-forInsert": "false", - "full-columns": "false", - "do-not-add-prifixDb": "false", - "file-per-table": "false", - "start-file": "mysql-bin.000045", - "start-pos": "", - "stop-file": "mysql-bin.000045", - "stop-pos": "", - "stop-datetime": "", - "start-datetime": "", - "databases": "", - "sql": "", - "instance": self.master} + args = { + "instance_name": "test_instance", + "save_sql": "1", + "rollback": "2sql", + "num": "1", + "threads": 1, + "add-extraInfo": "false", + "ignore-primaryKey-forInsert": "false", + "full-columns": "false", + "do-not-add-prifixDb": "false", + "file-per-table": "false", + "start-file": "mysql-bin.000045", + "start-pos": "", + "stop-file": "mysql-bin.000045", + "stop-pos": "", + "stop-datetime": "", + "start-datetime": "", + "databases": "", + "sql": "", + "instance": self.master, + } r = my2sql_file(args=args, user=self.superuser) self.assertEqual(self.superuser, r[0]) @@ -2154,51 +2639,50 @@ def test_del_binlog_instance_not_exist(self): "instance_id": 0, "binlog": "mysql-bin.000001", } - r = self.client.post(path='/binlog/del_log/', data=data) - self.assertEqual(json.loads(r.content), {'status': 1, 'msg': '实例不存在', 'data': []}) + r = self.client.post(path="/binlog/del_log/", data=data) + self.assertEqual( + json.loads(r.content), {"status": 1, "msg": "实例不存在", "data": []} + ) def test_del_binlog_binlog_not_exist(self): """ 测试删除binlog,实例存在,binlog 不存在 :return: """ - data = { - "instance_id": self.master.id, - "binlog": '' - } - r = self.client.post(path='/binlog/del_log/', data=data) - self.assertEqual(json.loads(r.content), {'status': 1, 'msg': 'Error:未选择binlog!', 'data': ''}) + data = {"instance_id": self.master.id, "binlog": ""} + r = self.client.post(path="/binlog/del_log/", data=data) + self.assertEqual( + json.loads(r.content), {"status": 1, "msg": "Error:未选择binlog!", "data": ""} + ) - @patch('sql.engines.mysql.MysqlEngine.query') - @patch('sql.engines.get_engine') + @patch("sql.engines.mysql.MysqlEngine.query") + @patch("sql.engines.get_engine") def test_del_binlog(self, _get_engine, _query): """ 测试删除binlog :return: """ - data = { - "instance_id": self.master.id, - "binlog": "mysql-bin.000001" - } - _query.return_value = ResultSet(full_sql='select 1') - r = self.client.post(path='/binlog/del_log/', data=data) - self.assertEqual(json.loads(r.content), {'status': 0, 'msg': '清理成功', 'data': ''}) + data = {"instance_id": self.master.id, "binlog": "mysql-bin.000001"} + _query.return_value = ResultSet(full_sql="select 1") + r = self.client.post(path="/binlog/del_log/", data=data) + self.assertEqual( + json.loads(r.content), {"status": 0, "msg": "清理成功", "data": ""} + ) - @patch('sql.engines.mysql.MysqlEngine.query') - @patch('sql.engines.get_engine') + @patch("sql.engines.mysql.MysqlEngine.query") + @patch("sql.engines.get_engine") def test_del_binlog_wrong(self, _get_engine, _query): """ 测试删除binlog :return: """ - data = { - "instance_id": self.master.id, - "binlog": "mysql-bin.000001" - } - _query.return_value = ResultSet(full_sql='select 1') - _query.return_value.error = '清理失败' - r = self.client.post(path='/binlog/del_log/', data=data) - self.assertEqual(json.loads(r.content), {'status': 2, 'msg': '清理失败,Error:清理失败', 'data': ''}) + data = {"instance_id": self.master.id, "binlog": "mysql-bin.000001"} + _query.return_value = ResultSet(full_sql="select 1") + _query.return_value.error = "清理失败" + r = self.client.post(path="/binlog/del_log/", data=data) + self.assertEqual( + json.loads(r.content), {"status": 2, "msg": "清理失败,Error:清理失败", "data": ""} + ) class TestParam(TestCase): @@ -2207,14 +2691,18 @@ class TestParam(TestCase): """ def setUp(self): - self.superuser = User(username='super', is_superuser=True) + self.superuser = User(username="super", is_superuser=True) self.superuser.save() # 使用 travis.ci 时实例和测试service保持一致 - self.master = Instance(instance_name='test_instance', type='master', db_type='mysql', - host=settings.DATABASES['default']['HOST'], - port=settings.DATABASES['default']['PORT'], - user=settings.DATABASES['default']['USER'], - password=settings.DATABASES['default']['PASSWORD']) + self.master = Instance( + instance_name="test_instance", + type="master", + db_type="mysql", + host=settings.DATABASES["default"]["HOST"], + port=settings.DATABASES["default"]["PORT"], + user=settings.DATABASES["default"]["USER"], + password=settings.DATABASES["default"]["PASSWORD"], + ) self.master.save() self.client = Client() self.client.force_login(self.superuser) @@ -2229,24 +2717,21 @@ def test_param_list_instance_not_exist(self): 测试获取参数列表,实例不存在 :return: """ - data = { - "instance_id": 0 - } - r = self.client.post(path='/param/list/', data=data) - self.assertEqual(json.loads(r.content), {'status': 1, 'msg': '实例不存在', 'data': []}) + data = {"instance_id": 0} + r = self.client.post(path="/param/list/", data=data) + self.assertEqual( + json.loads(r.content), {"status": 1, "msg": "实例不存在", "data": []} + ) - @patch('sql.engines.mysql.MysqlEngine.get_variables') - @patch('sql.engines.get_engine') + @patch("sql.engines.mysql.MysqlEngine.get_variables") + @patch("sql.engines.get_engine") def test_param_list_instance_exist(self, _get_engine, _get_variables): """ 测试获取参数列表,实例存在 :return: """ - data = { - "instance_id": self.master.id, - "editable": True - } - r = self.client.post(path='/param/list/', data=data) + data = {"instance_id": self.master.id, "editable": True} + r = self.client.post(path="/param/list/", data=data) self.assertIsInstance(json.loads(r.content), list) def test_param_history(self): @@ -2254,92 +2739,124 @@ def test_param_history(self): 测试获取参数修改历史 :return: """ - data = {"instance_id": self.master.id, - "search": "binlog", - "limit": 14, - "offset": 0} - r = self.client.post(path='/param/history/', data=data) - self.assertEqual(json.loads(r.content), {'rows': [], 'total': 0}) + data = { + "instance_id": self.master.id, + "search": "binlog", + "limit": 14, + "offset": 0, + } + r = self.client.post(path="/param/history/", data=data) + self.assertEqual(json.loads(r.content), {"rows": [], "total": 0}) - @patch('sql.engines.mysql.MysqlEngine.set_variable') - @patch('sql.engines.mysql.MysqlEngine.get_variables') - @patch('sql.engines.get_engine') - def test_param_edit_variable_not_config(self, _get_engine, _get_variables, _set_variable): + @patch("sql.engines.mysql.MysqlEngine.set_variable") + @patch("sql.engines.mysql.MysqlEngine.get_variables") + @patch("sql.engines.get_engine") + def test_param_edit_variable_not_config( + self, _get_engine, _get_variables, _set_variable + ): """ 测试参数修改,参数未在模板配置 :return: """ - data = {"instance_id": self.master.id, - "variable_name": "1", - "variable_value": "false"} - r = self.client.post(path='/param/edit/', data=data) - self.assertEqual(json.loads(r.content), {'data': [], 'msg': '请先在参数模板中配置该参数!', 'status': 1}) + data = { + "instance_id": self.master.id, + "variable_name": "1", + "variable_value": "false", + } + r = self.client.post(path="/param/edit/", data=data) + self.assertEqual( + json.loads(r.content), {"data": [], "msg": "请先在参数模板中配置该参数!", "status": 1} + ) - @patch('sql.engines.mysql.MysqlEngine.set_variable') - @patch('sql.engines.mysql.MysqlEngine.get_variables') - @patch('sql.engines.get_engine') - def test_param_edit_variable_not_change(self, _get_engine, _get_variables, _set_variable): + @patch("sql.engines.mysql.MysqlEngine.set_variable") + @patch("sql.engines.mysql.MysqlEngine.get_variables") + @patch("sql.engines.get_engine") + def test_param_edit_variable_not_change( + self, _get_engine, _get_variables, _set_variable + ): """ 测试参数修改,已在参数模板配置,但是值无变化 :return: """ - _get_variables.return_value.rows = (('binlog_format', 'ROW'),) + _get_variables.return_value.rows = (("binlog_format", "ROW"),) _set_variable.return_value.error = None _set_variable.return_value.full_sql = "set global binlog_format='STATEMENT';" - ParamTemplate.objects.create(db_type='mysql', - variable_name='binlog_format', - default_value='ROW', - editable=True) - data = {"instance_id": self.master.id, - "variable_name": "binlog_format", - "runtime_value": "ROW"} - r = self.client.post(path='/param/edit/', data=data) - self.assertEqual(json.loads(r.content), {'status': 1, 'msg': '参数值与实际运行值一致,未调整!', 'data': []}) - - @patch('sql.engines.mysql.MysqlEngine.set_variable') - @patch('sql.engines.mysql.MysqlEngine.get_variables') - @patch('sql.engines.get_engine') - def test_param_edit_variable_change(self, _get_engine, _get_variables, _set_variable): + ParamTemplate.objects.create( + db_type="mysql", + variable_name="binlog_format", + default_value="ROW", + editable=True, + ) + data = { + "instance_id": self.master.id, + "variable_name": "binlog_format", + "runtime_value": "ROW", + } + r = self.client.post(path="/param/edit/", data=data) + self.assertEqual( + json.loads(r.content), {"status": 1, "msg": "参数值与实际运行值一致,未调整!", "data": []} + ) + + @patch("sql.engines.mysql.MysqlEngine.set_variable") + @patch("sql.engines.mysql.MysqlEngine.get_variables") + @patch("sql.engines.get_engine") + def test_param_edit_variable_change( + self, _get_engine, _get_variables, _set_variable + ): """ 测试参数修改,已在参数模板配置,且值有变化 :return: """ - _get_variables.return_value.rows = (('binlog_format', 'ROW'),) + _get_variables.return_value.rows = (("binlog_format", "ROW"),) _set_variable.return_value.error = None _set_variable.return_value.full_sql = "set global binlog_format='STATEMENT';" - ParamTemplate.objects.create(db_type='mysql', - variable_name='binlog_format', - default_value='ROW', - editable=True) - data = {"instance_id": self.master.id, - "variable_name": "binlog_format", - "runtime_value": "STATEMENT"} - r = self.client.post(path='/param/edit/', data=data) - self.assertEqual(json.loads(r.content), {'status': 0, 'msg': '修改成功,请手动持久化到配置文件!', 'data': []}) - - @patch('sql.engines.mysql.MysqlEngine.set_variable') - @patch('sql.engines.mysql.MysqlEngine.get_variables') - @patch('sql.engines.get_engine') - def test_param_edit_variable_error(self, _get_engine, _get_variables, _set_variable): + ParamTemplate.objects.create( + db_type="mysql", + variable_name="binlog_format", + default_value="ROW", + editable=True, + ) + data = { + "instance_id": self.master.id, + "variable_name": "binlog_format", + "runtime_value": "STATEMENT", + } + r = self.client.post(path="/param/edit/", data=data) + self.assertEqual( + json.loads(r.content), {"status": 0, "msg": "修改成功,请手动持久化到配置文件!", "data": []} + ) + + @patch("sql.engines.mysql.MysqlEngine.set_variable") + @patch("sql.engines.mysql.MysqlEngine.get_variables") + @patch("sql.engines.get_engine") + def test_param_edit_variable_error( + self, _get_engine, _get_variables, _set_variable + ): """ 测试参数修改,已在参数模板配置,修改抛错 :return: """ - _get_variables.return_value.rows = (('binlog_format', 'ROW'),) - _set_variable.return_value.error = '修改报错' + _get_variables.return_value.rows = (("binlog_format", "ROW"),) + _set_variable.return_value.error = "修改报错" _set_variable.return_value.full_sql = "set global binlog_format='STATEMENT';" - ParamTemplate.objects.create(db_type='mysql', - variable_name='binlog_format', - default_value='ROW', - editable=True) - data = {"instance_id": self.master.id, - "variable_name": "binlog_format", - "runtime_value": "STATEMENT"} - r = self.client.post(path='/param/edit/', data=data) - self.assertEqual(json.loads(r.content), {'status': 1, 'msg': '设置错误,错误信息:修改报错', 'data': []}) + ParamTemplate.objects.create( + db_type="mysql", + variable_name="binlog_format", + default_value="ROW", + editable=True, + ) + data = { + "instance_id": self.master.id, + "variable_name": "binlog_format", + "runtime_value": "STATEMENT", + } + r = self.client.post(path="/param/edit/", data=data) + self.assertEqual( + json.loads(r.content), {"status": 1, "msg": "设置错误,错误信息:修改报错", "data": []} + ) class TestNotify(TestCase): @@ -2349,55 +2866,66 @@ class TestNotify(TestCase): def setUp(self): self.sys_config = SysConfig() - self.user = User.objects.create(username='test_user', display='中文显示', is_active=True) - self.su = User.objects.create(username='s_user', display='中文显示', is_active=True, is_superuser=True) + self.user = User.objects.create( + username="test_user", display="中文显示", is_active=True + ) + self.su = User.objects.create( + username="s_user", display="中文显示", is_active=True, is_superuser=True + ) tomorrow = datetime.today() + timedelta(days=1) - self.ins = Instance.objects.create(instance_name='some_ins', type='slave', db_type='mysql', - host='some_host', - port=3306, user='ins_user', password='some_str') + self.ins = Instance.objects.create( + instance_name="some_ins", + type="slave", + db_type="mysql", + host="some_host", + port=3306, + user="ins_user", + password="some_str", + ) self.wf = SqlWorkflow.objects.create( - workflow_name='some_name', + workflow_name="some_name", group_id=1, - group_name='g1', + group_name="g1", engineer=self.user.username, engineer_display=self.user.display, - audit_auth_groups='some_audit_group', + audit_auth_groups="some_audit_group", create_time=datetime.now(), - status='workflow_timingtask', + status="workflow_timingtask", is_backup=True, instance=self.ins, - db_name='some_db', - syntax_type=1 + db_name="some_db", + syntax_type=1, + ) + SqlWorkflowContent.objects.create( + workflow=self.wf, sql_content="some_sql", execute_result="" ) - SqlWorkflowContent.objects.create(workflow=self.wf, - sql_content='some_sql', - execute_result='') self.query_apply_1 = QueryPrivilegesApply.objects.create( group_id=1, - group_name='some_name', - title='some_title1', - user_name='some_user', + group_name="some_name", + title="some_title1", + user_name="some_user", instance=self.ins, - db_list='some_db,some_db2', + db_list="some_db,some_db2", limit_num=100, valid_date=tomorrow, priv_type=1, status=0, - audit_auth_groups='some_audit_group' + audit_auth_groups="some_audit_group", ) self.audit = WorkflowAudit.objects.create( group_id=1, - group_name='some_group', + group_name="some_group", workflow_id=1, workflow_type=1, - workflow_title='申请标题', - workflow_remark='申请备注', - audit_auth_groups='1,2,3', - current_audit='1', - next_audit='2', - current_status=0) - self.aug = Group.objects.create(id=1, name='auth_group') - self.rs = ResourceGroup.objects.create(group_id=1, ding_webhook='url') + workflow_title="申请标题", + workflow_remark="申请备注", + audit_auth_groups="1,2,3", + current_audit="1", + next_audit="2", + current_status=0, + ) + self.aug = Group.objects.create(id=1, name="auth_group") + self.rs = ResourceGroup.objects.create(group_id=1, ding_webhook="url") def tearDown(self): self.sys_config.purge() @@ -2413,13 +2941,13 @@ def test_notify_disable(self): :return: """ # 关闭消息通知 - self.sys_config.set('mail', 'false') - self.sys_config.set('ding', 'false') + self.sys_config.set("mail", "false") + self.sys_config.set("ding", "false") r = notify_for_audit(audit_id=self.audit.audit_id) self.assertIsNone(r) - @patch('sql.notify.MsgSender') - @patch('sql.notify.auth_group_users') + @patch("sql.notify.MsgSender") + @patch("sql.notify.auth_group_users") def test_notify_for_sqlreview_audit_wait(self, _auth_group_users, _msg_sender): """ 测试SQL上线申请审核通知 @@ -2428,19 +2956,19 @@ def test_notify_for_sqlreview_audit_wait(self, _auth_group_users, _msg_sender): # 通知人修改 _auth_group_users.return_value = [self.user] # 开启消息通知 - self.sys_config.set('mail', 'true') - self.sys_config.set('ding', 'true') + self.sys_config.set("mail", "true") + self.sys_config.set("ding", "true") # 修改工单状态为待审核 - self.audit.workflow_type = WorkflowDict.workflow_type['sqlreview'] + self.audit.workflow_type = WorkflowDict.workflow_type["sqlreview"] self.audit.workflow_id = self.wf.id - self.audit.current_status = WorkflowDict.workflow_status['audit_wait'] + self.audit.current_status = WorkflowDict.workflow_status["audit_wait"] self.audit.save() r = notify_for_audit(audit_id=self.audit.audit_id) self.assertIsNone(r) _msg_sender.assert_called_once() - @patch('sql.notify.MsgSender') - @patch('sql.notify.auth_group_users') + @patch("sql.notify.MsgSender") + @patch("sql.notify.auth_group_users") def test_notify_for_sqlreview_audit_success(self, _auth_group_users, _msg_sender): """ 测试SQL上线申请审核通过通知 @@ -2449,20 +2977,20 @@ def test_notify_for_sqlreview_audit_success(self, _auth_group_users, _msg_sender # 通知人修改 _auth_group_users.return_value = [self.user] # 开启消息通知 - self.sys_config.set('mail', 'true') - self.sys_config.set('ding', 'true') + self.sys_config.set("mail", "true") + self.sys_config.set("ding", "true") # 修改工单状态审核通过 - self.audit.workflow_type = WorkflowDict.workflow_type['sqlreview'] + self.audit.workflow_type = WorkflowDict.workflow_type["sqlreview"] self.audit.workflow_id = self.wf.id - self.audit.current_status = WorkflowDict.workflow_status['audit_success'] + self.audit.current_status = WorkflowDict.workflow_status["audit_success"] self.audit.create_user = self.user.username self.audit.save() r = notify_for_audit(audit_id=self.audit.audit_id) self.assertIsNone(r) _msg_sender.assert_called_once() - @patch('sql.notify.MsgSender') - @patch('sql.notify.auth_group_users') + @patch("sql.notify.MsgSender") + @patch("sql.notify.auth_group_users") def test_notify_for_sqlreview_audit_reject(self, _auth_group_users, _msg_sender): """ 测试SQL上线申请审核驳回通知 @@ -2471,20 +2999,20 @@ def test_notify_for_sqlreview_audit_reject(self, _auth_group_users, _msg_sender) # 通知人修改 _auth_group_users.return_value = [self.user] # 开启消息通知 - self.sys_config.set('mail', 'true') - self.sys_config.set('ding', 'true') + self.sys_config.set("mail", "true") + self.sys_config.set("ding", "true") # 修改工单状态审核通过 - self.audit.workflow_type = WorkflowDict.workflow_type['sqlreview'] + self.audit.workflow_type = WorkflowDict.workflow_type["sqlreview"] self.audit.workflow_id = self.wf.id - self.audit.current_status = WorkflowDict.workflow_status['audit_reject'] + self.audit.current_status = WorkflowDict.workflow_status["audit_reject"] self.audit.create_user = self.user.username self.audit.save() r = notify_for_audit(audit_id=self.audit.audit_id) self.assertIsNone(r) _msg_sender.assert_called_once() - @patch('sql.notify.MsgSender') - @patch('sql.notify.auth_group_users') + @patch("sql.notify.MsgSender") + @patch("sql.notify.auth_group_users") def test_notify_for_sqlreview_audit_abort(self, _auth_group_users, _msg_sender): """ 测试SQL上线申请审核取消通知 @@ -2493,12 +3021,12 @@ def test_notify_for_sqlreview_audit_abort(self, _auth_group_users, _msg_sender): # 通知人修改 _auth_group_users.return_value = [self.user] # 开启消息通知 - self.sys_config.set('mail', 'true') - self.sys_config.set('ding', 'true') + self.sys_config.set("mail", "true") + self.sys_config.set("ding", "true") # 修改工单状态审核取消 - self.audit.workflow_type = WorkflowDict.workflow_type['sqlreview'] + self.audit.workflow_type = WorkflowDict.workflow_type["sqlreview"] self.audit.workflow_id = self.wf.id - self.audit.current_status = WorkflowDict.workflow_status['audit_abort'] + self.audit.current_status = WorkflowDict.workflow_status["audit_abort"] self.audit.create_user = self.user.username self.audit.audit_auth_groups = self.aug.id self.audit.save() @@ -2506,9 +3034,11 @@ def test_notify_for_sqlreview_audit_abort(self, _auth_group_users, _msg_sender): self.assertIsNone(r) _msg_sender.assert_called_once() - @patch('sql.notify.MsgSender') - @patch('sql.notify.auth_group_users') - def test_notify_for_sqlreview_wrong_workflow_type(self, _auth_group_users, _msg_sender): + @patch("sql.notify.MsgSender") + @patch("sql.notify.auth_group_users") + def test_notify_for_sqlreview_wrong_workflow_type( + self, _auth_group_users, _msg_sender + ): """ 测试不存在的工单类型 :return: @@ -2516,17 +3046,19 @@ def test_notify_for_sqlreview_wrong_workflow_type(self, _auth_group_users, _msg_ # 通知人修改 _auth_group_users.return_value = [self.user] # 开启消息通知 - self.sys_config.set('mail', 'true') - self.sys_config.set('ding', 'true') + self.sys_config.set("mail", "true") + self.sys_config.set("ding", "true") # 修改工单状态审核取消 self.audit.workflow_type = 10 self.audit.save() - with self.assertRaisesMessage(Exception, '工单类型不正确'): + with self.assertRaisesMessage(Exception, "工单类型不正确"): notify_for_audit(audit_id=self.audit.audit_id) - @patch('sql.notify.MsgSender') - @patch('sql.notify.auth_group_users') - def test_notify_for_query_audit_wait_apply_db_perm(self, _auth_group_users, _msg_sender): + @patch("sql.notify.MsgSender") + @patch("sql.notify.auth_group_users") + def test_notify_for_query_audit_wait_apply_db_perm( + self, _auth_group_users, _msg_sender + ): """ 测试查询申请库权限 :return: @@ -2534,12 +3066,12 @@ def test_notify_for_query_audit_wait_apply_db_perm(self, _auth_group_users, _msg # 通知人修改 _auth_group_users.return_value = [self.user] # 开启消息通知 - self.sys_config.set('mail', 'true') - self.sys_config.set('ding', 'true') + self.sys_config.set("mail", "true") + self.sys_config.set("ding", "true") # 修改工单状态为待审核 - self.audit.workflow_type = WorkflowDict.workflow_type['query'] + self.audit.workflow_type = WorkflowDict.workflow_type["query"] self.audit.workflow_id = self.query_apply_1.apply_id - self.audit.current_status = WorkflowDict.workflow_status['audit_wait'] + self.audit.current_status = WorkflowDict.workflow_status["audit_wait"] self.audit.save() # 修改工单为库权限申请 self.query_apply_1.priv_type = 1 @@ -2548,9 +3080,11 @@ def test_notify_for_query_audit_wait_apply_db_perm(self, _auth_group_users, _msg self.assertIsNone(r) _msg_sender.assert_called_once() - @patch('sql.notify.MsgSender') - @patch('sql.notify.auth_group_users') - def test_notify_for_query_audit_wait_apply_tb_perm(self, _auth_group_users, _msg_sender): + @patch("sql.notify.MsgSender") + @patch("sql.notify.auth_group_users") + def test_notify_for_query_audit_wait_apply_tb_perm( + self, _auth_group_users, _msg_sender + ): """ 测试查询申请表权限 :return: @@ -2558,12 +3092,12 @@ def test_notify_for_query_audit_wait_apply_tb_perm(self, _auth_group_users, _msg # 通知人修改 _auth_group_users.return_value = [self.user] # 开启消息通知 - self.sys_config.set('mail', 'true') - self.sys_config.set('ding', 'true') + self.sys_config.set("mail", "true") + self.sys_config.set("ding", "true") # 修改工单状态为待审核 - self.audit.workflow_type = WorkflowDict.workflow_type['query'] + self.audit.workflow_type = WorkflowDict.workflow_type["query"] self.audit.workflow_id = self.query_apply_1.apply_id - self.audit.current_status = WorkflowDict.workflow_status['audit_wait'] + self.audit.current_status = WorkflowDict.workflow_status["audit_wait"] self.audit.save() # 修改工单为表权限申请 self.query_apply_1.priv_type = 2 @@ -2572,21 +3106,21 @@ def test_notify_for_query_audit_wait_apply_tb_perm(self, _auth_group_users, _msg self.assertIsNone(r) _msg_sender.assert_called_once() - @patch('sql.notify.MsgSender') + @patch("sql.notify.MsgSender") def test_notify_for_execute_disable(self, _msg_sender): """ 测试执行消息关闭 :return: """ # 开启消息通知 - self.sys_config.set('mail', 'false') - self.sys_config.set('ding', 'false') + self.sys_config.set("mail", "false") + self.sys_config.set("ding", "false") r = notify_for_execute(self.wf) self.assertIsNone(r) - @patch('sql.notify.auth_group_users') - @patch('sql.notify.Audit') - @patch('sql.notify.MsgSender') + @patch("sql.notify.auth_group_users") + @patch("sql.notify.Audit") + @patch("sql.notify.MsgSender") def test_notify_for_execute(self, _msg_sender, _audit, _auth_group_users): """ 测试执行消息 @@ -2594,42 +3128,43 @@ def test_notify_for_execute(self, _msg_sender, _audit, _auth_group_users): """ _auth_group_users.return_value = [self.user] # 处理工单信息 - _audit.review_info.return_value = self.audit.audit_auth_groups, self.audit.current_audit + _audit.review_info.return_value = ( + self.audit.audit_auth_groups, + self.audit.current_audit, + ) # 开启消息通知 - self.sys_config.set('mail', 'true') - self.sys_config.set('ding', 'true') - self.sys_config.set('ddl_notify_auth_group', self.aug.name) + self.sys_config.set("mail", "true") + self.sys_config.set("ding", "true") + self.sys_config.set("ddl_notify_auth_group", self.aug.name) # 修改工单状态为执行结束,修改为DDL工单 - self.wf.status = 'workflow_finish' + self.wf.status = "workflow_finish" self.wf.syntax_type = 1 self.wf.save() r = notify_for_execute(self.wf) self.assertIsNone(r) _msg_sender.assert_called() - - @patch('sql.notify.MsgSender') + @patch("sql.notify.MsgSender") def test_notify_for_my2sql_disable(self, _msg_sender): """ 测试执行消息关闭 :return: """ # 开启消息通知 - self.sys_config.set('mail', 'false') - self.sys_config.set('ding', 'false') + self.sys_config.set("mail", "false") + self.sys_config.set("ding", "false") r = notify_for_execute(self.wf) self.assertIsNone(r) - - @patch('django_q.tasks.async_task') - @patch('sql.notify.MsgSender') + @patch("django_q.tasks.async_task") + @patch("sql.notify.MsgSender") def test_notify_for_my2sql(self, _msg_sender, _async_task): """ 测试执行消息 :return: """ # 开启消息通知 - self.sys_config.set('mail', 'true') + self.sys_config.set("mail", "true") # 设置为task成功 _async_task.return_value.success.return_value = True r = notify_for_my2sql(_async_task) @@ -2644,17 +3179,23 @@ class TestDataDictionary(TestCase): def setUp(self): self.sys_config = SysConfig() - self.su = User.objects.create(username='s_user', display='中文显示', is_active=True, is_superuser=True) - self.u1 = User.objects.create(username='user1', display='中文显示', is_active=True) + self.su = User.objects.create( + username="s_user", display="中文显示", is_active=True, is_superuser=True + ) + self.u1 = User.objects.create(username="user1", display="中文显示", is_active=True) self.client = Client() self.client.force_login(self.su) # 使用 travis.ci 时实例和测试service保持一致 - self.ins = Instance.objects.create(instance_name='test_instance', type='slave', db_type='mysql', - host=settings.DATABASES['default']['HOST'], - port=settings.DATABASES['default']['PORT'], - user=settings.DATABASES['default']['USER'], - password=settings.DATABASES['default']['PASSWORD']) - self.db_name = settings.DATABASES['default']['TEST']['NAME'] + self.ins = Instance.objects.create( + instance_name="test_instance", + type="slave", + db_type="mysql", + host=settings.DATABASES["default"]["HOST"], + port=settings.DATABASES["default"]["PORT"], + user=settings.DATABASES["default"]["USER"], + password=settings.DATABASES["default"]["PASSWORD"], + ) + self.db_name = settings.DATABASES["default"]["TEST"]["NAME"] def tearDown(self): self.sys_config.purge() @@ -2666,38 +3207,39 @@ def test_data_dictionary_view(self): 测试访问数据字典页面 :return: """ - r = self.client.get(path='/data_dictionary/') + r = self.client.get(path="/data_dictionary/") self.assertEqual(r.status_code, 200) - @patch('sql.data_dictionary.get_engine') + @patch("sql.data_dictionary.get_engine") def test_table_list(self, _get_engine): """ 测试获取表清单 :return: """ - _get_engine.return_value.get_group_tables_by_db.return_value = {'t': [['test1', '测试表1'], ['test2', '测试表2']]} + _get_engine.return_value.get_group_tables_by_db.return_value = { + "t": [["test1", "测试表1"], ["test2", "测试表2"]] + } data = { - 'instance_name': self.ins.instance_name, - 'db_name': self.db_name, - 'db_type': 'mysql' + "instance_name": self.ins.instance_name, + "db_name": self.db_name, + "db_type": "mysql", } - r = self.client.get(path='/data_dictionary/table_list/', data=data) + r = self.client.get(path="/data_dictionary/table_list/", data=data) self.assertEqual(r.status_code, 200) - self.assertDictEqual(json.loads(r.content), - {'data': {'t': [['test1', '测试表1'], ['test2', '测试表2']]}, 'status': 0}) + self.assertDictEqual( + json.loads(r.content), + {"data": {"t": [["test1", "测试表1"], ["test2", "测试表2"]]}, "status": 0}, + ) def test_table_list_not_param(self): """ 测试获取表清单,参数不完整 :return: """ - data = { - 'instance_name': 'not exist ins', - 'db_type': 'mysql' - } - r = self.client.get(path='/data_dictionary/table_list/', data=data) + data = {"instance_name": "not exist ins", "db_type": "mysql"} + r = self.client.get(path="/data_dictionary/table_list/", data=data) self.assertEqual(r.status_code, 200) - self.assertDictEqual(json.loads(r.content), {'msg': '非法调用!', 'status': 1}) + self.assertDictEqual(json.loads(r.content), {"msg": "非法调用!", "status": 1}) def test_table_list_instance_does_not_exist(self): """ @@ -2705,46 +3247,53 @@ def test_table_list_instance_does_not_exist(self): :return: """ data = { - 'instance_name': 'not exist ins', - 'db_name': self.db_name, - 'db_type': 'mysql' + "instance_name": "not exist ins", + "db_name": self.db_name, + "db_type": "mysql", } - r = self.client.get(path='/data_dictionary/table_list/', data=data) + r = self.client.get(path="/data_dictionary/table_list/", data=data) self.assertEqual(r.status_code, 200) - self.assertDictEqual(json.loads(r.content), {'msg': 'Instance.DoesNotExist', 'status': 1}) + self.assertDictEqual( + json.loads(r.content), {"msg": "Instance.DoesNotExist", "status": 1} + ) - @patch('sql.data_dictionary.get_engine') + @patch("sql.data_dictionary.get_engine") def test_table_list_exception(self, _get_engine): """ 测试获取表清单,异常 :return: """ - _get_engine.side_effect = RuntimeError('test error') + _get_engine.side_effect = RuntimeError("test error") data = { - 'instance_name': self.ins.instance_name, - 'db_name': self.db_name, - 'db_type': 'mysql' + "instance_name": self.ins.instance_name, + "db_name": self.db_name, + "db_type": "mysql", } - r = self.client.get(path='/data_dictionary/table_list/', data=data) + r = self.client.get(path="/data_dictionary/table_list/", data=data) self.assertEqual(r.status_code, 200) - self.assertDictEqual(json.loads(r.content), {'msg': 'test error', 'status': 1}) + self.assertDictEqual(json.loads(r.content), {"msg": "test error", "status": 1}) - @patch('sql.data_dictionary.get_engine') + @patch("sql.data_dictionary.get_engine") def test_table_info(self, _get_engine): """ 测试获取表信息 :return: """ - _get_engine.return_value.query.return_value = ResultSet(rows=(('test1', '测试表1'), ('test2', '测试表2'))) + _get_engine.return_value.query.return_value = ResultSet( + rows=(("test1", "测试表1"), ("test2", "测试表2")) + ) data = { - 'instance_name': self.ins.instance_name, - 'db_name': self.db_name, - 'tb_name': 'sql_instance', - 'db_type': 'mysql' + "instance_name": self.ins.instance_name, + "db_name": self.db_name, + "tb_name": "sql_instance", + "db_type": "mysql", } - r = self.client.get(path='/data_dictionary/table_info/', data=data) + r = self.client.get(path="/data_dictionary/table_info/", data=data) self.assertEqual(r.status_code, 200) - self.assertListEqual(list(json.loads(r.content)['data'].keys()), ['meta_data', 'desc', 'index', 'create_sql']) + self.assertListEqual( + list(json.loads(r.content)["data"].keys()), + ["meta_data", "desc", "index", "create_sql"], + ) def test_table_info_not_param(self): """ @@ -2752,11 +3301,11 @@ def test_table_info_not_param(self): :return: """ data = { - 'instance_name': 'not exist ins', + "instance_name": "not exist ins", } - r = self.client.get(path='/data_dictionary/table_info/', data=data) + r = self.client.get(path="/data_dictionary/table_info/", data=data) self.assertEqual(r.status_code, 200) - self.assertDictEqual(json.loads(r.content), {'msg': '非法调用!', 'status': 1}) + self.assertDictEqual(json.loads(r.content), {"msg": "非法调用!", "status": 1}) def test_table_info_instance_does_not_exist(self): """ @@ -2764,31 +3313,33 @@ def test_table_info_instance_does_not_exist(self): :return: """ data = { - 'instance_name': 'not exist ins', - 'db_name': self.db_name, - 'tb_name': 'sql_instance', - 'db_type': 'mysql' + "instance_name": "not exist ins", + "db_name": self.db_name, + "tb_name": "sql_instance", + "db_type": "mysql", } - r = self.client.get(path='/data_dictionary/table_info/', data=data) + r = self.client.get(path="/data_dictionary/table_info/", data=data) self.assertEqual(r.status_code, 200) - self.assertDictEqual(json.loads(r.content), {'msg': 'Instance.DoesNotExist', 'status': 1}) + self.assertDictEqual( + json.loads(r.content), {"msg": "Instance.DoesNotExist", "status": 1} + ) - @patch('sql.data_dictionary.get_engine') + @patch("sql.data_dictionary.get_engine") def test_table_info_exception(self, _get_engine): """ 测试获取表清单,异常 :return: """ - _get_engine.side_effect = RuntimeError('test error') + _get_engine.side_effect = RuntimeError("test error") data = { - 'instance_name': self.ins.instance_name, - 'db_name': self.db_name, - 'tb_name': 'sql_instance', - 'db_type': 'mysql' + "instance_name": self.ins.instance_name, + "db_name": self.db_name, + "tb_name": "sql_instance", + "db_type": "mysql", } - r = self.client.get(path='/data_dictionary/table_info/', data=data) + r = self.client.get(path="/data_dictionary/table_info/", data=data) self.assertEqual(r.status_code, 200) - self.assertDictEqual(json.loads(r.content), {'msg': 'test error', 'status': 1}) + self.assertDictEqual(json.loads(r.content), {"msg": "test error", "status": 1}) def test_export_instance_does_not_exist(self): """ @@ -2796,137 +3347,261 @@ def test_export_instance_does_not_exist(self): :return: """ data = { - 'instance_name': 'not_exist', - 'db_name': self.db_name, - 'db_type': 'mysql' + "instance_name": "not_exist", + "db_name": self.db_name, + "db_type": "mysql", } - r = self.client.get(path='/data_dictionary/export/', data=data) - self.assertDictEqual(json.loads(r.content), {'data': [], 'msg': '你所在组未关联该实例!', 'status': 1}) + r = self.client.get(path="/data_dictionary/export/", data=data) + self.assertDictEqual( + json.loads(r.content), {"data": [], "msg": "你所在组未关联该实例!", "status": 1} + ) - @patch('sql.data_dictionary.user_instances') - @patch('sql.data_dictionary.get_engine') + @patch("sql.data_dictionary.user_instances") + @patch("sql.data_dictionary.get_engine") def test_export_ins_no_perm(self, _get_engine, _user_instances): """ 测试导出实例无权限 :return: """ self.client.force_login(self.u1) - data_dictionary_export = Permission.objects.get(codename='data_dictionary_export') + data_dictionary_export = Permission.objects.get( + codename="data_dictionary_export" + ) self.u1.user_permissions.add(data_dictionary_export) _user_instances.return_value.get.return_value = self.ins - data = { - 'instance_name': self.ins.instance_name, - 'db_type': 'mysql' - - } - r = self.client.get(path='/data_dictionary/export/', data=data) - self.assertDictEqual(json.loads(r.content), - {'status': 1, 'msg': f'仅管理员可以导出整个实例的字典信息!', 'data': []}) + data = {"instance_name": self.ins.instance_name, "db_type": "mysql"} + r = self.client.get(path="/data_dictionary/export/", data=data) + self.assertDictEqual( + json.loads(r.content), + {"status": 1, "msg": f"仅管理员可以导出整个实例的字典信息!", "data": []}, + ) - @patch('sql.data_dictionary.get_engine') + @patch("sql.data_dictionary.get_engine") def test_export_db(self, _get_engine): """ 测试导出 :return: """ - _get_engine.return_value.get_all_databases.return_value.rows.return_value = ResultSet( - rows=(('test1',), ('test2',))) - _get_engine.return_value.query.return_value = ResultSet(rows=( - {'TABLE_CATALOG': 'def', 'TABLE_SCHEMA': 'archer', 'TABLE_NAME': 'aliyun_rds_config', - 'TABLE_TYPE': 'BASE TABLE', 'ENGINE': 'InnoDB', 'VERSION': 10, 'ROW_FORMAT': 'Dynamic', 'TABLE_ROWS': 0, - 'AVG_ROW_LENGTH': 0, 'DATA_LENGTH': 16384, 'MAX_DATA_LENGTH': 0, 'INDEX_LENGTH': 32768, 'DATA_FREE': 0, - 'AUTO_INCREMENT': 1, 'CREATE_TIME': datetime(2019, 5, 28, 9, 25, 41), 'UPDATE_TIME': None, - 'CHECK_TIME': None, 'TABLE_COLLATION': 'utf8_general_ci', 'CHECKSUM': None, 'CREATE_OPTIONS': '', - 'TABLE_COMMENT': ''}, - {'TABLE_CATALOG': 'def', 'TABLE_SCHEMA': 'archer', 'TABLE_NAME': 'auth_group', 'TABLE_TYPE': 'BASE TABLE', - 'ENGINE': 'InnoDB', 'VERSION': 10, 'ROW_FORMAT': 'Dynamic', 'TABLE_ROWS': 8, 'AVG_ROW_LENGTH': 2048, - 'DATA_LENGTH': 16384, 'MAX_DATA_LENGTH': 0, 'INDEX_LENGTH': 16384, 'DATA_FREE': 0, 'AUTO_INCREMENT': 9, - 'CREATE_TIME': datetime(2019, 5, 28, 9, 4, 11), 'UPDATE_TIME': None, 'CHECK_TIME': None, - 'TABLE_COLLATION': 'utf8_general_ci', 'CHECKSUM': None, 'CREATE_OPTIONS': '', 'TABLE_COMMENT': ''})) + _get_engine.return_value.get_all_databases.return_value.rows.return_value = ( + ResultSet(rows=(("test1",), ("test2",))) + ) + _get_engine.return_value.query.return_value = ResultSet( + rows=( + { + "TABLE_CATALOG": "def", + "TABLE_SCHEMA": "archer", + "TABLE_NAME": "aliyun_rds_config", + "TABLE_TYPE": "BASE TABLE", + "ENGINE": "InnoDB", + "VERSION": 10, + "ROW_FORMAT": "Dynamic", + "TABLE_ROWS": 0, + "AVG_ROW_LENGTH": 0, + "DATA_LENGTH": 16384, + "MAX_DATA_LENGTH": 0, + "INDEX_LENGTH": 32768, + "DATA_FREE": 0, + "AUTO_INCREMENT": 1, + "CREATE_TIME": datetime(2019, 5, 28, 9, 25, 41), + "UPDATE_TIME": None, + "CHECK_TIME": None, + "TABLE_COLLATION": "utf8_general_ci", + "CHECKSUM": None, + "CREATE_OPTIONS": "", + "TABLE_COMMENT": "", + }, + { + "TABLE_CATALOG": "def", + "TABLE_SCHEMA": "archer", + "TABLE_NAME": "auth_group", + "TABLE_TYPE": "BASE TABLE", + "ENGINE": "InnoDB", + "VERSION": 10, + "ROW_FORMAT": "Dynamic", + "TABLE_ROWS": 8, + "AVG_ROW_LENGTH": 2048, + "DATA_LENGTH": 16384, + "MAX_DATA_LENGTH": 0, + "INDEX_LENGTH": 16384, + "DATA_FREE": 0, + "AUTO_INCREMENT": 9, + "CREATE_TIME": datetime(2019, 5, 28, 9, 4, 11), + "UPDATE_TIME": None, + "CHECK_TIME": None, + "TABLE_COLLATION": "utf8_general_ci", + "CHECKSUM": None, + "CREATE_OPTIONS": "", + "TABLE_COMMENT": "", + }, + ) + ) data = { - 'instance_name': self.ins.instance_name, - 'db_name': self.db_name, - 'db_type': 'mysql' + "instance_name": self.ins.instance_name, + "db_name": self.db_name, + "db_type": "mysql", } - r = self.client.get(path='/data_dictionary/export/', data=data) + r = self.client.get(path="/data_dictionary/export/", data=data) self.assertEqual(r.status_code, 200) self.assertTrue(r.streaming) - - @patch('sql.data_dictionary.get_engine') + @patch("sql.data_dictionary.get_engine") def oracle_test_export_db(self, _get_engine): """ oracle测试导出 :return: """ - _get_engine.return_value.get_all_databases.return_value.rows.return_value = ResultSet( - rows=(('test1',), ('test2',))) - _get_engine.return_value.query.return_value = ResultSet(rows=( - { 'TABLE_NAME': 'aliyun_rds_config', 'TABLE_COMMENTS': 'TABLE', 'COLUMN_NAME':'t1', 'data_type': 'varcher2(20)', 'DATA_DEFAULT': 'Dynamic', 'NULLABLE': 'Y', 'INDEX_NAME': 'SYS_01', 'COMMENTS': 'SYS_01' - }, - { 'TABLE_NAME': 'auth_group', 'TABLE_COMMENTS': 'TABLE', 'COLUMN_NAME': 't1', 'data_type': 'varcher2(20)', 'DATA_DEFAULT': 'Dynamic', 'NULLABLE': 'N', 'INDEX_NAME': 'SYS_01','COMMENTS': 'SYS_01' - })) + _get_engine.return_value.get_all_databases.return_value.rows.return_value = ( + ResultSet(rows=(("test1",), ("test2",))) + ) + _get_engine.return_value.query.return_value = ResultSet( + rows=( + { + "TABLE_NAME": "aliyun_rds_config", + "TABLE_COMMENTS": "TABLE", + "COLUMN_NAME": "t1", + "data_type": "varcher2(20)", + "DATA_DEFAULT": "Dynamic", + "NULLABLE": "Y", + "INDEX_NAME": "SYS_01", + "COMMENTS": "SYS_01", + }, + { + "TABLE_NAME": "auth_group", + "TABLE_COMMENTS": "TABLE", + "COLUMN_NAME": "t1", + "data_type": "varcher2(20)", + "DATA_DEFAULT": "Dynamic", + "NULLABLE": "N", + "INDEX_NAME": "SYS_01", + "COMMENTS": "SYS_01", + }, + ) + ) data = { - 'instance_name': self.ins.instance_name, - 'db_name': self.db_name, - 'db_type': 'oracle' + "instance_name": self.ins.instance_name, + "db_name": self.db_name, + "db_type": "oracle", } - r = self.client.get(path='/data_dictionary/export/', data=data) - print("oracle_test_export_db" ) + r = self.client.get(path="/data_dictionary/export/", data=data) + print("oracle_test_export_db") self.assertEqual(r.status_code, 200) self.assertTrue(r.streaming) - @patch('sql.data_dictionary.get_engine') + @patch("sql.data_dictionary.get_engine") def test_export_instance(self, _get_engine): """ 测试导出 :return: """ - _get_engine.return_value.get_all_databases.return_value.rows.return_value = ResultSet( - rows=(('test1',), ('test2',))) - _get_engine.return_value.query.return_value = ResultSet(rows=( - {'TABLE_CATALOG': 'def', 'TABLE_SCHEMA': 'archer', 'TABLE_NAME': 'aliyun_rds_config', - 'TABLE_TYPE': 'BASE TABLE', 'ENGINE': 'InnoDB', 'VERSION': 10, 'ROW_FORMAT': 'Dynamic', 'TABLE_ROWS': 0, - 'AVG_ROW_LENGTH': 0, 'DATA_LENGTH': 16384, 'MAX_DATA_LENGTH': 0, 'INDEX_LENGTH': 32768, 'DATA_FREE': 0, - 'AUTO_INCREMENT': 1, 'CREATE_TIME': datetime(2019, 5, 28, 9, 25, 41), 'UPDATE_TIME': None, - 'CHECK_TIME': None, 'TABLE_COLLATION': 'utf8_general_ci', 'CHECKSUM': None, 'CREATE_OPTIONS': '', - 'TABLE_COMMENT': ''}, - {'TABLE_CATALOG': 'def', 'TABLE_SCHEMA': 'archer', 'TABLE_NAME': 'auth_group', 'TABLE_TYPE': 'BASE TABLE', - 'ENGINE': 'InnoDB', 'VERSION': 10, 'ROW_FORMAT': 'Dynamic', 'TABLE_ROWS': 8, 'AVG_ROW_LENGTH': 2048, - 'DATA_LENGTH': 16384, 'MAX_DATA_LENGTH': 0, 'INDEX_LENGTH': 16384, 'DATA_FREE': 0, 'AUTO_INCREMENT': 9, - 'CREATE_TIME': datetime(2019, 5, 28, 9, 4, 11), 'UPDATE_TIME': None, 'CHECK_TIME': None, - 'TABLE_COLLATION': 'utf8_general_ci', 'CHECKSUM': None, 'CREATE_OPTIONS': '', 'TABLE_COMMENT': ''})) - data = { - 'instance_name': self.ins.instance_name, - 'db_type':'mysql' - } - r = self.client.get(path='/data_dictionary/export/', data=data) + _get_engine.return_value.get_all_databases.return_value.rows.return_value = ( + ResultSet(rows=(("test1",), ("test2",))) + ) + _get_engine.return_value.query.return_value = ResultSet( + rows=( + { + "TABLE_CATALOG": "def", + "TABLE_SCHEMA": "archer", + "TABLE_NAME": "aliyun_rds_config", + "TABLE_TYPE": "BASE TABLE", + "ENGINE": "InnoDB", + "VERSION": 10, + "ROW_FORMAT": "Dynamic", + "TABLE_ROWS": 0, + "AVG_ROW_LENGTH": 0, + "DATA_LENGTH": 16384, + "MAX_DATA_LENGTH": 0, + "INDEX_LENGTH": 32768, + "DATA_FREE": 0, + "AUTO_INCREMENT": 1, + "CREATE_TIME": datetime(2019, 5, 28, 9, 25, 41), + "UPDATE_TIME": None, + "CHECK_TIME": None, + "TABLE_COLLATION": "utf8_general_ci", + "CHECKSUM": None, + "CREATE_OPTIONS": "", + "TABLE_COMMENT": "", + }, + { + "TABLE_CATALOG": "def", + "TABLE_SCHEMA": "archer", + "TABLE_NAME": "auth_group", + "TABLE_TYPE": "BASE TABLE", + "ENGINE": "InnoDB", + "VERSION": 10, + "ROW_FORMAT": "Dynamic", + "TABLE_ROWS": 8, + "AVG_ROW_LENGTH": 2048, + "DATA_LENGTH": 16384, + "MAX_DATA_LENGTH": 0, + "INDEX_LENGTH": 16384, + "DATA_FREE": 0, + "AUTO_INCREMENT": 9, + "CREATE_TIME": datetime(2019, 5, 28, 9, 4, 11), + "UPDATE_TIME": None, + "CHECK_TIME": None, + "TABLE_COLLATION": "utf8_general_ci", + "CHECKSUM": None, + "CREATE_OPTIONS": "", + "TABLE_COMMENT": "", + }, + ) + ) + data = {"instance_name": self.ins.instance_name, "db_type": "mysql"} + r = self.client.get(path="/data_dictionary/export/", data=data) self.assertEqual(r.status_code, 200) - self.assertDictEqual(json.loads(r.content), - {'data': [], 'msg': '实例test_instance数据字典导出成功,请到downloads目录下载!', 'status': 0}) + self.assertDictEqual( + json.loads(r.content), + { + "data": [], + "msg": "实例test_instance数据字典导出成功,请到downloads目录下载!", + "status": 0, + }, + ) - @patch('sql.data_dictionary.get_engine') + @patch("sql.data_dictionary.get_engine") def oracle_test_export_instance(self, _get_engine): """ oracle元数据测试导出 :return: """ - _get_engine.return_value.get_all_databases.return_value.rows.return_value = ResultSet( - rows=(('test1',), ('test2',))) - _get_engine.return_value.query.return_value = ResultSet(rows=( - { 'TABLE_NAME': 'aliyun_rds_config', 'TABLE_COMMENTS': 'TABLE', 'COLUMN_NAME':'t1', 'data_type': 'varcher2(20)', 'DATA_DEFAULT': 'Dynamic', 'NULLABLE': 'Y', 'INDEX_NAME': 'SYS_01', 'COMMENTS': 'SYS_01' - }, - { 'TABLE_NAME': 'auth_group', 'TABLE_COMMENTS': 'TABLE', 'COLUMN_NAME': 't1', 'data_type': 'varcher2(20)', 'DATA_DEFAULT': 'Dynamic', 'NULLABLE': 'N', 'INDEX_NAME': 'SYS_01','COMMENTS': 'SYS_01' - })) - data = { - 'instance_name': self.ins.instance_name, - 'db_type':'oracle' - } - r = self.client.get(path='/data_dictionary/export/', data=data) + _get_engine.return_value.get_all_databases.return_value.rows.return_value = ( + ResultSet(rows=(("test1",), ("test2",))) + ) + _get_engine.return_value.query.return_value = ResultSet( + rows=( + { + "TABLE_NAME": "aliyun_rds_config", + "TABLE_COMMENTS": "TABLE", + "COLUMN_NAME": "t1", + "data_type": "varcher2(20)", + "DATA_DEFAULT": "Dynamic", + "NULLABLE": "Y", + "INDEX_NAME": "SYS_01", + "COMMENTS": "SYS_01", + }, + { + "TABLE_NAME": "auth_group", + "TABLE_COMMENTS": "TABLE", + "COLUMN_NAME": "t1", + "data_type": "varcher2(20)", + "DATA_DEFAULT": "Dynamic", + "NULLABLE": "N", + "INDEX_NAME": "SYS_01", + "COMMENTS": "SYS_01", + }, + ) + ) + data = {"instance_name": self.ins.instance_name, "db_type": "oracle"} + r = self.client.get(path="/data_dictionary/export/", data=data) print(r.status_code) - print("oracle_test_export_instance" ) + print("oracle_test_export_instance") self.assertEqual(r.status_code, 200) - self.assertDictEqual(json.loads(r.content), - {'data': [], 'msg': '实例test_instance数据字典导出成功,请到downloads目录下载!', 'status': 0}) - + self.assertDictEqual( + json.loads(r.content), + { + "data": [], + "msg": "实例test_instance数据字典导出成功,请到downloads目录下载!", + "status": 0, + }, + ) diff --git a/sql/urls.py b/sql/urls.py index 71202950ef..084313569f 100644 --- a/sql/urls.py +++ b/sql/urls.py @@ -8,157 +8,160 @@ import sql.sql_optimize from common import auth, config, workflow, dashboard, check from common.twofa import totp -from sql import views, sql_workflow, sql_analyze, query, slowlog, instance, instance_account, db_diagnostic, \ - resource_group, binlog, data_dictionary, archiver, audit_log, user +from sql import ( + views, + sql_workflow, + sql_analyze, + query, + slowlog, + instance, + instance_account, + db_diagnostic, + resource_group, + binlog, + data_dictionary, + archiver, + audit_log, + user, +) from sql.utils import tasks from common.utils import ding_api urlpatterns = [ - path('', views.index), - path('jsi18n/', JavaScriptCatalog.as_view(), name='javascript-catalog'), - path('index/', views.index), - path('login/', views.login, name='login'), - path('login/2fa/', views.twofa, name='twofa'), - path('logout/', auth.sign_out), - path('signup/', auth.sign_up), - path('sqlworkflow/', views.sqlworkflow), - path('submitsql/', views.submit_sql), - path('editsql/', views.submit_sql), - path('submitotherinstance/', views.submit_sql), - path('detail//', views.detail, name='detail'), - path('autoreview/', sql_workflow.submit), - path('passed/', sql_workflow.passed), - path('execute/', sql_workflow.execute), - path('timingtask/', sql_workflow.timing_task), - path('alter_run_date/', sql_workflow.alter_run_date), - path('cancel/', sql_workflow.cancel), - path('rollback/', views.rollback), - path('sqlanalyze/', views.sqlanalyze), - path('sqlquery/', views.sqlquery), - path('slowquery/', views.slowquery), - path('sqladvisor/', views.sqladvisor), - path('slowquery_advisor/', views.sqladvisor), - path('queryapplylist/', views.queryapplylist), - path('queryapplydetail//', views.queryapplydetail, name='queryapplydetail'), - path('queryuserprivileges/', views.queryuserprivileges), - path('dbdiagnostic/', views.dbdiagnostic), - path('workflow/', views.workflows), - path('workflow//', views.workflowsdetail), - path('dbaprinciples/', views.dbaprinciples), - path('dashboard/', dashboard.pyecharts), - path('group/', views.group), - path('grouprelations//', views.groupmgmt), - path('instance/', views.instance), - path('instanceaccount/', views.instanceaccount), - path('database/', views.database), - path('instanceparam/', views.instance_param), - path('my2sql/', views.my2sql), - path('schemasync/', views.schemasync), - path('archive/', views.archive), - path('archive//', views.archive_detail, name='archive_detail'), - path('config/', views.config), - path('audit/', views.audit), - path('audit_sqlquery/', views.audit_sqlquery), - path('audit_sqlworkflow/', views.audit_sqlworkflow), - - path('authenticate/', auth.authenticate_entry), - path('sqlworkflow_list/', sql_workflow.sql_workflow_list), - path('sqlworkflow_list_audit/', sql_workflow.sql_workflow_list_audit), - path('sqlworkflow/detail_content/', sql_workflow.detail_content), - path('sqlworkflow/backup_sql/', sql_workflow.backup_sql), - path('simplecheck/', sql_workflow.check), - path('getWorkflowStatus/', sql_workflow.get_workflow_status), - path('del_sqlcronjob/', tasks.del_schedule), - path('inception/osc_control/', sql_workflow.osc_control), - - path('sql_analyze/generate/', sql_analyze.generate), - path('sql_analyze/analyze/', sql_analyze.analyze), - - path('workflow/list/', workflow.lists), - path('workflow/log/', workflow.log), - path('config/change/', config.change_config), - - path('check/go_inception/', check.go_inception), - path('check/email/', check.email), - path('check/instance/', check.instance), - - path('group/group/', resource_group.group), - path('group/addrelation/', resource_group.addrelation), - path('group/relations/', resource_group.associated_objects), - path('group/instances/', resource_group.instances), - path('group/unassociated/', resource_group.unassociated_objects), - path('group/auditors/', resource_group.auditors), - path('group/changeauditors/', resource_group.changeauditors), - path('group/user_all_instances/', resource_group.user_all_instances), - - path('instance/list/', instance.lists), - - path('instance/user/list', instance_account.users), - path('instance/user/create/', instance_account.create), - path('instance/user/edit/', instance_account.edit), - path('instance/user/grant/', instance_account.grant), - path('instance/user/reset_pwd/', instance_account.reset_pwd), - path('instance/user/lock/', instance_account.lock), - path('instance/user/delete/', instance_account.delete), - - path('instance/database/list/', sql.instance_database.databases), - path('instance/database/create/', sql.instance_database.create), - path('instance/database/edit/', sql.instance_database.edit), - - path('instance/schemasync/', instance.schemasync), - path('instance/instance_resource/', instance.instance_resource), - path('instance/describetable/', instance.describe), - - path('data_dictionary/', views.data_dictionary), - path('data_dictionary/table_list/', data_dictionary.table_list), - path('data_dictionary/table_info/', data_dictionary.table_info), - path('data_dictionary/export/', data_dictionary.export), - - path('param/list/', instance.param_list), - path('param/history/', instance.param_history), - path('param/edit/', instance.param_edit), - - path('query/', query.query), - path('query/querylog/', query.querylog), - path('query/querylog_audit/', query.querylog_audit), - path('query/favorite/', query.favorite), - path('query/explain/', sql.sql_optimize.explain), - path('query/applylist/', sql.query_privileges.query_priv_apply_list), - path('query/userprivileges/', sql.query_privileges.user_query_priv), - path('query/applyforprivileges/', sql.query_privileges.query_priv_apply), - path('query/modifyprivileges/', sql.query_privileges.query_priv_modify), - path('query/privaudit/', sql.query_privileges.query_priv_audit), - - path('binlog/list/', binlog.binlog_list), - path('binlog/my2sql/', binlog.my2sql), - path('binlog/del_log/', binlog.del_binlog), - - path('slowquery/review/', slowlog.slowquery_review), - path('slowquery/review_history/', slowlog.slowquery_review_history), - path('slowquery/optimize_sqladvisor/', sql.sql_optimize.optimize_sqladvisor), - path('slowquery/optimize_sqltuning/', sql.sql_optimize.optimize_sqltuning), - path('slowquery/optimize_soar/', sql.sql_optimize.optimize_soar), - path('slowquery/optimize_sqltuningadvisor/', sql.sql_optimize.optimize_sqltuningadvisor), - path('slowquery/report/', slowlog.report), - - path('db_diagnostic/process/', db_diagnostic.process), - path('db_diagnostic/create_kill_session/', db_diagnostic.create_kill_session), - path('db_diagnostic/kill_session/', db_diagnostic.kill_session), - path('db_diagnostic/tablesapce/', db_diagnostic.tablesapce), - path('db_diagnostic/trxandlocks/', db_diagnostic.trxandlocks), - path('db_diagnostic/innodb_trx/', db_diagnostic.innodb_trx), - - path('archive/list/', archiver.archive_list), - path('archive/apply/', archiver.archive_apply), - path('archive/audit/', archiver.archive_audit), - path('archive/switch/', archiver.archive_switch), - path('archive/once/', archiver.archive_once), - path('archive/log/', archiver.archive_log), - - path('4admin/sync_ding_user/', ding_api.sync_ding_user), - - path('audit/log/', audit_log.audit_log), - path('audit/input/', audit_log.audit_input), - path('user/list/', user.lists), - path('user/qrcode//', totp.generate_qrcode), + path("", views.index), + path("jsi18n/", JavaScriptCatalog.as_view(), name="javascript-catalog"), + path("index/", views.index), + path("login/", views.login, name="login"), + path("login/2fa/", views.twofa, name="twofa"), + path("logout/", auth.sign_out), + path("signup/", auth.sign_up), + path("sqlworkflow/", views.sqlworkflow), + path("submitsql/", views.submit_sql), + path("editsql/", views.submit_sql), + path("submitotherinstance/", views.submit_sql), + path("detail//", views.detail, name="detail"), + path("autoreview/", sql_workflow.submit), + path("passed/", sql_workflow.passed), + path("execute/", sql_workflow.execute), + path("timingtask/", sql_workflow.timing_task), + path("alter_run_date/", sql_workflow.alter_run_date), + path("cancel/", sql_workflow.cancel), + path("rollback/", views.rollback), + path("sqlanalyze/", views.sqlanalyze), + path("sqlquery/", views.sqlquery), + path("slowquery/", views.slowquery), + path("sqladvisor/", views.sqladvisor), + path("slowquery_advisor/", views.sqladvisor), + path("queryapplylist/", views.queryapplylist), + path( + "queryapplydetail//", + views.queryapplydetail, + name="queryapplydetail", + ), + path("queryuserprivileges/", views.queryuserprivileges), + path("dbdiagnostic/", views.dbdiagnostic), + path("workflow/", views.workflows), + path("workflow//", views.workflowsdetail), + path("dbaprinciples/", views.dbaprinciples), + path("dashboard/", dashboard.pyecharts), + path("group/", views.group), + path("grouprelations//", views.groupmgmt), + path("instance/", views.instance), + path("instanceaccount/", views.instanceaccount), + path("database/", views.database), + path("instanceparam/", views.instance_param), + path("my2sql/", views.my2sql), + path("schemasync/", views.schemasync), + path("archive/", views.archive), + path("archive//", views.archive_detail, name="archive_detail"), + path("config/", views.config), + path("audit/", views.audit), + path("audit_sqlquery/", views.audit_sqlquery), + path("audit_sqlworkflow/", views.audit_sqlworkflow), + path("authenticate/", auth.authenticate_entry), + path("sqlworkflow_list/", sql_workflow.sql_workflow_list), + path("sqlworkflow_list_audit/", sql_workflow.sql_workflow_list_audit), + path("sqlworkflow/detail_content/", sql_workflow.detail_content), + path("sqlworkflow/backup_sql/", sql_workflow.backup_sql), + path("simplecheck/", sql_workflow.check), + path("getWorkflowStatus/", sql_workflow.get_workflow_status), + path("del_sqlcronjob/", tasks.del_schedule), + path("inception/osc_control/", sql_workflow.osc_control), + path("sql_analyze/generate/", sql_analyze.generate), + path("sql_analyze/analyze/", sql_analyze.analyze), + path("workflow/list/", workflow.lists), + path("workflow/log/", workflow.log), + path("config/change/", config.change_config), + path("check/go_inception/", check.go_inception), + path("check/email/", check.email), + path("check/instance/", check.instance), + path("group/group/", resource_group.group), + path("group/addrelation/", resource_group.addrelation), + path("group/relations/", resource_group.associated_objects), + path("group/instances/", resource_group.instances), + path("group/unassociated/", resource_group.unassociated_objects), + path("group/auditors/", resource_group.auditors), + path("group/changeauditors/", resource_group.changeauditors), + path("group/user_all_instances/", resource_group.user_all_instances), + path("instance/list/", instance.lists), + path("instance/user/list", instance_account.users), + path("instance/user/create/", instance_account.create), + path("instance/user/edit/", instance_account.edit), + path("instance/user/grant/", instance_account.grant), + path("instance/user/reset_pwd/", instance_account.reset_pwd), + path("instance/user/lock/", instance_account.lock), + path("instance/user/delete/", instance_account.delete), + path("instance/database/list/", sql.instance_database.databases), + path("instance/database/create/", sql.instance_database.create), + path("instance/database/edit/", sql.instance_database.edit), + path("instance/schemasync/", instance.schemasync), + path("instance/instance_resource/", instance.instance_resource), + path("instance/describetable/", instance.describe), + path("data_dictionary/", views.data_dictionary), + path("data_dictionary/table_list/", data_dictionary.table_list), + path("data_dictionary/table_info/", data_dictionary.table_info), + path("data_dictionary/export/", data_dictionary.export), + path("param/list/", instance.param_list), + path("param/history/", instance.param_history), + path("param/edit/", instance.param_edit), + path("query/", query.query), + path("query/querylog/", query.querylog), + path("query/querylog_audit/", query.querylog_audit), + path("query/favorite/", query.favorite), + path("query/explain/", sql.sql_optimize.explain), + path("query/applylist/", sql.query_privileges.query_priv_apply_list), + path("query/userprivileges/", sql.query_privileges.user_query_priv), + path("query/applyforprivileges/", sql.query_privileges.query_priv_apply), + path("query/modifyprivileges/", sql.query_privileges.query_priv_modify), + path("query/privaudit/", sql.query_privileges.query_priv_audit), + path("binlog/list/", binlog.binlog_list), + path("binlog/my2sql/", binlog.my2sql), + path("binlog/del_log/", binlog.del_binlog), + path("slowquery/review/", slowlog.slowquery_review), + path("slowquery/review_history/", slowlog.slowquery_review_history), + path("slowquery/optimize_sqladvisor/", sql.sql_optimize.optimize_sqladvisor), + path("slowquery/optimize_sqltuning/", sql.sql_optimize.optimize_sqltuning), + path("slowquery/optimize_soar/", sql.sql_optimize.optimize_soar), + path( + "slowquery/optimize_sqltuningadvisor/", + sql.sql_optimize.optimize_sqltuningadvisor, + ), + path("slowquery/report/", slowlog.report), + path("db_diagnostic/process/", db_diagnostic.process), + path("db_diagnostic/create_kill_session/", db_diagnostic.create_kill_session), + path("db_diagnostic/kill_session/", db_diagnostic.kill_session), + path("db_diagnostic/tablesapce/", db_diagnostic.tablesapce), + path("db_diagnostic/trxandlocks/", db_diagnostic.trxandlocks), + path("db_diagnostic/innodb_trx/", db_diagnostic.innodb_trx), + path("archive/list/", archiver.archive_list), + path("archive/apply/", archiver.archive_apply), + path("archive/audit/", archiver.archive_audit), + path("archive/switch/", archiver.archive_switch), + path("archive/once/", archiver.archive_once), + path("archive/log/", archiver.archive_log), + path("4admin/sync_ding_user/", ding_api.sync_ding_user), + path("audit/log/", audit_log.audit_log), + path("audit/input/", audit_log.audit_input), + path("user/list/", user.lists), + path("user/qrcode//", totp.generate_qrcode), ] diff --git a/sql/user.py b/sql/user.py index 8dcfb52d06..b7b8ff6053 100644 --- a/sql/user.py +++ b/sql/user.py @@ -9,11 +9,15 @@ @superuser_required def lists(request): """获取用户列表""" - users = Users.objects.order_by('username') - users = users.values("id", "username", "display", "is_superuser", "is_staff", "is_active", "email") + users = Users.objects.order_by("username") + users = users.values( + "id", "username", "display", "is_superuser", "is_staff", "is_active", "email" + ) rows = [row for row in users] result = {"status": 0, "msg": "ok", "data": rows} - return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), - content_type='application/json') + return HttpResponse( + json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), + content_type="application/json", + ) diff --git a/sql/utils/data_masking.py b/sql/utils/data_masking.py index d774105fdb..f4c4e4a91b 100644 --- a/sql/utils/data_masking.py +++ b/sql/utils/data_masking.py @@ -11,7 +11,7 @@ import re import traceback -logger = logging.getLogger('default') +logger = logging.getLogger("default") def data_masking(instance, db_name, sql, sql_result): @@ -21,22 +21,28 @@ def data_masking(instance, db_name, sql, sql_result): # 解析查询语句,判断UNION需要单独处理 p = sqlparse.parse(sql)[0] for token in p.tokens: - if token.ttype is Keyword and token.value.upper() in ['UNION', 'UNION ALL']: - keywords_count['UNION'] = keywords_count.get('UNION', 0) + 1 + if token.ttype is Keyword and token.value.upper() in ["UNION", "UNION ALL"]: + keywords_count["UNION"] = keywords_count.get("UNION", 0) + 1 # 通过goInception获取select list inception_engine = GoInceptionEngine() - select_list = inception_engine.query_data_masking(instance=instance, db_name=db_name, sql=sql) + select_list = inception_engine.query_data_masking( + instance=instance, db_name=db_name, sql=sql + ) # 如果UNION存在,那么调用去重函数 - select_list = del_repeat(select_list, keywords_count) if keywords_count else select_list + select_list = ( + del_repeat(select_list, keywords_count) if keywords_count else select_list + ) # 分析语法树获取命中脱敏规则的列数据 hit_columns = analyze_query_tree(select_list, instance) sql_result.mask_rule_hit = True if hit_columns else False # 对命中规则列hit_columns的数据进行脱敏 - masking_rules = {i.rule_type: model_to_dict(i) for i in DataMaskingRules.objects.all()} + masking_rules = { + i.rule_type: model_to_dict(i) for i in DataMaskingRules.objects.all() + } if hit_columns and sql_result.rows: rows = list(sql_result.rows) for column in hit_columns: - index, rule_type = column['index'], column['rule_type'] + index, rule_type = column["index"], column["rule_type"] masking_rule = masking_rules.get(rule_type) if not masking_rule: continue @@ -47,7 +53,7 @@ def data_masking(instance, db_name, sql, sql_result): # 脱敏结果 sql_result.is_masked = True except Exception as msg: - logger.warning(f'数据脱敏异常,错误信息:{traceback.format_exc()}') + logger.warning(f"数据脱敏异常,错误信息:{traceback.format_exc()}") sql_result.error = str(msg) sql_result.status = 1 return sql_result @@ -65,15 +71,15 @@ def del_repeat(select_list, keywords_count): # 先将query_tree转换成表,方便统计 df = pd.DataFrame(select_list) - #从原来的库、表、字段去重改为字段 - #result_index = df.groupby(['field', 'table', 'schema']).filter(lambda g: len(g) > 1).to_dict('records') - result_index = df.groupby(['field']).filter(lambda g: len(g) > 1).to_dict('records') + # 从原来的库、表、字段去重改为字段 + # result_index = df.groupby(['field', 'table', 'schema']).filter(lambda g: len(g) > 1).to_dict('records') + result_index = df.groupby(["field"]).filter(lambda g: len(g) > 1).to_dict("records") # 再统计重复数量 result_len = len(result_index) - + # 再计算取列表前多少的值=重复数量/(union次数+1) - group_count = int(result_len / (keywords_count['UNION'] + 1)) + group_count = int(result_len / (keywords_count["UNION"] + 1)) result = result_index[:group_count] return result @@ -83,39 +89,49 @@ def analyze_query_tree(select_list, instance): """解析select list, 返回命中脱敏规则的列信息""" # 获取实例全部激活的脱敏字段信息,减少循环查询,提升效率 masking_columns = { - f"{i.instance}-{i.table_schema}-{i.table_name}-{i.column_name}": model_to_dict(i) for i in - DataMaskingColumns.objects.filter(instance=instance, active=True) + f"{i.instance}-{i.table_schema}-{i.table_name}-{i.column_name}": model_to_dict( + i + ) + for i in DataMaskingColumns.objects.filter(instance=instance, active=True) } # 遍历select_list 格式化命中的列信息 hit_columns = [] for column in select_list: - table_schema, table, field = column.get('schema'), column.get('table'), column.get('field') - masking_column = masking_columns.get(f"{instance}-{table_schema}-{table}-{field}") + table_schema, table, field = ( + column.get("schema"), + column.get("table"), + column.get("field"), + ) + masking_column = masking_columns.get( + f"{instance}-{table_schema}-{table}-{field}" + ) if masking_column: - hit_columns.append({ - "instance_name": instance.instance_name, - "table_schema": table_schema, - "table_name": table, - "column_name": field, - "rule_type": masking_column['rule_type'], - "is_hit": True, - "index": column['index'] - }) + hit_columns.append( + { + "instance_name": instance.instance_name, + "table_schema": table_schema, + "table_name": table, + "column_name": field, + "rule_type": masking_column["rule_type"], + "is_hit": True, + "index": column["index"], + } + ) return hit_columns def regex(masking_rule, value): """利用正则表达式脱敏数据""" - rule_regex = masking_rule['rule_regex'] - hide_group = masking_rule['hide_group'] + rule_regex = masking_rule["rule_regex"] + hide_group = masking_rule["hide_group"] # 正则匹配必须分组,隐藏的组会使用****代替 try: p = re.compile(rule_regex, re.I) m = p.search(str(value)) - masking_str = '' + masking_str = "" for i in range(m.lastindex): if i == hide_group - 1: - group = '****' + group = "****" else: group = m.group(i + 1) masking_str = masking_str + group @@ -132,7 +148,11 @@ def brute_mask(instance, sql_result): 返回同样结构的sql_result , error 中写入脱敏时产生的错误. """ # 读取所有关联实例的脱敏规则,去重后应用到结果集,不会按照具体配置的字段匹配 - rule_types = DataMaskingColumns.objects.filter(instance=instance).values_list('rule_type', flat=True).distinct() + rule_types = ( + DataMaskingColumns.objects.filter(instance=instance) + .values_list("rule_type", flat=True) + .distinct() + ) masking_rules = DataMaskingRules.objects.filter(rule_type__in=rule_types) for reg in masking_rules: compiled_r = re.compile(reg.rule_regex, re.I) @@ -147,7 +167,9 @@ def brute_mask(instance, sql_result): temp_value_list = [] for j in range(len(sql_result.rows[i])): # 进行正则替换 - temp_value_list += [compiled_r.sub(replace_pattern, str(sql_result.rows[i][j]))] + temp_value_list += [ + compiled_r.sub(replace_pattern, str(sql_result.rows[i][j])) + ] rows[i] = tuple(temp_value_list) sql_result.rows = rows return sql_result @@ -172,22 +194,30 @@ def simple_column_mask(instance, sql_result): # 脱敏规则字段索引信息 _masking_column_index = [] if column_name in sql_result_column_list: - _masking_column_index.append(sql_result_column_list.index(column_name)) + _masking_column_index.append( + sql_result_column_list.index(column_name) + ) # 别名字段脱敏处理 try: for _c in sql_result_column_list: - alias_column_regex = r'"?([^\s"]+)"?\s+(as\s+)?"?({})[",\s+]?'.format(re.escape(_c)) + alias_column_regex = ( + r'"?([^\s"]+)"?\s+(as\s+)?"?({})[",\s+]?'.format( + re.escape(_c) + ) + ) alias_column_r = re.compile(alias_column_regex, re.I) # 解析原SQL查询别名字段 search_data = re.search(alias_column_r, sql_result.full_sql) # 字段名 _column_name = search_data.group(1).lower() - s_column_name = re.sub(r'^"?\w+"?\."?|\.|"$', '', _column_name) + s_column_name = re.sub(r'^"?\w+"?\."?|\.|"$', "", _column_name) # 别名 alias_name = search_data.group(3).lower() # 如果字段名匹配脱敏配置字段,对此字段进行脱敏处理 if s_column_name == column_name: - _masking_column_index.append(sql_result_column_list.index(alias_name)) + _masking_column_index.append( + sql_result_column_list.index(alias_name) + ) except: pass @@ -209,7 +239,9 @@ def simple_column_mask(instance, sql_result): for j in range(len(sql_result.rows[i])): column_data = sql_result.rows[i][j] if j == masking_column_index: - column_data = compiled_r.sub(replace_pattern, str(sql_result.rows[i][j])) + column_data = compiled_r.sub( + replace_pattern, str(sql_result.rows[i][j]) + ) temp_value_list += [column_data] rows[i] = tuple(temp_value_list) sql_result.rows = rows diff --git a/sql/utils/execute_sql.py b/sql/utils/execute_sql.py index 0ba93436c2..4e2c084a53 100644 --- a/sql/utils/execute_sql.py +++ b/sql/utils/execute_sql.py @@ -12,7 +12,7 @@ from sql.utils.workflow_audit import Audit from sql.engines import get_engine -logger = logging.getLogger('default') +logger = logging.getLogger("default") def execute(workflow_id, user=None): @@ -21,21 +21,25 @@ def execute(workflow_id, user=None): with transaction.atomic(): workflow_detail = SqlWorkflow.objects.select_for_update().get(id=workflow_id) # 只有排队中和定时执行的数据才可以继续执行,否则直接抛错 - if workflow_detail.status not in ['workflow_queuing', 'workflow_timingtask']: - raise Exception('工单状态不正确,禁止执行!') + if workflow_detail.status not in ["workflow_queuing", "workflow_timingtask"]: + raise Exception("工单状态不正确,禁止执行!") # 将工单状态修改为执行中 else: - SqlWorkflow(id=workflow_id, status='workflow_executing').save(update_fields=['status']) + SqlWorkflow(id=workflow_id, status="workflow_executing").save( + update_fields=["status"] + ) # 增加执行日志 - audit_id = Audit.detail_by_workflow_id(workflow_id=workflow_id, - workflow_type=WorkflowDict.workflow_type['sqlreview']).audit_id - Audit.add_log(audit_id=audit_id, - operation_type=5, - operation_type_desc='执行工单', - operation_info='工单开始执行' if user else '系统定时执行工单', - operator=user.username if user else '', - operator_display=user.display if user else '系统' - ) + audit_id = Audit.detail_by_workflow_id( + workflow_id=workflow_id, workflow_type=WorkflowDict.workflow_type["sqlreview"] + ).audit_id + Audit.add_log( + audit_id=audit_id, + operation_type=5, + operation_type_desc="执行工单", + operation_info="工单开始执行" if user else "系统定时执行工单", + operator=user.username if user else "", + operator_display=user.display if user else "系统", + ) execute_engine = get_engine(instance=workflow_detail.instance) return execute_engine.execute_workflow(workflow=workflow_detail) @@ -52,60 +56,68 @@ def execute_callback(task): # 判断工单状态,如果不是执行中的,不允许更新信息,直接抛错记录日志 with transaction.atomic(): workflow = SqlWorkflow.objects.get(id=workflow_id) - if workflow.status != 'workflow_executing': - raise Exception(f'工单{workflow.id}状态不正确,禁止重复更新执行结果!') + if workflow.status != "workflow_executing": + raise Exception(f"工单{workflow.id}状态不正确,禁止重复更新执行结果!") workflow.finish_time = task.stopped if not task.success: # 不成功会返回错误堆栈信息,构造一个错误信息 - workflow.status = 'workflow_exception' + workflow.status = "workflow_exception" execute_result = ReviewSet(full_sql=workflow.sqlworkflowcontent.sql_content) - execute_result.rows = [ReviewResult( - stage='Execute failed', - errlevel=2, - stagestatus='异常终止', - errormessage=task.result, - sql=workflow.sqlworkflowcontent.sql_content)] + execute_result.rows = [ + ReviewResult( + stage="Execute failed", + errlevel=2, + stagestatus="异常终止", + errormessage=task.result, + sql=workflow.sqlworkflowcontent.sql_content, + ) + ] elif task.result.warning or task.result.error: execute_result = task.result - workflow.status = 'workflow_exception' + workflow.status = "workflow_exception" else: execute_result = task.result - workflow.status = 'workflow_finish' + workflow.status = "workflow_finish" try: # 保存执行结果 workflow.sqlworkflowcontent.execute_result = execute_result.json() workflow.sqlworkflowcontent.save() workflow.save() except Exception as e: - logger.error(f'SQL工单回调异常: {workflow_id} {traceback.format_exc()}') + logger.error(f"SQL工单回调异常: {workflow_id} {traceback.format_exc()}") SqlWorkflow.objects.filter(id=workflow_id).update( finish_time=task.stopped, - status='workflow_exception', + status="workflow_exception", ) - workflow.sqlworkflowcontent.execute_result = {f'{e}'} + workflow.sqlworkflowcontent.execute_result = {f"{e}"} workflow.sqlworkflowcontent.save() # 增加工单日志 - audit_id = Audit.detail_by_workflow_id(workflow_id=workflow_id, - workflow_type=WorkflowDict.workflow_type['sqlreview']).audit_id - Audit.add_log(audit_id=audit_id, - operation_type=6, - operation_type_desc='执行结束', - operation_info='执行结果:{}'.format(workflow.get_status_display()), - operator='', - operator_display='系统' - ) + audit_id = Audit.detail_by_workflow_id( + workflow_id=workflow_id, workflow_type=WorkflowDict.workflow_type["sqlreview"] + ).audit_id + Audit.add_log( + audit_id=audit_id, + operation_type=6, + operation_type_desc="执行结束", + operation_info="执行结果:{}".format(workflow.get_status_display()), + operator="", + operator_display="系统", + ) # DDL工单结束后清空实例资源缓存 if workflow.syntax_type == 1: r = get_redis_connection("default") - for key in r.scan_iter(match='*insRes*', count=2000): + for key in r.scan_iter(match="*insRes*", count=2000): r.delete(key) # 开启了Execute阶段通知参数才发送消息通知 sys_config = SysConfig() - is_notified = 'Execute' in sys_config.get('notify_phase_control').split(',') \ - if sys_config.get('notify_phase_control') else True + is_notified = ( + "Execute" in sys_config.get("notify_phase_control").split(",") + if sys_config.get("notify_phase_control") + else True + ) if is_notified: notify_for_execute(workflow) diff --git a/sql/utils/extract_tables.py b/sql/utils/extract_tables.py index 73aed5d522..9ba7972d3f 100644 --- a/sql/utils/extract_tables.py +++ b/sql/utils/extract_tables.py @@ -53,9 +53,12 @@ def is_subselect(parsed): if not parsed.is_group: return False for item in parsed.tokens: - if ( - item.ttype is DML - and item.value.upper() in ("SELECT", "INSERT", "UPDATE", "CREATE", "DELETE") + if item.ttype is DML and item.value.upper() in ( + "SELECT", + "INSERT", + "UPDATE", + "CREATE", + "DELETE", ): return True return False @@ -82,18 +85,23 @@ def extract_from_part(parsed, stop_at_punctuation=True): # Also 'SELECT * FROM abc JOIN def' will trigger this elif # condition. So we need to ignore the keyword JOIN and its variants # INNER JOIN, FULL OUTER JOIN, etc. - elif item.ttype is Keyword and (not item.value.upper() == "FROM") and ( - not item.value.upper().endswith("JOIN") + elif ( + item.ttype is Keyword + and (not item.value.upper() == "FROM") + and (not item.value.upper().endswith("JOIN")) ): tbl_prefix_seen = False else: yield item elif item.ttype is Keyword or item.ttype is Keyword.DML: item_val = item.value.upper() - if ( - item_val in ("COPY", "FROM", "INTO", "UPDATE", "TABLE") - or item_val.endswith("JOIN") - ): + if item_val in ( + "COPY", + "FROM", + "INTO", + "UPDATE", + "TABLE", + ) or item_val.endswith("JOIN"): tbl_prefix_seen = True # 'SELECT a, FROM abc' will detect FROM as part of the column list. # So this check here is necessary. @@ -139,8 +147,8 @@ def parse_identifier(item): try: schema_name = identifier.get_parent_name() real_name = identifier.get_real_name() - is_function = ( - allow_functions and _identifier_is_function(identifier) + is_function = allow_functions and _identifier_is_function( + identifier ) except AttributeError: continue diff --git a/sql/utils/resource_group.py b/sql/utils/resource_group.py index f3349753c5..2ccf46a04c 100644 --- a/sql/utils/resource_group.py +++ b/sql/utils/resource_group.py @@ -12,7 +12,12 @@ def user_groups(user): if user.is_superuser: group_list = [group for group in ResourceGroup.objects.filter(is_deleted=0)] else: - group_list = [group for group in Users.objects.get(id=user.id).resource_group.filter(is_deleted=0)] + group_list = [ + group + for group in Users.objects.get(id=user.id).resource_group.filter( + is_deleted=0 + ) + ] return group_list @@ -26,7 +31,7 @@ def user_instances(user, type=None, db_type=None, tag_codes=None): :return: """ # 拥有所有实例权限的用户 - if user.has_perm('sql.query_all_instances'): + if user.has_perm("sql.query_all_instances"): instances = Instance.objects.all() else: # 先获取用户关联的资源组 @@ -44,7 +49,9 @@ def user_instances(user, type=None, db_type=None, tag_codes=None): # 过滤tag if tag_codes: for tag_code in tag_codes: - instances = instances.filter(instance_tag__tag_code=tag_code, instance_tag__active=True) + instances = instances.filter( + instance_tag__tag_code=tag_code, instance_tag__active=True + ) return instances.distinct() diff --git a/sql/utils/sql_review.py b/sql/utils/sql_review.py index 7cd38eb59a..771747f30d 100644 --- a/sql/utils/sql_review.py +++ b/sql/utils/sql_review.py @@ -18,14 +18,19 @@ def is_auto_review(workflow_id): """ workflow = SqlWorkflow.objects.get(id=workflow_id) - auto_review_tags = SysConfig().get('auto_review_tag', '').split(',') - auto_review_db_type = SysConfig().get('auto_review_db_type', '').split(',') + auto_review_tags = SysConfig().get("auto_review_tag", "").split(",") + auto_review_db_type = SysConfig().get("auto_review_db_type", "").split(",") # TODO 这里也可以放到engine中实现,但是配置项可能会相对复杂 - if workflow.instance.db_type in auto_review_db_type and workflow.instance.instance_tag.filter( - tag_code__in=auto_review_tags).exists(): + if ( + workflow.instance.db_type in auto_review_db_type + and workflow.instance.instance_tag.filter( + tag_code__in=auto_review_tags + ).exists() + ): # 获取正则表达式 auto_review_regex = SysConfig().get( - 'auto_review_regex', '^alter|^create|^drop|^truncate|^rename|^delete') + "auto_review_regex", "^alter|^create|^drop|^truncate|^rename|^delete" + ) p = re.compile(auto_review_regex, re.I) # 判断是否匹配到需要手动审核的语句 @@ -35,14 +40,14 @@ def is_auto_review(workflow_id): for review_row in json.loads(review_content): review_result = ReviewResult(**review_row) # 去除SQL注释 https://github.com/hhyo/Archery/issues/949 - sql = remove_comments(review_result.sql).replace("\n","").replace("\r", "") + sql = remove_comments(review_result.sql).replace("\n", "").replace("\r", "") # 正则匹配 if p.match(sql): auto_review = False break # 影响行数加测, 总语句影响行数超过指定数量则需要人工审核 all_affected_rows += int(review_result.affected_rows) - if all_affected_rows > int(SysConfig().get('auto_review_max_update_rows', 50)): + if all_affected_rows > int(SysConfig().get("auto_review_max_update_rows", 50)): auto_review = False else: auto_review = False @@ -63,14 +68,19 @@ def can_execute(user, workflow_id): with transaction.atomic(): workflow_detail = SqlWorkflow.objects.select_for_update().get(id=workflow_id) # 只有审核通过和定时执行的数据才可以立即执行 - if workflow_detail.status not in ['workflow_review_pass', 'workflow_timingtask']: + if workflow_detail.status not in [ + "workflow_review_pass", + "workflow_timingtask", + ]: return False # 当前登录用户有资源组粒度执行权限,并且为组内用户 group_ids = [group.group_id for group in user_groups(user)] - if workflow_detail.group_id in group_ids and user.has_perm('sql.sql_execute_for_resource_group'): + if workflow_detail.group_id in group_ids and user.has_perm( + "sql.sql_execute_for_resource_group" + ): result = True # 当前登录用户为提交人,并且有执行权限 - if workflow_detail.engineer == user.username and user.has_perm('sql.sql_execute'): + if workflow_detail.engineer == user.username and user.has_perm("sql.sql_execute"): result = True return result @@ -104,13 +114,17 @@ def can_timingtask(user, workflow_id): workflow_detail = SqlWorkflow.objects.get(id=workflow_id) result = False # 只有审核通过和定时执行的数据才可以执行 - if workflow_detail.status in ['workflow_review_pass', 'workflow_timingtask']: + if workflow_detail.status in ["workflow_review_pass", "workflow_timingtask"]: # 当前登录用户有资源组粒度执行权限,并且为组内用户 group_ids = [group.group_id for group in user_groups(user)] - if workflow_detail.group_id in group_ids and user.has_perm('sql.sql_execute_for_resource_group'): + if workflow_detail.group_id in group_ids and user.has_perm( + "sql.sql_execute_for_resource_group" + ): result = True # 当前登录用户为提交人,并且有执行权限 - if workflow_detail.engineer == user.username and user.has_perm('sql.sql_execute'): + if workflow_detail.engineer == user.username and user.has_perm( + "sql.sql_execute" + ): result = True return result @@ -126,11 +140,19 @@ def can_cancel(user, workflow_id): workflow_detail = SqlWorkflow.objects.get(id=workflow_id) result = False # 审核中的工单,审核人和提交人可终止 - if workflow_detail.status == 'workflow_manreviewing': + if workflow_detail.status == "workflow_manreviewing": from sql.utils.workflow_audit import Audit - return any([Audit.can_review(user, workflow_id, 2), user.username == workflow_detail.engineer]) - elif workflow_detail.status in ['workflow_review_pass', 'workflow_timingtask']: - return any([can_execute(user, workflow_id), user.username == workflow_detail.engineer]) + + return any( + [ + Audit.can_review(user, workflow_id, 2), + user.username == workflow_detail.engineer, + ] + ) + elif workflow_detail.status in ["workflow_review_pass", "workflow_timingtask"]: + return any( + [can_execute(user, workflow_id), user.username == workflow_detail.engineer] + ) return result @@ -147,7 +169,9 @@ def can_view(user, workflow_id): if user.is_superuser: result = True # 非管理员,拥有审核权限、资源组粒度执行权限的,可以查看组内所有工单 - elif user.has_perm('sql.sql_review') or user.has_perm('sql.sql_execute_for_resource_group'): + elif user.has_perm("sql.sql_review") or user.has_perm( + "sql.sql_execute_for_resource_group" + ): # 先获取用户所在资源组列表 group_list = user_groups(user) group_ids = [group.group_id for group in group_list] @@ -171,6 +195,9 @@ def can_rollback(user, workflow_id): workflow_detail = SqlWorkflow.objects.get(id=workflow_id) result = False # 执行结束并且开启备份的工单可以查看回滚信息 - if workflow_detail.is_backup and workflow_detail.status in ('workflow_finish', 'workflow_exception'): + if workflow_detail.is_backup and workflow_detail.status in ( + "workflow_finish", + "workflow_exception", + ): return can_view(user, workflow_id) return result diff --git a/sql/utils/sql_utils.py b/sql/utils/sql_utils.py index 55e32cea3b..201f0c20f0 100644 --- a/sql/utils/sql_utils.py +++ b/sql/utils/sql_utils.py @@ -13,10 +13,10 @@ from sql.engines.models import SqlItem from sql.utils.extract_tables import extract_tables as extract_tables_by_sql_parse -__author__ = 'hhyo' +__author__ = "hhyo" -def get_syntax_type(sql, parser=True, db_type='mysql'): +def get_syntax_type(sql, parser=True, db_type="mysql"): """ 返回SQL语句类型,仅判断DDL和DML :param sql: @@ -29,32 +29,32 @@ def get_syntax_type(sql, parser=True, db_type='mysql'): try: statement = sqlparse.parse(sql)[0] syntax_type = statement.token_first(skip_cm=True).ttype.__str__() - if syntax_type == 'Token.Keyword.DDL': - syntax_type = 'DDL' - elif syntax_type == 'Token.Keyword.DML': - syntax_type = 'DML' + if syntax_type == "Token.Keyword.DDL": + syntax_type = "DDL" + elif syntax_type == "Token.Keyword.DML": + syntax_type = "DML" except Exception: syntax_type = None else: - if db_type == 'mysql': + if db_type == "mysql": ddl_re = r"^alter|^create|^drop|^rename|^truncate" dml_re = r"^call|^delete|^do|^handler|^insert|^load\s+data|^load\s+xml|^replace|^select|^update" - elif db_type == 'oracle': + elif db_type == "oracle": ddl_re = r"^alter|^create|^drop|^rename|^truncate" dml_re = r"^delete|^exec|^insert|^select|^update|^with|^merge" else: # TODO 其他数据库的解析正则 return None if re.match(ddl_re, sql, re.I): - syntax_type = 'DDL' + syntax_type = "DDL" elif re.match(dml_re, sql, re.I): - syntax_type = 'DML' + syntax_type = "DML" else: syntax_type = None return syntax_type -def remove_comments(sql, db_type='mysql'): +def remove_comments(sql, db_type="mysql"): """ 去除SQL语句中的注释信息 来源:https://stackoverflow.com/questions/35647841/parse-sql-file-with-comments-into-sqlite-with-python @@ -63,10 +63,8 @@ def remove_comments(sql, db_type='mysql'): :return: """ sql_comments_re = { - 'oracle': - [r'(?:--)[^\n]*\n', r'(?:\W|^)(?:remark|rem)\s+[^\n]*\n'], - 'mysql': - [r'(?:#|--\s)[^\n]*\n'] + "oracle": [r"(?:--)[^\n]*\n", r"(?:\W|^)(?:remark|rem)\s+[^\n]*\n"], + "mysql": [r"(?:#|--\s)[^\n]*\n"], } specific_comment_re = sql_comments_re[db_type] additional_patterns = "|" @@ -94,10 +92,12 @@ def extract_tables(sql): """ tables = list() for i in extract_tables_by_sql_parse(sql): - tables.append({ - "schema": i.schema, - "name": i.name, - }) + tables.append( + { + "schema": i.schema, + "name": i.name, + } + ) return tables @@ -110,7 +110,7 @@ def generate_sql(text): # 尝试XML解析 try: mapper, xml_raw_text = mybatis_mapper2sql.create_mapper(xml_raw_text=text) - statements = mybatis_mapper2sql.get_statement(mapper, result_type='list') + statements = mybatis_mapper2sql.get_statement(mapper, result_type="list") rows = [] # 压缩SQL语句,方便展示 for statement in statements: @@ -131,13 +131,15 @@ def generate_sql(text): def get_base_sqlitem_list(full_sql): - ''' 把参数 full_sql 转变为 SqlItem列表 + """把参数 full_sql 转变为 SqlItem列表 :param full_sql: 完整sql字符串, 每个SQL以分号;间隔, 不包含plsql执行块和plsql对象定义块 :return: SqlItem对象列表 - ''' + """ list = [] for statement in sqlparse.split(full_sql): - statement = sqlparse.format(statement, strip_comments=True, reindent=True, keyword_case='lower') + statement = sqlparse.format( + statement, strip_comments=True, reindent=True, keyword_case="lower" + ) if len(statement) <= 0: continue item = SqlItem(statement=statement) @@ -146,15 +148,15 @@ def get_base_sqlitem_list(full_sql): def get_full_sqlitem_list(full_sql, db_name): - ''' 获取Sql对应的SqlItem列表, 包括PLSQL部分 + """获取Sql对应的SqlItem列表, 包括PLSQL部分 PLSQL语句块由delimiter $$作为开始间隔符,以$$作为结束间隔符 :param full_sql: 全部sql内容 :return: SqlItem 列表 - ''' + """ list = [] # 定义开始分隔符,两端用括号,是为了re.split()返回列表包含分隔符 - regex_delimiter = r'(delimiter\s*\$\$)' + regex_delimiter = r"(delimiter\s*\$\$)" # 注意:必须把package body置于package之前,否则将永远匹配不上package body regex_objdefine = r'create\s+or\s+replace\s+(function|procedure|trigger|package\s+body|package|view)\s+("?\w+"?\.)?"?\w+"?[\s+|\(]' # 对象命名,两端有双引号 @@ -192,7 +194,7 @@ def get_full_sqlitem_list(full_sql, db_name): plsql_block = sql[0:pos].strip() # 如果plsql_area字符串最后一个字符为/,则把/给去掉 while True: - if plsql_block[-1:] == '/': + if plsql_block[-1:] == "/": plsql_block = plsql_block[:-1].strip() else: break @@ -211,11 +213,11 @@ def get_full_sqlitem_list(full_sql, db_name): str_plsql_type = search_result.groups()[0] idx = str_plsql_match.index(str_plsql_type) - nm_str = str_plsql_match[idx + len(str_plsql_type):].strip() + nm_str = str_plsql_match[idx + len(str_plsql_type) :].strip() - if nm_str[-1:] == '(': + if nm_str[-1:] == "(": nm_str = nm_str[:-1] - nm_list = nm_str.split('.') + nm_list = nm_str.split(".") if len(nm_list) > 1: # 带有属主的对象名, 形如object_owner.object_name @@ -246,28 +248,32 @@ def get_full_sqlitem_list(full_sql, db_name): object_name = nm_list[0].upper().strip() tmp_object_type = str_plsql_type.upper() - tmp_stmt_type = 'PLSQL' - if tmp_object_type == 'VIEW': - tmp_stmt_type = 'SQL' - - item = SqlItem(statement=plsql_block, - stmt_type=tmp_stmt_type, - object_owner=object_owner, - object_type=tmp_object_type, - object_name=object_name) + tmp_stmt_type = "PLSQL" + if tmp_object_type == "VIEW": + tmp_stmt_type = "SQL" + + item = SqlItem( + statement=plsql_block, + stmt_type=tmp_stmt_type, + object_owner=object_owner, + object_type=tmp_object_type, + object_name=object_name, + ) list.append(item) else: # 未检索到关键字, 属于情况2, 匿名可执行块 it's ANONYMOUS - item = SqlItem(statement=plsql_block.strip(), - stmt_type='PLSQL', - object_owner=db_name, - object_type='ANONYMOUS', - object_name='ANONYMOUS') + item = SqlItem( + statement=plsql_block.strip(), + stmt_type="PLSQL", + object_owner=db_name, + object_type="ANONYMOUS", + object_name="ANONYMOUS", + ) list.append(item) if length > pos + 2: # 处理$$之后的那些语句, 默认为单条可执行SQL的集合 - sql_area = sql[pos + 2:].strip() + sql_area = sql[pos + 2 :].strip() if len(sql_area) > 0: tmp_list = get_base_sqlitem_list(sql_area) list.extend(tmp_list) @@ -287,18 +293,22 @@ def get_full_sqlitem_list(full_sql, db_name): def get_exec_sqlitem_list(reviewResult, db_name): - """ 根据审核结果生成新的SQL列表 + """根据审核结果生成新的SQL列表 :param reviewResult: SQL审核结果列表 :param db_name: :return: """ list = [] - list.append(SqlItem(statement=f" ALTER SESSION SET CURRENT_SCHEMA = \"{db_name}\" ")) + list.append(SqlItem(statement=f' ALTER SESSION SET CURRENT_SCHEMA = "{db_name}" ')) for item in reviewResult: - list.append(SqlItem(statement=item['sql'], - stmt_type=item['stmt_type'], - object_owner=item['object_owner'], - object_type=item['object_type'], - object_name=item['object_name'])) + list.append( + SqlItem( + statement=item["sql"], + stmt_type=item["stmt_type"], + object_owner=item["object_owner"], + object_type=item["object_type"], + object_name=item["object_name"], + ) + ) return list diff --git a/sql/utils/ssh_tunnel.py b/sql/utils/ssh_tunnel.py index e40c9861e7..2946eb5f1f 100644 --- a/sql/utils/ssh_tunnel.py +++ b/sql/utils/ssh_tunnel.py @@ -14,7 +14,18 @@ class SSHConnection(object): """ ssh隧道连接类,用于映射ssh隧道端口到本地,连接结束时需要清理 """ - def __init__(self, host, port, tun_host, tun_port, tun_user, tun_password, pkey, pkey_password): + + def __init__( + self, + host, + port, + tun_host, + tun_port, + tun_user, + tun_password, + pkey, + pkey_password, + ): self.host = host self.port = int(port) self.tun_host = tun_host @@ -26,7 +37,9 @@ def __init__(self, host, port, tun_host, tun_port, tun_user, tun_password, pkey, private_key_file_obj = io.StringIO() private_key_file_obj.write(pkey) private_key_file_obj.seek(0) - self.private_key = RSAKey.from_private_key(private_key_file_obj, password=pkey_password) + self.private_key = RSAKey.from_private_key( + private_key_file_obj, password=pkey_password + ) self.server = SSHTunnelForwarder( ssh_address_or_host=(self.tun_host, self.tun_port), ssh_username=self.tun_user, diff --git a/sql/utils/tasks.py b/sql/utils/tasks.py index 97de47f108..793feced50 100644 --- a/sql/utils/tasks.py +++ b/sql/utils/tasks.py @@ -4,30 +4,50 @@ import logging -logger = logging.getLogger('default') +logger = logging.getLogger("default") def add_sql_schedule(name, run_date, workflow_id): """添加/修改sql定时任务""" del_schedule(name) - schedule('sql.utils.execute_sql.execute', workflow_id, - hook='sql.utils.execute_sql.execute_callback', - name=name, schedule_type='O', next_run=run_date, repeats=1, timeout=-1) + schedule( + "sql.utils.execute_sql.execute", + workflow_id, + hook="sql.utils.execute_sql.execute_callback", + name=name, + schedule_type="O", + next_run=run_date, + repeats=1, + timeout=-1, + ) logger.debug(f"添加SQL定时执行任务:{name} 执行时间:{run_date}") def add_kill_conn_schedule(name, run_date, instance_id, thread_id): """添加/修改终止数据库连接的定时任务""" del_schedule(name) - schedule('sql.query.kill_query_conn', instance_id, thread_id, - name=name, schedule_type='O', next_run=run_date, repeats=1, timeout=-1) + schedule( + "sql.query.kill_query_conn", + instance_id, + thread_id, + name=name, + schedule_type="O", + next_run=run_date, + repeats=1, + timeout=-1, + ) def add_sync_ding_user_schedule(): """添加钉钉同步用户定时任务""" - del_schedule(name='同步钉钉用户ID') - schedule('common.utils.ding_api.sync_ding_user_id', - name='同步钉钉用户ID', schedule_type='D', repeats=-1, timeout=-1) + del_schedule(name="同步钉钉用户ID") + schedule( + "common.utils.ding_api.sync_ding_user_id", + name="同步钉钉用户ID", + schedule_type="D", + repeats=-1, + timeout=-1, + ) def del_schedule(name): @@ -35,7 +55,7 @@ def del_schedule(name): try: sql_schedule = Schedule.objects.get(name=name) Schedule.delete(sql_schedule) - logger.debug(f'删除schedule:{name}') + logger.debug(f"删除schedule:{name}") except Schedule.DoesNotExist: pass diff --git a/sql/utils/tests.py b/sql/utils/tests.py index 185224a41b..404fb4c356 100644 --- a/sql/utils/tests.py +++ b/sql/utils/tests.py @@ -23,11 +23,30 @@ from common.config import SysConfig from common.utils.const import WorkflowDict from sql.engines.models import ReviewResult, ReviewSet -from sql.models import Users, SqlWorkflow, SqlWorkflowContent, Instance, ResourceGroup, \ - WorkflowLog, WorkflowAudit, WorkflowAuditDetail, WorkflowAuditSetting, \ - QueryPrivilegesApply, DataMaskingRules, DataMaskingColumns, InstanceTag, ArchiveConfig +from sql.models import ( + Users, + SqlWorkflow, + SqlWorkflowContent, + Instance, + ResourceGroup, + WorkflowLog, + WorkflowAudit, + WorkflowAuditDetail, + WorkflowAuditSetting, + QueryPrivilegesApply, + DataMaskingRules, + DataMaskingColumns, + InstanceTag, + ArchiveConfig, +) from sql.utils.resource_group import user_groups, user_instances, auth_group_users -from sql.utils.sql_review import is_auto_review, can_execute, can_timingtask, can_cancel, on_correct_time_period +from sql.utils.sql_review import ( + is_auto_review, + can_execute, + can_timingtask, + can_cancel, + on_correct_time_period, +) from sql.utils.sql_utils import * from sql.utils.execute_sql import execute, execute_callback from sql.utils.tasks import add_sql_schedule, del_schedule, task_info @@ -35,7 +54,7 @@ from sql.utils.data_masking import data_masking, brute_mask, simple_column_mask User = Users -__author__ = 'hhyo' +__author__ = "hhyo" class TestSQLUtils(TestCase): @@ -46,8 +65,8 @@ def test_get_syntax_type(self): """ dml_sql = "select * from users;" ddl_sql = "alter table users add id not null default 0 comment 'id' " - self.assertEqual(get_syntax_type(dml_sql), 'DML') - self.assertEqual(get_syntax_type(ddl_sql), 'DDL') + self.assertEqual(get_syntax_type(dml_sql), "DML") + self.assertEqual(get_syntax_type(ddl_sql), "DDL") def test_get_syntax_type_by_re(self): """ @@ -56,10 +75,10 @@ def test_get_syntax_type_by_re(self): """ dml_sql = "select * from users;" ddl_sql = "alter table users add id int not null default 0 comment 'id' " - other_sql = 'show engine innodb status' - self.assertEqual(get_syntax_type(dml_sql, parser=False, db_type='mysql'), 'DML') - self.assertEqual(get_syntax_type(ddl_sql, parser=False, db_type='mysql'), 'DDL') - self.assertIsNone(get_syntax_type(other_sql, parser=False, db_type='mysql')) + other_sql = "show engine innodb status" + self.assertEqual(get_syntax_type(dml_sql, parser=False, db_type="mysql"), "DML") + self.assertEqual(get_syntax_type(ddl_sql, parser=False, db_type="mysql"), "DDL") + self.assertIsNone(get_syntax_type(other_sql, parser=False, db_type="mysql")) def test_remove_comments(self): """ @@ -72,12 +91,15 @@ def test_remove_comments(self): SELECT 1+1; -- This comment continues to the end of line""" sql3 = """/* this is an in-line comment */ SELECT 1 /* this is an in-line comment */ + 1;/* this is an in-line comment */""" - self.assertEqual(remove_comments(sql1, db_type='mysql'), - 'SELECT 1+1; # This comment continues to the end of line') - self.assertEqual(remove_comments(sql2, db_type='mysql'), - 'SELECT 1+1; -- This comment continues to the end of line') - self.assertEqual(remove_comments(sql3, db_type='mysql'), - 'SELECT 1 + 1;') + self.assertEqual( + remove_comments(sql1, db_type="mysql"), + "SELECT 1+1; # This comment continues to the end of line", + ) + self.assertEqual( + remove_comments(sql2, db_type="mysql"), + "SELECT 1+1; -- This comment continues to the end of line", + ) + self.assertEqual(remove_comments(sql3, db_type="mysql"), "SELECT 1 + 1;") def test_extract_tables_by_sql_parse(self): """ @@ -85,7 +107,10 @@ def test_extract_tables_by_sql_parse(self): :return: """ sql = "select * from user.users a join logs.log b on a.id=b.id;" - self.assertEqual(extract_tables(sql), [{'name': 'users', 'schema': 'user'}, {'name': 'log', 'schema': 'logs'}]) + self.assertEqual( + extract_tables(sql), + [{"name": "users", "schema": "user"}, {"name": "log", "schema": "logs"}], + ) def test_generate_sql_from_sql(self): """ @@ -94,9 +119,13 @@ def test_generate_sql_from_sql(self): """ text = "select * from sql_user;select * from sql_workflow;" rows = generate_sql(text) - self.assertListEqual(rows, [{'sql_id': 1, 'sql': 'select * from sql_user;'}, - {'sql_id': 2, 'sql': 'select * from sql_workflow;'}] - ) + self.assertListEqual( + rows, + [ + {"sql_id": 1, "sql": "select * from sql_user;"}, + {"sql_id": 2, "sql": "select * from sql_workflow;"}, + ], + ) def test_generate_sql_from_xml(self): """ @@ -120,9 +149,15 @@ def test_generate_sql_from_xml(self): """ rows = generate_sql(text) - self.assertEqual(rows, [{'sql_id': 'testParameters', - 'sql': '\nSELECT name,\n category,\n price\nFROM fruits\nWHERE category = ?\n AND price > ?'}] - ) + self.assertEqual( + rows, + [ + { + "sql_id": "testParameters", + "sql": "\nSELECT name,\n category,\n price\nFROM fruits\nWHERE category = ?\n AND price > ?", + } + ], + ) class TestSQLReview(TestCase): @@ -131,36 +166,38 @@ class TestSQLReview(TestCase): """ def setUp(self): - self.superuser = User.objects.create(username='super', is_superuser=True) - self.user = User.objects.create(username='user') + self.superuser = User.objects.create(username="super", is_superuser=True) + self.user = User.objects.create(username="user") # 使用 travis.ci 时实例和测试service保持一致 - self.master = Instance(instance_name='test_instance', type='master', db_type='mysql', - host=settings.DATABASES['default']['HOST'], - port=settings.DATABASES['default']['PORT'], - user=settings.DATABASES['default']['USER'], - password=settings.DATABASES['default']['PASSWORD']) + self.master = Instance( + instance_name="test_instance", + type="master", + db_type="mysql", + host=settings.DATABASES["default"]["HOST"], + port=settings.DATABASES["default"]["PORT"], + user=settings.DATABASES["default"]["USER"], + password=settings.DATABASES["default"]["PASSWORD"], + ) self.master.save() self.sys_config = SysConfig() self.client = Client() - self.group = ResourceGroup.objects.create(group_id=1, group_name='group_name') + self.group = ResourceGroup.objects.create(group_id=1, group_name="group_name") self.wf1 = SqlWorkflow.objects.create( - workflow_name='workflow_name', + workflow_name="workflow_name", group_id=self.group.group_id, group_name=self.group.group_name, engineer=self.superuser.username, engineer_display=self.superuser.display, - audit_auth_groups='audit_auth_groups', + audit_auth_groups="audit_auth_groups", create_time=datetime.datetime.now(), - status='workflow_review_pass', + status="workflow_review_pass", is_backup=True, instance=self.master, - db_name='db_name', + db_name="db_name", syntax_type=1, ) self.wfc1 = SqlWorkflowContent.objects.create( - workflow=self.wf1, - sql_content='some_sql', - execute_result='' + workflow=self.wf1, sql_content="some_sql", execute_result="" ) def tearDown(self): @@ -171,275 +208,370 @@ def tearDown(self): self.master.delete() self.sys_config.replace(json.dumps({})) - @patch('sql.engines.get_engine') - def test_auto_review_hit_review_regex(self, _get_engine, ): + @patch("sql.engines.get_engine") + def test_auto_review_hit_review_regex( + self, + _get_engine, + ): """ 测试自动审批通过的判定条件,命中判断正则 :return: """ # 开启自动审批设置 - self.sys_config.set('auto_review', 'true') - self.sys_config.set('auto_review_db_type', 'mysql') - self.sys_config.set('auto_review_regex', '^drop') # drop语句需要审批 - self.sys_config.set('auto_review_max_update_rows', '50') # update影响行数大于50需要审批 + self.sys_config.set("auto_review", "true") + self.sys_config.set("auto_review_db_type", "mysql") + self.sys_config.set("auto_review_regex", "^drop") # drop语句需要审批 + self.sys_config.set("auto_review_max_update_rows", "50") # update影响行数大于50需要审批 self.sys_config.get_all_config() # 修改工单为drop self.wfc1.sql_content = "drop table users;" - self.wfc1.save(update_fields=('sql_content',)) + self.wfc1.save(update_fields=("sql_content",)) r = is_auto_review(self.wfc1.workflow_id) self.assertFalse(r) - @patch('sql.engines.mysql.MysqlEngine.execute_check') - @patch('sql.engines.get_engine') + @patch("sql.engines.mysql.MysqlEngine.execute_check") + @patch("sql.engines.get_engine") def test_auto_review_gt_max_update_rows(self, _get_engine, _execute_check): """ 测试自动审批通过的判定条件,影响行数大于auto_review_max_update_rows :return: """ # 开启自动审批设置 - self.sys_config.set('auto_review', 'true') - self.sys_config.set('auto_review_db_type', 'mysql') - self.sys_config.set('auto_review_regex', '^drop') # drop语句需要审批 - self.sys_config.set('auto_review_max_update_rows', '2') # update影响行数大于2需要审批 + self.sys_config.set("auto_review", "true") + self.sys_config.set("auto_review_db_type", "mysql") + self.sys_config.set("auto_review_regex", "^drop") # drop语句需要审批 + self.sys_config.set("auto_review_max_update_rows", "2") # update影响行数大于2需要审批 self.sys_config.get_all_config() # 修改工单为update self.wfc1.sql_content = "update table users set email='';" - self.wfc1.save(update_fields=('sql_content',)) + self.wfc1.save(update_fields=("sql_content",)) # mock返回值,update影响行数=3 _execute_check.return_value.to_dict.return_value = [ - {"id": 1, "stage": "CHECKED", "errlevel": 0, "stagestatus": "Audit completed", "errormessage": "None", - "sql": "use archer_test", "affected_rows": 0, "sequence": "'0_0_0'", "backup_dbname": "None", - "execute_time": "0", "sqlsha1": "", "actual_affected_rows": 'null'}, - {"id": 2, "stage": "CHECKED", "errlevel": 0, "stagestatus": "Audit completed", "errormessage": "None", - "sql": "update table users set email=''", "affected_rows": 3, "sequence": "'0_0_1'", - "backup_dbname": "mysql_3306_archer_test", "execute_time": "0", "sqlsha1": "", - "actual_affected_rows": 'null'}] + { + "id": 1, + "stage": "CHECKED", + "errlevel": 0, + "stagestatus": "Audit completed", + "errormessage": "None", + "sql": "use archer_test", + "affected_rows": 0, + "sequence": "'0_0_0'", + "backup_dbname": "None", + "execute_time": "0", + "sqlsha1": "", + "actual_affected_rows": "null", + }, + { + "id": 2, + "stage": "CHECKED", + "errlevel": 0, + "stagestatus": "Audit completed", + "errormessage": "None", + "sql": "update table users set email=''", + "affected_rows": 3, + "sequence": "'0_0_1'", + "backup_dbname": "mysql_3306_archer_test", + "execute_time": "0", + "sqlsha1": "", + "actual_affected_rows": "null", + }, + ] r = is_auto_review(self.wfc1.workflow_id) self.assertFalse(r) - @patch('sql.engines.get_engine') + @patch("sql.engines.get_engine") def test_auto_review_true(self, _get_engine): """ 测试自动审批通过的判定条件, :return: """ # 开启自动审批设置 - self.sys_config.set('auto_review', 'true') - self.sys_config.set('auto_review_db_type', 'mysql') - self.sys_config.set('auto_review_regex', '^drop') # drop语句需要审批 - self.sys_config.set('auto_review_max_update_rows', '2') # update影响行数大于2需要审批 - self.sys_config.set('auto_review_tag', 'GA') # 仅GA开启自动审批 + self.sys_config.set("auto_review", "true") + self.sys_config.set("auto_review_db_type", "mysql") + self.sys_config.set("auto_review_regex", "^drop") # drop语句需要审批 + self.sys_config.set("auto_review_max_update_rows", "2") # update影响行数大于2需要审批 + self.sys_config.set("auto_review_tag", "GA") # 仅GA开启自动审批 self.sys_config.get_all_config() # 修改工单为update,mock返回值,update影响行数=3 self.wfc1.sql_content = "update table users set email='';" - self.wfc1.review_content = json.dumps([ - {"id": 1, "stage": "CHECKED", "errlevel": 0, "stagestatus": "Audit completed", "errormessage": "None", - "sql": "use archer_test", "affected_rows": 0, "sequence": "'0_0_0'", "backup_dbname": "None", - "execute_time": "0", "sqlsha1": "", "actual_affected_rows": 'null'}, - {"id": 2, "stage": "CHECKED", "errlevel": 0, "stagestatus": "Audit completed", "errormessage": "None", - "sql": "update table users set email=''", "affected_rows": 1, "sequence": "'0_0_1'", - "backup_dbname": "mysql_3306_archer_test", "execute_time": "0", "sqlsha1": "", - "actual_affected_rows": 'null'}]) - self.wfc1.save(update_fields=('sql_content', 'review_content')) + self.wfc1.review_content = json.dumps( + [ + { + "id": 1, + "stage": "CHECKED", + "errlevel": 0, + "stagestatus": "Audit completed", + "errormessage": "None", + "sql": "use archer_test", + "affected_rows": 0, + "sequence": "'0_0_0'", + "backup_dbname": "None", + "execute_time": "0", + "sqlsha1": "", + "actual_affected_rows": "null", + }, + { + "id": 2, + "stage": "CHECKED", + "errlevel": 0, + "stagestatus": "Audit completed", + "errormessage": "None", + "sql": "update table users set email=''", + "affected_rows": 1, + "sequence": "'0_0_1'", + "backup_dbname": "mysql_3306_archer_test", + "execute_time": "0", + "sqlsha1": "", + "actual_affected_rows": "null", + }, + ] + ) + self.wfc1.save(update_fields=("sql_content", "review_content")) # 修改工单实例标签 - tag, is_created = InstanceTag.objects.get_or_create(tag_code='GA', - defaults={'tag_name': '生产环境', 'active': True}) + tag, is_created = InstanceTag.objects.get_or_create( + tag_code="GA", defaults={"tag_name": "生产环境", "active": True} + ) self.wf1.instance.instance_tag.add(tag) r = is_auto_review(self.wfc1.workflow_id) self.assertTrue(r) - @patch('sql.engines.get_engine') + @patch("sql.engines.get_engine") def test_auto_review_false(self, _get_engine): """ 测试自动审批通过的判定条件, :return: """ # 开启自动审批设置 - self.sys_config.set('auto_review', 'true') - self.sys_config.set('auto_review_db_type', '') # 未配置auto_review_db_type需要审批 - self.sys_config.set('auto_review_regex', '^drop') # drop语句需要审批 - self.sys_config.set('auto_review_max_update_rows', '2') # update影响行数大于2需要审批 - self.sys_config.set('auto_review_tag', 'GA') # 仅GA开启自动审批 + self.sys_config.set("auto_review", "true") + self.sys_config.set("auto_review_db_type", "") # 未配置auto_review_db_type需要审批 + self.sys_config.set("auto_review_regex", "^drop") # drop语句需要审批 + self.sys_config.set("auto_review_max_update_rows", "2") # update影响行数大于2需要审批 + self.sys_config.set("auto_review_tag", "GA") # 仅GA开启自动审批 self.sys_config.get_all_config() # 修改工单为update,mock返回值,update影响行数=3 self.wfc1.sql_content = "update table users set email='';" - self.wfc1.review_content = json.dumps([ - {"id": 1, "stage": "CHECKED", "errlevel": 0, "stagestatus": "Audit completed", "errormessage": "None", - "sql": "use archer_test", "affected_rows": 0, "sequence": "'0_0_0'", "backup_dbname": "None", - "execute_time": "0", "sqlsha1": "", "actual_affected_rows": 'null'}, - {"id": 2, "stage": "CHECKED", "errlevel": 0, "stagestatus": "Audit completed", "errormessage": "None", - "sql": "update table users set email=''", "affected_rows": 1, "sequence": "'0_0_1'", - "backup_dbname": "mysql_3306_archer_test", "execute_time": "0", "sqlsha1": "", - "actual_affected_rows": 'null'}]) - self.wfc1.save(update_fields=('sql_content', 'review_content')) + self.wfc1.review_content = json.dumps( + [ + { + "id": 1, + "stage": "CHECKED", + "errlevel": 0, + "stagestatus": "Audit completed", + "errormessage": "None", + "sql": "use archer_test", + "affected_rows": 0, + "sequence": "'0_0_0'", + "backup_dbname": "None", + "execute_time": "0", + "sqlsha1": "", + "actual_affected_rows": "null", + }, + { + "id": 2, + "stage": "CHECKED", + "errlevel": 0, + "stagestatus": "Audit completed", + "errormessage": "None", + "sql": "update table users set email=''", + "affected_rows": 1, + "sequence": "'0_0_1'", + "backup_dbname": "mysql_3306_archer_test", + "execute_time": "0", + "sqlsha1": "", + "actual_affected_rows": "null", + }, + ] + ) + self.wfc1.save(update_fields=("sql_content", "review_content")) # 修改工单实例标签 - tag, is_created = InstanceTag.objects.get_or_create(tag_code='GA', - defaults={'tag_name': '生产环境', 'active': True}) + tag, is_created = InstanceTag.objects.get_or_create( + tag_code="GA", defaults={"tag_name": "生产环境", "active": True} + ) self.wf1.instance.instance_tag.add(tag) r = is_auto_review(self.wfc1.workflow_id) self.assertFalse(r) - def test_can_execute_for_resource_group(self, ): + def test_can_execute_for_resource_group( + self, + ): """ 测试是否能执行的判定条件,登录用户有资源组粒度执行权限,并且为组内用户 :return: """ # 修改工单为workflow_review_pass,登录用户有资源组粒度执行权限,并且为组内用户 - self.wf1.status = 'workflow_review_pass' - self.wf1.save(update_fields=('status',)) - sql_execute_for_resource_group = Permission.objects.get(codename='sql_execute_for_resource_group') + self.wf1.status = "workflow_review_pass" + self.wf1.save(update_fields=("status",)) + sql_execute_for_resource_group = Permission.objects.get( + codename="sql_execute_for_resource_group" + ) self.user.user_permissions.add(sql_execute_for_resource_group) self.user.resource_group.add(self.group) r = can_execute(user=self.user, workflow_id=self.wfc1.workflow_id) self.assertTrue(r) - def test_can_execute_true(self, ): + def test_can_execute_true( + self, + ): """ 测试是否能执行的判定条件,当前登录用户为提交人,并且有执行权限,工单状态为审核通过 :return: """ # 修改工单为workflow_review_pass,当前登录用户为提交人,并且有执行权限 - self.wf1.status = 'workflow_review_pass' + self.wf1.status = "workflow_review_pass" self.wf1.engineer = self.user.username - self.wf1.save(update_fields=('status', 'engineer')) - sql_execute = Permission.objects.get(codename='sql_execute') + self.wf1.save(update_fields=("status", "engineer")) + sql_execute = Permission.objects.get(codename="sql_execute") self.user.user_permissions.add(sql_execute) r = can_execute(user=self.user, workflow_id=self.wfc1.workflow_id) self.assertTrue(r) - def test_can_execute_workflow_timing_task(self, ): + def test_can_execute_workflow_timing_task( + self, + ): """ 测试是否能执行的判定条件,当前登录用户为提交人,并且有执行权限,工单状态为定时执行 :return: """ # 修改工单为workflow_review_pass,当前登录用户为提交人,并且有执行权限 - self.wf1.status = 'workflow_timingtask' + self.wf1.status = "workflow_timingtask" self.wf1.engineer = self.user.username - self.wf1.save(update_fields=('status', 'engineer')) - sql_execute = Permission.objects.get(codename='sql_execute') + self.wf1.save(update_fields=("status", "engineer")) + sql_execute = Permission.objects.get(codename="sql_execute") self.user.user_permissions.add(sql_execute) r = can_execute(user=self.user, workflow_id=self.wfc1.workflow_id) self.assertTrue(r) - def test_can_execute_false_no_permission(self, ): + def test_can_execute_false_no_permission( + self, + ): """ 当前登录用户为提交人,但是没有执行权限 :return: """ # 修改工单为workflow_review_pass,当前登录用户为提交人,并且有执行权限 - self.wf1.status = 'workflow_timingtask' + self.wf1.status = "workflow_timingtask" self.wf1.engineer = self.user.username - self.wf1.save(update_fields=('status', 'engineer')) + self.wf1.save(update_fields=("status", "engineer")) r = can_execute(user=self.user, workflow_id=self.wfc1.workflow_id) self.assertFalse(r) - def test_can_execute_false_not_in_group(self, ): + def test_can_execute_false_not_in_group( + self, + ): """ 当前登录用户为提交人,有资源组粒度执行权限,但是不是组内用户 :return: """ # 修改工单为workflow_review_pass,有资源组粒度执行权限,但是不是组内用户 - self.wf1.status = 'workflow_review_pass' - self.wf1.save(update_fields=('status',)) - sql_execute_for_resource_group = Permission.objects.get(codename='sql_execute_for_resource_group') + self.wf1.status = "workflow_review_pass" + self.wf1.save(update_fields=("status",)) + sql_execute_for_resource_group = Permission.objects.get( + codename="sql_execute_for_resource_group" + ) self.user.user_permissions.add(sql_execute_for_resource_group) r = can_execute(user=self.user, workflow_id=self.wfc1.workflow_id) self.assertFalse(r) - def test_can_execute_false_wrong_status(self, ): + def test_can_execute_false_wrong_status( + self, + ): """ 当前登录用户为提交人,前登录用户为提交人,并且有执行权限,但是工单状态为待审核 :return: """ # 修改工单为workflow_manreviewing,当前登录用户为提交人,并且有执行权限, 但是工单状态为待审核 - self.wf1.status = 'workflow_manreviewing' + self.wf1.status = "workflow_manreviewing" self.wf1.engineer = self.user.username - self.wf1.save(update_fields=('status', 'engineer')) - sql_execute = Permission.objects.get(codename='sql_execute') + self.wf1.save(update_fields=("status", "engineer")) + sql_execute = Permission.objects.get(codename="sql_execute") self.user.user_permissions.add(sql_execute) r = can_execute(user=self.user, workflow_id=self.wfc1.workflow_id) self.assertFalse(r) - def test_can_timingtask_true(self, ): + def test_can_timingtask_true( + self, + ): """ 测试是否能定时执行的判定条件,当前登录用户为提交人,并且有执行权限,工单状态为审核通过 :return: """ # 修改工单为workflow_review_pass,当前登录用户为提交人,并且有执行权限 - self.wf1.status = 'workflow_review_pass' + self.wf1.status = "workflow_review_pass" self.wf1.engineer = self.user.username - self.wf1.save(update_fields=('status', 'engineer')) - sql_execute = Permission.objects.get(codename='sql_execute') + self.wf1.save(update_fields=("status", "engineer")) + sql_execute = Permission.objects.get(codename="sql_execute") self.user.user_permissions.add(sql_execute) r = can_timingtask(user=self.user, workflow_id=self.wfc1.workflow_id) self.assertTrue(r) - def test_can_timingtask_false(self, ): + def test_can_timingtask_false( + self, + ): """ 测试是否能定时执行的判定条件,当前登录有执行权限,工单状态为审核通过,但用户不是提交人 :return: """ # 修改工单为workflow_review_pass,当前登录用户为提交人,并且有执行权限 - self.wf1.status = 'workflow_review_pass' + self.wf1.status = "workflow_review_pass" self.wf1.engineer = self.superuser.username - self.wf1.save(update_fields=('status', 'engineer')) - sql_execute = Permission.objects.get(codename='sql_execute') + self.wf1.save(update_fields=("status", "engineer")) + sql_execute = Permission.objects.get(codename="sql_execute") self.user.user_permissions.add(sql_execute) r = can_timingtask(user=self.user, workflow_id=self.wfc1.workflow_id) self.assertFalse(r) - @patch('sql.utils.workflow_audit.Audit.can_review') + @patch("sql.utils.workflow_audit.Audit.can_review") def test_can_cancel_true_for_apply_user(self, _can_review): """ 测试是否能取消,审核中的工单,提交人可终止 :return: """ # 修改工单为workflow_review_pass,当前登录用户为提交人 - self.wf1.status = 'workflow_manreviewing' + self.wf1.status = "workflow_manreviewing" self.wf1.engineer = self.user.username - self.wf1.save(update_fields=('status', 'engineer')) + self.wf1.save(update_fields=("status", "engineer")) _can_review.return_value = False r = can_cancel(user=self.user, workflow_id=self.wfc1.workflow_id) self.assertTrue(r) - @patch('sql.utils.workflow_audit.Audit.can_review') + @patch("sql.utils.workflow_audit.Audit.can_review") def test_can_cancel_true_for_audit_user(self, _can_review): """ 测试是否能取消,审核中的工单,审核人可终止 :return: """ # 修改工单为workflow_review_pass,当前登录用户为提交人 - self.wf1.status = 'workflow_manreviewing' + self.wf1.status = "workflow_manreviewing" self.wf1.engineer = self.superuser.username - self.wf1.save(update_fields=('status', 'engineer')) + self.wf1.save(update_fields=("status", "engineer")) _can_review.return_value = True r = can_cancel(user=self.user, workflow_id=self.wfc1.workflow_id) self.assertTrue(r) - @patch('sql.utils.sql_review.can_execute') + @patch("sql.utils.sql_review.can_execute") def test_can_cancel_true_for_execute_user(self, _can_execute): """ 测试是否能取消,审核通过但未执行的工单,有执行权限的用户终止 :return: """ # 修改工单为workflow_review_pass,当前登录用户为提交人 - self.wf1.status = 'workflow_review_pass' + self.wf1.status = "workflow_review_pass" self.wf1.engineer = self.user.username - self.wf1.save(update_fields=('status', 'engineer')) + self.wf1.save(update_fields=("status", "engineer")) _can_execute.return_value = True r = can_cancel(user=self.user, workflow_id=self.wfc1.workflow_id) self.assertTrue(r) - @patch('sql.utils.sql_review.can_execute') + @patch("sql.utils.sql_review.can_execute") def test_can_cancel_true_for_submit_user(self, _can_execute): """ 测试是否能取消,审核通过但未执行的工单,提交人可终止 :return: """ # 修改工单为workflow_review_pass,当前登录用户为提交人 - self.wf1.status = 'workflow_review_pass' + self.wf1.status = "workflow_review_pass" self.wf1.engineer = self.user.username - self.wf1.save(update_fields=('status', 'engineer')) + self.wf1.save(update_fields=("status", "engineer")) _can_execute.return_value = True r = can_cancel(user=self.user, workflow_id=self.wfc1.workflow_id) self.assertTrue(r) @@ -450,10 +582,12 @@ def test_on_correct_time_period(self): :return: """ # 修改工单为workflow_review_pass,当前登录用户为提交人 - self.wf1.run_date_start = '2019-06-15 11:10:00' - self.wf1.run_date_end = '2019-06-15 11:30:00' - self.wf1.save(update_fields=('run_date_start', 'run_date_end')) - run_date = datetime.datetime.strptime('2019-06-15 11:15:00', "%Y-%m-%d %H:%M:%S") + self.wf1.run_date_start = "2019-06-15 11:10:00" + self.wf1.run_date_end = "2019-06-15 11:30:00" + self.wf1.save(update_fields=("run_date_start", "run_date_end")) + run_date = datetime.datetime.strptime( + "2019-06-15 11:15:00", "%Y-%m-%d %H:%M:%S" + ) r = on_correct_time_period(self.wf1.id, run_date=run_date) self.assertTrue(r) @@ -463,77 +597,94 @@ def test_not_in_correct_time_period(self): :return: """ # 修改工单为workflow_review_pass,当前登录用户为提交人 - self.wf1.run_date_start = '2019-06-15 11:10:00' - self.wf1.run_date_end = '2019-06-15 11:30:00' - self.wf1.save(update_fields=('run_date_start', 'run_date_end')) - run_date = datetime.datetime.strptime('2019-06-15 11:45:00', "%Y-%m-%d %H:%M:%S") + self.wf1.run_date_start = "2019-06-15 11:10:00" + self.wf1.run_date_end = "2019-06-15 11:30:00" + self.wf1.save(update_fields=("run_date_start", "run_date_end")) + run_date = datetime.datetime.strptime( + "2019-06-15 11:45:00", "%Y-%m-%d %H:%M:%S" + ) r = on_correct_time_period(self.wf1.id, run_date=run_date) self.assertFalse(r) - @patch('sql.utils.sql_review.datetime') + @patch("sql.utils.sql_review.datetime") def test_now_on_correct_time_period(self, _datetime): """ 测试当前时间在可执行时间内 :return: """ # 修改工单为workflow_review_pass,当前登录用户为提交人 - self.wf1.run_date_start = '2019-06-15 11:10:00' - self.wf1.run_date_end = '2019-06-15 11:30:00' - self.wf1.save(update_fields=('run_date_start', 'run_date_end')) + self.wf1.run_date_start = "2019-06-15 11:10:00" + self.wf1.run_date_end = "2019-06-15 11:30:00" + self.wf1.save(update_fields=("run_date_start", "run_date_end")) _datetime.datetime.now.return_value = datetime.datetime.strptime( - '2019-06-15 11:15:00', "%Y-%m-%d %H:%M:%S") + "2019-06-15 11:15:00", "%Y-%m-%d %H:%M:%S" + ) r = on_correct_time_period(self.wf1.id) self.assertTrue(r) - @patch('sql.utils.sql_review.datetime') + @patch("sql.utils.sql_review.datetime") def test_now_not_in_correct_time_period(self, _datetime): """ 测试当前时间不在可执行时间内 :return: """ # 修改工单为workflow_review_pass,当前登录用户为提交人 - self.wf1.run_date_start = '2019-06-15 11:10:00' - self.wf1.run_date_end = '2019-06-15 11:30:00' - self.wf1.save(update_fields=('run_date_start', 'run_date_end')) + self.wf1.run_date_start = "2019-06-15 11:10:00" + self.wf1.run_date_end = "2019-06-15 11:30:00" + self.wf1.save(update_fields=("run_date_start", "run_date_end")) _datetime.datetime.now.return_value = datetime.datetime.strptime( - '2019-06-15 11:55:00', "%Y-%m-%d %H:%M:%S") + "2019-06-15 11:55:00", "%Y-%m-%d %H:%M:%S" + ) r = on_correct_time_period(self.wf1.id) self.assertFalse(r) class TestExecuteSql(TestCase): def setUp(self): - self.ins = Instance.objects.create(instance_name='some_ins', type='slave', db_type='mysql', - host='some_host', - port=3306, user='ins_user', password='some_str') + self.ins = Instance.objects.create( + instance_name="some_ins", + type="slave", + db_type="mysql", + host="some_host", + port=3306, + user="ins_user", + password="some_str", + ) self.wf = SqlWorkflow.objects.create( - workflow_name='some_name', + workflow_name="some_name", group_id=1, - group_name='g1', - engineer_display='', - audit_auth_groups='some_group', + group_name="g1", + engineer_display="", + audit_auth_groups="some_group", create_time=datetime.datetime.now(), - status='workflow_timingtask', + status="workflow_timingtask", is_backup=True, instance=self.ins, - db_name='some_db', - syntax_type=1 - ) - SqlWorkflowContent.objects.create(workflow=self.wf, - sql_content='some_sql', - execute_result=ReviewSet(rows=[ReviewResult( - id=0, - stage='Execute failed', - errlevel=2, - stagestatus='异常终止', - errormessage='', - sql='执行异常信息', - affected_rows=0, - actual_affected_rows=0, - sequence='0_0_0', - backup_dbname=None, - execute_time=0, - sqlsha1='')]).json()) + db_name="some_db", + syntax_type=1, + ) + SqlWorkflowContent.objects.create( + workflow=self.wf, + sql_content="some_sql", + execute_result=ReviewSet( + rows=[ + ReviewResult( + id=0, + stage="Execute failed", + errlevel=2, + stagestatus="异常终止", + errormessage="", + sql="执行异常信息", + affected_rows=0, + actual_affected_rows=0, + sequence="0_0_0", + backup_dbname=None, + execute_time=0, + sqlsha1="", + ) + ] + ).json(), + ) def tearDown(self): self.ins.delete() @@ -541,9 +692,9 @@ def tearDown(self): SqlWorkflowContent.objects.all().delete() WorkflowLog.objects.all().delete() - @patch('sql.utils.execute_sql.Audit') - @patch('sql.engines.mysql.MysqlEngine.execute_workflow') - @patch('sql.engines.get_engine') + @patch("sql.utils.execute_sql.Audit") + @patch("sql.engines.mysql.MysqlEngine.execute_workflow") + @patch("sql.engines.get_engine") def test_execute(self, _get_engine, _execute_workflow, _audit): _audit.detail_by_workflow_id.return_value.audit_id = 1 execute(self.wf.id) @@ -551,194 +702,216 @@ def test_execute(self, _get_engine, _execute_workflow, _audit): _audit.add_log.assert_called_with( audit_id=1, operation_type=5, - operation_type_desc='执行工单', - operation_info='系统定时执行工单', - operator='', - operator_display='系统', + operation_type_desc="执行工单", + operation_info="系统定时执行工单", + operator="", + operator_display="系统", ) - @patch('sql.utils.execute_sql.notify_for_execute') - @patch('sql.utils.execute_sql.Audit') + @patch("sql.utils.execute_sql.notify_for_execute") + @patch("sql.utils.execute_sql.Audit") def test_execute_callback_success(self, _audit, _notify): # 初始化工单执行返回对象 self.task_result = MagicMock() self.task_result.args = [self.wf.id] self.task_result.success = True self.task_result.stopped = datetime.datetime.now() - self.task_result.result.json.return_value = json.dumps([{'id': 1, 'sql': 'some_content'}]) - self.task_result.result.warning = '' - self.task_result.result.error = '' + self.task_result.result.json.return_value = json.dumps( + [{"id": 1, "sql": "some_content"}] + ) + self.task_result.result.warning = "" + self.task_result.result.error = "" _audit.detail_by_workflow_id.return_value.audit_id = 123 - _audit.add_log.return_value = 'any thing' + _audit.add_log.return_value = "any thing" # 先处理为执行中 - self.wf.status = 'workflow_executing' - self.wf.save(update_fields=['status']) + self.wf.status = "workflow_executing" + self.wf.save(update_fields=["status"]) execute_callback(self.task_result) - _audit.detail_by_workflow_id.assert_called_with(workflow_id=self.wf.id, workflow_type=2) + _audit.detail_by_workflow_id.assert_called_with( + workflow_id=self.wf.id, workflow_type=2 + ) _audit.add_log.assert_called_with( audit_id=123, operation_type=6, - operation_type_desc='执行结束', + operation_type_desc="执行结束", operation_info="执行结果:已正常结束", - operator='', - operator_display='系统', + operator="", + operator_display="系统", ) _notify.assert_called_once() - @patch('sql.utils.execute_sql.notify_for_execute') - @patch('sql.utils.execute_sql.Audit') + @patch("sql.utils.execute_sql.notify_for_execute") + @patch("sql.utils.execute_sql.Audit") def test_execute_callback_failure(self, _audit, _notify): # 初始化工单执行返回对象 self.task_result = MagicMock() self.task_result.args = [self.wf.id] self.task_result.success = False self.task_result.stopped = datetime.datetime.now() - self.task_result.result = '执行异常' + self.task_result.result = "执行异常" _audit.detail_by_workflow_id.return_value.audit_id = 123 - _audit.add_log.return_value = 'any thing' + _audit.add_log.return_value = "any thing" # 处理状态为执行中 - self.wf.status = 'workflow_executing' - self.wf.save(update_fields=['status']) + self.wf.status = "workflow_executing" + self.wf.save(update_fields=["status"]) execute_callback(self.task_result) - _audit.detail_by_workflow_id.assert_called_with(workflow_id=self.wf.id, workflow_type=2) + _audit.detail_by_workflow_id.assert_called_with( + workflow_id=self.wf.id, workflow_type=2 + ) _audit.add_log.assert_called_with( audit_id=123, operation_type=6, - operation_type_desc='执行结束', + operation_type_desc="执行结束", operation_info="执行结果:执行有异常", - operator='', - operator_display='系统', + operator="", + operator_display="系统", ) _notify.assert_called_once() - @patch('sql.utils.execute_sql.notify_for_execute') - @patch('sql.utils.execute_sql.Audit') + @patch("sql.utils.execute_sql.notify_for_execute") + @patch("sql.utils.execute_sql.Audit") def test_execute_callback_failure_no_execute_result(self, _audit, _notify): # 初始化工单执行返回对象 self.task_result = MagicMock() self.task_result.args = [self.wf.id] self.task_result.success = False self.task_result.stopped = datetime.datetime.now() - self.task_result.result = '执行异常' + self.task_result.result = "执行异常" _audit.detail_by_workflow_id.return_value.audit_id = 123 - _audit.add_log.return_value = 'any thing' + _audit.add_log.return_value = "any thing" # 删除execute_result、处理为执行中 - self.wf.sqlworkflowcontent.execute_result = '' + self.wf.sqlworkflowcontent.execute_result = "" self.wf.sqlworkflowcontent.save() - self.wf.status = 'workflow_executing' - self.wf.save(update_fields=['status']) + self.wf.status = "workflow_executing" + self.wf.save(update_fields=["status"]) execute_callback(self.task_result) - _audit.detail_by_workflow_id.assert_called_with(workflow_id=self.wf.id, workflow_type=2) + _audit.detail_by_workflow_id.assert_called_with( + workflow_id=self.wf.id, workflow_type=2 + ) _audit.add_log.assert_called_with( audit_id=123, operation_type=6, - operation_type_desc='执行结束', + operation_type_desc="执行结束", operation_info="执行结果:执行有异常", - operator='', - operator_display='系统', + operator="", + operator_display="系统", ) _notify.assert_called_once() class TestTasks(TestCase): def setUp(self): - self.Schedule = Schedule.objects.create(name='some_name') + self.Schedule = Schedule.objects.create(name="some_name") def tearDown(self): Schedule.objects.all().delete() - @patch('sql.utils.tasks.schedule') + @patch("sql.utils.tasks.schedule") def test_add_sql_schedule(self, _schedule): - add_sql_schedule('test', datetime.datetime.now(), 1) + add_sql_schedule("test", datetime.datetime.now(), 1) _schedule.assert_called_once() def test_del_schedule(self): - del_schedule('some_name') + del_schedule("some_name") with self.assertRaises(Schedule.DoesNotExist): - Schedule.objects.get(name='some_name') + Schedule.objects.get(name="some_name") def test_del_schedule_not_exists(self): - del_schedule('some_name1') + del_schedule("some_name1") def test_task_info(self): - task_info('some_name') + task_info("some_name") def test_task_info_not_exists(self): with self.assertRaises(Schedule.DoesNotExist): - Schedule.objects.get(name='some_name1') + Schedule.objects.get(name="some_name1") class TestAudit(TestCase): def setUp(self): self.sys_config = SysConfig() - self.user = User.objects.create(username='test_user', display='中文显示', is_active=True) - self.su = User.objects.create(username='s_user', display='中文显示', is_active=True, is_superuser=True) + self.user = User.objects.create( + username="test_user", display="中文显示", is_active=True + ) + self.su = User.objects.create( + username="s_user", display="中文显示", is_active=True, is_superuser=True + ) tomorrow = datetime.datetime.today() + datetime.timedelta(days=1) - self.ins = Instance.objects.create(instance_name='some_ins', type='slave', db_type='mysql', - host='some_host', - port=3306, user='ins_user', password='some_str') - self.res_group = ResourceGroup.objects.create(group_id=1, group_name='group_name') + self.ins = Instance.objects.create( + instance_name="some_ins", + type="slave", + db_type="mysql", + host="some_host", + port=3306, + user="ins_user", + password="some_str", + ) + self.res_group = ResourceGroup.objects.create( + group_id=1, group_name="group_name" + ) self.wf = SqlWorkflow.objects.create( - workflow_name='some_name', + workflow_name="some_name", group_id=1, - group_name='g1', - engineer_display='', - audit_auth_groups='some_audit_group', + group_name="g1", + engineer_display="", + audit_auth_groups="some_audit_group", create_time=datetime.datetime.now(), - status='workflow_timingtask', + status="workflow_timingtask", is_backup=True, instance=self.ins, - db_name='some_db', - syntax_type=1 + db_name="some_db", + syntax_type=1, + ) + SqlWorkflowContent.objects.create( + workflow=self.wf, sql_content="some_sql", execute_result="" ) - SqlWorkflowContent.objects.create(workflow=self.wf, - sql_content='some_sql', - execute_result='') self.query_apply_1 = QueryPrivilegesApply.objects.create( group_id=1, - group_name='some_name', - title='some_title1', - user_name='some_user', + group_name="some_name", + title="some_title1", + user_name="some_user", instance=self.ins, - db_list='some_db,some_db2', + db_list="some_db,some_db2", limit_num=100, valid_date=tomorrow, priv_type=1, status=0, - audit_auth_groups='some_audit_group' + audit_auth_groups="some_audit_group", ) self.archive_apply_1 = ArchiveConfig.objects.create( - title='title', + title="title", resource_group=self.res_group, - audit_auth_groups='some_audit_group', + audit_auth_groups="some_audit_group", src_instance=self.ins, - src_db_name='src_db_name', - src_table_name='src_table_name', + src_db_name="src_db_name", + src_table_name="src_table_name", dest_instance=self.ins, - dest_db_name='src_db_name', - dest_table_name='src_table_name', - condition='1=1', - mode='file', + dest_db_name="src_db_name", + dest_table_name="src_table_name", + condition="1=1", + mode="file", no_delete=True, sleep=1, - status=WorkflowDict.workflow_status['audit_wait'], + status=WorkflowDict.workflow_status["audit_wait"], state=False, - user_name='some_user', - user_display='display', + user_name="some_user", + user_display="display", ) self.audit = WorkflowAudit.objects.create( group_id=1, - group_name='some_group', + group_name="some_group", workflow_id=1, workflow_type=1, - workflow_title='申请标题', - workflow_remark='申请备注', - audit_auth_groups='1,2,3', - current_audit='1', - next_audit='2', - current_status=0) - self.wl = WorkflowLog.objects.create(audit_id=self.audit.audit_id, - operation_type=1) + workflow_title="申请标题", + workflow_remark="申请备注", + audit_auth_groups="1,2,3", + current_audit="1", + next_audit="2", + current_status=0, + ) + self.wl = WorkflowLog.objects.create( + audit_id=self.audit.audit_id, operation_type=1 + ) def tearDown(self): self.sys_config.purge() @@ -754,224 +927,247 @@ def tearDown(self): ArchiveConfig.objects.all().delete() def test_audit_add_query(self): - """ 测试添加查询审核工单""" + """测试添加查询审核工单""" result = Audit.add(1, self.query_apply_1.apply_id) - audit_id = result['data']['audit_id'] - workflow_status = result['data']['workflow_status'] - self.assertEqual(workflow_status, WorkflowDict.workflow_status['audit_wait']) + audit_id = result["data"]["audit_id"] + workflow_status = result["data"]["workflow_status"] + self.assertEqual(workflow_status, WorkflowDict.workflow_status["audit_wait"]) audit_detail = WorkflowAudit.objects.get(audit_id=audit_id) # 当前审批 - self.assertEqual(audit_detail.current_audit, 'some_audit_group') + self.assertEqual(audit_detail.current_audit, "some_audit_group") # 无下级审批 - self.assertEqual(audit_detail.next_audit, '-1') + self.assertEqual(audit_detail.next_audit, "-1") # 验证日志 log_info = WorkflowLog.objects.filter(audit_id=audit_id).first() self.assertEqual(log_info.operation_type, 0) - self.assertEqual(log_info.operation_type_desc, '提交') - self.assertIn('等待审批,审批流程:', log_info.operation_info) + self.assertEqual(log_info.operation_type_desc, "提交") + self.assertIn("等待审批,审批流程:", log_info.operation_info) def test_audit_add_sqlreview(self): - """ 测试添加上线审核工单""" + """测试添加上线审核工单""" result = Audit.add(2, self.wf.id) - audit_id = result['data']['audit_id'] - workflow_status = result['data']['workflow_status'] - self.assertEqual(workflow_status, WorkflowDict.workflow_status['audit_wait']) + audit_id = result["data"]["audit_id"] + workflow_status = result["data"]["workflow_status"] + self.assertEqual(workflow_status, WorkflowDict.workflow_status["audit_wait"]) audit_detail = WorkflowAudit.objects.get(audit_id=audit_id) # 当前审批 - self.assertEqual(audit_detail.current_audit, 'some_audit_group') + self.assertEqual(audit_detail.current_audit, "some_audit_group") # 无下级审批 - self.assertEqual(audit_detail.next_audit, '-1') + self.assertEqual(audit_detail.next_audit, "-1") # 验证日志 log_info = WorkflowLog.objects.filter(audit_id=audit_id).first() self.assertEqual(log_info.operation_type, 0) - self.assertEqual(log_info.operation_type_desc, '提交') - self.assertIn('等待审批,审批流程:', log_info.operation_info) + self.assertEqual(log_info.operation_type_desc, "提交") + self.assertIn("等待审批,审批流程:", log_info.operation_info) def test_audit_add_archive_review(self): - """ 测试添加数据归档工单""" + """测试添加数据归档工单""" result = Audit.add(3, self.archive_apply_1.id) - audit_id = result['data']['audit_id'] - workflow_status = result['data']['workflow_status'] - self.assertEqual(workflow_status, WorkflowDict.workflow_status['audit_wait']) + audit_id = result["data"]["audit_id"] + workflow_status = result["data"]["workflow_status"] + self.assertEqual(workflow_status, WorkflowDict.workflow_status["audit_wait"]) audit_detail = WorkflowAudit.objects.get(audit_id=audit_id) # 当前审批 - self.assertEqual(audit_detail.current_audit, 'some_audit_group') + self.assertEqual(audit_detail.current_audit, "some_audit_group") # 无下级审批 - self.assertEqual(audit_detail.next_audit, '-1') + self.assertEqual(audit_detail.next_audit, "-1") # 验证日志 log_info = WorkflowLog.objects.filter(audit_id=audit_id).first() self.assertEqual(log_info.operation_type, 0) - self.assertEqual(log_info.operation_type_desc, '提交') - self.assertIn('等待审批,审批流程:', log_info.operation_info) + self.assertEqual(log_info.operation_type_desc, "提交") + self.assertIn("等待审批,审批流程:", log_info.operation_info) def test_audit_add_wrong_type(self): - """ 测试添加不存在的类型""" - with self.assertRaisesMessage(Exception, '工单类型不存在'): + """测试添加不存在的类型""" + with self.assertRaisesMessage(Exception, "工单类型不存在"): Audit.add(4, 1) def test_audit_add_settings_not_exists(self): - """ 测试审批流程未配置""" - self.wf.audit_auth_groups = '' + """测试审批流程未配置""" + self.wf.audit_auth_groups = "" self.wf.save() - with self.assertRaisesMessage(Exception, '审批流程不能为空,请先配置审批流程'): + with self.assertRaisesMessage(Exception, "审批流程不能为空,请先配置审批流程"): Audit.add(2, self.wf.id) def test_audit_add_duplicate(self): """测试重复提交""" Audit.add(2, self.wf.id) - with self.assertRaisesMessage(Exception, '该工单当前状态为待审核,请勿重复提交'): + with self.assertRaisesMessage(Exception, "该工单当前状态为待审核,请勿重复提交"): Audit.add(2, self.wf.id) - @patch('sql.utils.workflow_audit.is_auto_review', return_value=True) + @patch("sql.utils.workflow_audit.is_auto_review", return_value=True) def test_audit_add_auto_review(self, _is_auto_review): """测试提交自动审核通过""" - self.sys_config.set('auto_review', 'true') + self.sys_config.set("auto_review", "true") result = Audit.add(2, self.wf.id) - audit_id = result['data']['audit_id'] - workflow_status = result['data']['workflow_status'] - self.assertEqual(workflow_status, WorkflowDict.workflow_status['audit_success']) + audit_id = result["data"]["audit_id"] + workflow_status = result["data"]["workflow_status"] + self.assertEqual(workflow_status, WorkflowDict.workflow_status["audit_success"]) audit_detail = WorkflowAudit.objects.get(audit_id=audit_id) # 无下级审批 - self.assertEqual(audit_detail.next_audit, '-1') + self.assertEqual(audit_detail.next_audit, "-1") # 验证日志 log_info = WorkflowLog.objects.filter(audit_id=audit_id).first() self.assertEqual(log_info.operation_type, 0) - self.assertEqual(log_info.operation_type_desc, '提交') - self.assertEqual(log_info.operation_info, '无需审批,系统直接审核通过') + self.assertEqual(log_info.operation_type_desc, "提交") + self.assertEqual(log_info.operation_info, "无需审批,系统直接审核通过") def test_audit_add_multiple_audit(self): """测试提交多级审核""" - self.wf.audit_auth_groups = '1,2,3' + self.wf.audit_auth_groups = "1,2,3" self.wf.save() result = Audit.add(2, self.wf.id) - audit_id = result['data']['audit_id'] - workflow_status = result['data']['workflow_status'] + audit_id = result["data"]["audit_id"] + workflow_status = result["data"]["workflow_status"] audit_detail = WorkflowAudit.objects.get(audit_id=audit_id) - self.assertEqual(workflow_status, WorkflowDict.workflow_status['audit_wait']) + self.assertEqual(workflow_status, WorkflowDict.workflow_status["audit_wait"]) # 存在下级审批 - self.assertEqual(audit_detail.current_audit, '1') - self.assertEqual(audit_detail.next_audit, '2') + self.assertEqual(audit_detail.current_audit, "1") + self.assertEqual(audit_detail.next_audit, "2") # 验证日志 log_info = WorkflowLog.objects.filter(audit_id=audit_id).first() self.assertEqual(log_info.operation_type, 0) - self.assertEqual(log_info.operation_type_desc, '提交') - self.assertIn('等待审批,审批流程:', log_info.operation_info) + self.assertEqual(log_info.operation_type_desc, "提交") + self.assertIn("等待审批,审批流程:", log_info.operation_info) def test_audit_success_not_exists_next(self): """测试审核通过、无下一级""" - self.audit.current_audit = '3' - self.audit.next_audit = '-1' + self.audit.current_audit = "3" + self.audit.next_audit = "-1" self.audit.save() - result = Audit.audit(self.audit.audit_id, - WorkflowDict.workflow_status['audit_success'], - self.user.username, - '通过') + result = Audit.audit( + self.audit.audit_id, + WorkflowDict.workflow_status["audit_success"], + self.user.username, + "通过", + ) audit_id = self.audit.audit_id - workflow_status = result['data']['workflow_status'] + workflow_status = result["data"]["workflow_status"] audit_detail = WorkflowAudit.objects.get(audit_id=audit_id) - self.assertEqual(workflow_status, WorkflowDict.workflow_status['audit_success']) + self.assertEqual(workflow_status, WorkflowDict.workflow_status["audit_success"]) # 不存在下级审批 - self.assertEqual(audit_detail.next_audit, '-1') + self.assertEqual(audit_detail.next_audit, "-1") # 验证日志 - log_info = WorkflowLog.objects.filter(audit_id=audit_id).order_by('-id').first() + log_info = WorkflowLog.objects.filter(audit_id=audit_id).order_by("-id").first() self.assertEqual(log_info.operator, self.user.username) self.assertEqual(log_info.operator_display, self.user.display) self.assertEqual(log_info.operation_type, 1) - self.assertEqual(log_info.operation_type_desc, '审批通过') - self.assertEqual(log_info.operation_info, f'审批备注:通过,下级审批:None') + self.assertEqual(log_info.operation_type_desc, "审批通过") + self.assertEqual(log_info.operation_info, f"审批备注:通过,下级审批:None") def test_audit_success_exists_next(self): """测试审核通过、存在下一级""" - self.audit.current_audit = '1' - self.audit.next_audit = '2' + self.audit.current_audit = "1" + self.audit.next_audit = "2" self.audit.save() - result = Audit.audit(self.audit.audit_id, - WorkflowDict.workflow_status['audit_success'], - self.user.username, - '通过') + result = Audit.audit( + self.audit.audit_id, + WorkflowDict.workflow_status["audit_success"], + self.user.username, + "通过", + ) audit_id = self.audit.audit_id - workflow_status = result['data']['workflow_status'] + workflow_status = result["data"]["workflow_status"] audit_detail = WorkflowAudit.objects.get(audit_id=audit_id) - self.assertEqual(workflow_status, WorkflowDict.workflow_status['audit_wait']) + self.assertEqual(workflow_status, WorkflowDict.workflow_status["audit_wait"]) # 存在下级审批 - self.assertEqual(audit_detail.next_audit, '3') + self.assertEqual(audit_detail.next_audit, "3") # 验证日志 - log_info = WorkflowLog.objects.filter(audit_id=audit_id).order_by('-id').first() + log_info = WorkflowLog.objects.filter(audit_id=audit_id).order_by("-id").first() self.assertEqual(log_info.operator, self.user.username) self.assertEqual(log_info.operator_display, self.user.display) self.assertEqual(log_info.operation_type, 1) - self.assertEqual(log_info.operation_type_desc, '审批通过') - self.assertEqual(log_info.operation_info, f'审批备注:通过,下级审批:2') + self.assertEqual(log_info.operation_type_desc, "审批通过") + self.assertEqual(log_info.operation_info, f"审批备注:通过,下级审批:2") def test_audit_reject(self): """测试审核不通过""" - result = Audit.audit(self.audit.audit_id, - WorkflowDict.workflow_status['audit_reject'], - self.user.username, - '不通过') + result = Audit.audit( + self.audit.audit_id, + WorkflowDict.workflow_status["audit_reject"], + self.user.username, + "不通过", + ) audit_id = self.audit.audit_id - workflow_status = result['data']['workflow_status'] + workflow_status = result["data"]["workflow_status"] audit_detail = WorkflowAudit.objects.get(audit_id=audit_id) - self.assertEqual(workflow_status, WorkflowDict.workflow_status['audit_reject']) + self.assertEqual(workflow_status, WorkflowDict.workflow_status["audit_reject"]) # 不存在下级审批 - self.assertEqual(audit_detail.next_audit, '-1') + self.assertEqual(audit_detail.next_audit, "-1") # 验证日志 - log_info = WorkflowLog.objects.filter(audit_id=audit_id).order_by('-id').first() + log_info = WorkflowLog.objects.filter(audit_id=audit_id).order_by("-id").first() self.assertEqual(log_info.operator, self.user.username) self.assertEqual(log_info.operator_display, self.user.display) self.assertEqual(log_info.operation_type, 2) - self.assertEqual(log_info.operation_type_desc, '审批不通过') - self.assertEqual(log_info.operation_info, f'审批备注:不通过') + self.assertEqual(log_info.operation_type_desc, "审批不通过") + self.assertEqual(log_info.operation_info, f"审批备注:不通过") def test_audit_abort(self): """测试取消审批""" self.audit.create_user = self.user.username self.audit.save() - result = Audit.audit(self.audit.audit_id, - WorkflowDict.workflow_status['audit_abort'], - self.user.username, - '取消') + result = Audit.audit( + self.audit.audit_id, + WorkflowDict.workflow_status["audit_abort"], + self.user.username, + "取消", + ) audit_id = self.audit.audit_id - workflow_status = result['data']['workflow_status'] + workflow_status = result["data"]["workflow_status"] audit_detail = WorkflowAudit.objects.get(audit_id=audit_id) - self.assertEqual(workflow_status, WorkflowDict.workflow_status['audit_abort']) + self.assertEqual(workflow_status, WorkflowDict.workflow_status["audit_abort"]) # 不存在下级审批 - self.assertEqual(audit_detail.next_audit, '-1') + self.assertEqual(audit_detail.next_audit, "-1") # 验证日志 - log_info = WorkflowLog.objects.filter(audit_id=audit_id).order_by('-id').first() + log_info = WorkflowLog.objects.filter(audit_id=audit_id).order_by("-id").first() self.assertEqual(log_info.operator, self.user.username) self.assertEqual(log_info.operator_display, self.user.display) self.assertEqual(log_info.operation_type, 3) - self.assertEqual(log_info.operation_type_desc, '审批取消') - self.assertEqual(log_info.operation_info, f'取消原因:取消') + self.assertEqual(log_info.operation_type_desc, "审批取消") + self.assertEqual(log_info.operation_info, f"取消原因:取消") def test_audit_wrong_exception(self): """测试审核异常的状态""" - with self.assertRaisesMessage(Exception, '审核异常'): - Audit.audit(self.audit.audit_id, 10, self.user.username, '') + with self.assertRaisesMessage(Exception, "审核异常"): + Audit.audit(self.audit.audit_id, 10, self.user.username, "") def test_audit_success_wrong_status(self): """测试审核通过,当前状态不是待审核""" self.audit.current_status = 1 self.audit.save() - with self.assertRaisesMessage(Exception, '工单不是待审核状态,请返回刷新'): - Audit.audit(self.audit.audit_id, WorkflowDict.workflow_status['audit_success'], self.user.username, '') + with self.assertRaisesMessage(Exception, "工单不是待审核状态,请返回刷新"): + Audit.audit( + self.audit.audit_id, + WorkflowDict.workflow_status["audit_success"], + self.user.username, + "", + ) def test_audit_reject_wrong_status(self): """测试审核不通过,当前状态不是待审核""" self.audit.current_status = 1 self.audit.save() - with self.assertRaisesMessage(Exception, '工单不是待审核状态,请返回刷新'): - Audit.audit(self.audit.audit_id, WorkflowDict.workflow_status['audit_reject'], self.user.username, '') + with self.assertRaisesMessage(Exception, "工单不是待审核状态,请返回刷新"): + Audit.audit( + self.audit.audit_id, + WorkflowDict.workflow_status["audit_reject"], + self.user.username, + "", + ) def test_audit_abort_wrong_status(self): """测试审核不通过,当前状态不是待审核""" self.audit.current_status = 2 self.audit.save() - with self.assertRaisesMessage(Exception, '工单不是待审核态/审核通过状态,请返回刷新'): - Audit.audit(self.audit.audit_id, WorkflowDict.workflow_status['audit_abort'], self.user.username, '') - - @patch('sql.utils.workflow_audit.user_groups', return_value=[]) + with self.assertRaisesMessage(Exception, "工单不是待审核态/审核通过状态,请返回刷新"): + Audit.audit( + self.audit.audit_id, + WorkflowDict.workflow_status["audit_abort"], + self.user.username, + "", + ) + + @patch("sql.utils.workflow_audit.user_groups", return_value=[]) def test_todo(self, _user_groups): """TODO 测试todo数量,未断言""" Audit.todo(self.user) @@ -986,140 +1182,164 @@ def test_detail(self): def test_detail_by_workflow_id(self): """测试通过业务id获取审核信息""" - self.audit.workflow_type = WorkflowDict.workflow_type['sqlreview'] + self.audit.workflow_type = WorkflowDict.workflow_type["sqlreview"] self.audit.workflow_id = self.wf.id self.audit.save() - result = Audit.detail_by_workflow_id(self.wf.id, WorkflowDict.workflow_type['sqlreview']) + result = Audit.detail_by_workflow_id( + self.wf.id, WorkflowDict.workflow_type["sqlreview"] + ) self.assertEqual(result, self.audit) result = Audit.detail_by_workflow_id(0, 0) self.assertEqual(result, None) def test_settings(self): """测试通过组和审核类型,获取审核配置信息""" - WorkflowAuditSetting.objects.create(workflow_type=1, group_id=1, audit_auth_groups='1,2,3') + WorkflowAuditSetting.objects.create( + workflow_type=1, group_id=1, audit_auth_groups="1,2,3" + ) result = Audit.settings(workflow_type=1, group_id=1) - self.assertEqual(result, '1,2,3') + self.assertEqual(result, "1,2,3") result = Audit.settings(0, 0) self.assertEqual(result, None) def test_change_settings_edit(self): """修改配置信息""" - ws = WorkflowAuditSetting.objects.create(workflow_type=1, group_id=1, audit_auth_groups='1,2,3') - Audit.change_settings(workflow_type=1, group_id=1, audit_auth_groups='1,2') + ws = WorkflowAuditSetting.objects.create( + workflow_type=1, group_id=1, audit_auth_groups="1,2,3" + ) + Audit.change_settings(workflow_type=1, group_id=1, audit_auth_groups="1,2") ws = WorkflowAuditSetting.objects.get(audit_setting_id=ws.audit_setting_id) - self.assertEqual(ws.audit_auth_groups, '1,2') + self.assertEqual(ws.audit_auth_groups, "1,2") def test_change_settings_add(self): """添加配置信息""" - Audit.change_settings(workflow_type=1, group_id=1, audit_auth_groups='1,2') + Audit.change_settings(workflow_type=1, group_id=1, audit_auth_groups="1,2") ws = WorkflowAuditSetting.objects.get(workflow_type=1, group_id=1) - self.assertEqual(ws.audit_auth_groups, '1,2') + self.assertEqual(ws.audit_auth_groups, "1,2") - @patch('sql.utils.workflow_audit.auth_group_users') - @patch('sql.utils.workflow_audit.Audit.detail_by_workflow_id') + @patch("sql.utils.workflow_audit.auth_group_users") + @patch("sql.utils.workflow_audit.Audit.detail_by_workflow_id") def test_can_review_sql_review(self, _detail_by_workflow_id, _auth_group_users): """测试判断用户当前是否是可审核上线工单,非管理员拥有权限""" - sql_review = Permission.objects.get(codename='sql_review') + sql_review = Permission.objects.get(codename="sql_review") self.user.user_permissions.add(sql_review) - aug = Group.objects.create(name='auth_group') + aug = Group.objects.create(name="auth_group") _detail_by_workflow_id.return_value.current_audit = aug.id _auth_group_users.return_value.filter.exists = True - self.audit.workflow_type = WorkflowDict.workflow_type['sqlreview'] + self.audit.workflow_type = WorkflowDict.workflow_type["sqlreview"] self.audit.workflow_id = self.wf.id self.audit.save() - r = Audit.can_review(self.user, self.audit.workflow_id, self.audit.workflow_type) + r = Audit.can_review( + self.user, self.audit.workflow_id, self.audit.workflow_type + ) self.assertEqual(r, True) - @patch('sql.utils.workflow_audit.auth_group_users') - @patch('sql.utils.workflow_audit.Audit.detail_by_workflow_id') + @patch("sql.utils.workflow_audit.auth_group_users") + @patch("sql.utils.workflow_audit.Audit.detail_by_workflow_id") def test_can_review_query_review(self, _detail_by_workflow_id, _auth_group_users): """测试判断用户当前是否是可审核查询工单,非管理员拥有权限""" - query_review = Permission.objects.get(codename='query_review') + query_review = Permission.objects.get(codename="query_review") self.user.user_permissions.add(query_review) - aug = Group.objects.create(name='auth_group') + aug = Group.objects.create(name="auth_group") _detail_by_workflow_id.return_value.current_audit = aug.id _auth_group_users.return_value.filter.exists = True - self.audit.workflow_type = WorkflowDict.workflow_type['query'] + self.audit.workflow_type = WorkflowDict.workflow_type["query"] self.audit.workflow_id = self.query_apply_1.apply_id self.audit.save() - r = Audit.can_review(self.user, self.audit.workflow_id, self.audit.workflow_type) + r = Audit.can_review( + self.user, self.audit.workflow_id, self.audit.workflow_type + ) self.assertEqual(r, True) - @patch('sql.utils.workflow_audit.auth_group_users') - @patch('sql.utils.workflow_audit.Audit.detail_by_workflow_id') - def test_can_review_sql_review_super(self, _detail_by_workflow_id, _auth_group_users): + @patch("sql.utils.workflow_audit.auth_group_users") + @patch("sql.utils.workflow_audit.Audit.detail_by_workflow_id") + def test_can_review_sql_review_super( + self, _detail_by_workflow_id, _auth_group_users + ): """测试判断用户当前是否是可审核查询工单,用户是管理员""" - aug = Group.objects.create(name='auth_group') + aug = Group.objects.create(name="auth_group") _detail_by_workflow_id.return_value.current_audit = aug.id _auth_group_users.return_value.filter.exists = True - self.audit.workflow_type = WorkflowDict.workflow_type['sqlreview'] + self.audit.workflow_type = WorkflowDict.workflow_type["sqlreview"] self.audit.workflow_id = self.wf.id self.audit.save() r = Audit.can_review(self.su, self.audit.workflow_id, self.audit.workflow_type) self.assertEqual(r, True) - @patch('sql.utils.workflow_audit.auth_group_users') - @patch('sql.utils.workflow_audit.Audit.detail_by_workflow_id') + @patch("sql.utils.workflow_audit.auth_group_users") + @patch("sql.utils.workflow_audit.Audit.detail_by_workflow_id") def test_can_review_wrong_status(self, _detail_by_workflow_id, _auth_group_users): """测试判断用户当前是否是可审核,非待审核工单""" - aug = Group.objects.create(name='auth_group') + aug = Group.objects.create(name="auth_group") _detail_by_workflow_id.return_value.current_audit = aug.id _auth_group_users.return_value.filter.exists = True - self.audit.workflow_type = WorkflowDict.workflow_type['sqlreview'] + self.audit.workflow_type = WorkflowDict.workflow_type["sqlreview"] self.audit.workflow_id = self.wf.id - self.audit.current_status = WorkflowDict.workflow_status['audit_success'] + self.audit.current_status = WorkflowDict.workflow_status["audit_success"] self.audit.save() - r = Audit.can_review(self.user, self.audit.workflow_id, self.audit.workflow_type) + r = Audit.can_review( + self.user, self.audit.workflow_id, self.audit.workflow_type + ) self.assertEqual(r, False) - @patch('sql.utils.workflow_audit.auth_group_users') - @patch('sql.utils.workflow_audit.Audit.detail_by_workflow_id') + @patch("sql.utils.workflow_audit.auth_group_users") + @patch("sql.utils.workflow_audit.Audit.detail_by_workflow_id") def test_can_review_no_prem(self, _detail_by_workflow_id, _auth_group_users): """测试判断用户当前是否是可审核,普通用户无权限""" - aug = Group.objects.create(name='auth_group') + aug = Group.objects.create(name="auth_group") _detail_by_workflow_id.return_value.current_audit = aug.id _auth_group_users.return_value.filter.exists = True - self.audit.workflow_type = WorkflowDict.workflow_type['sqlreview'] + self.audit.workflow_type = WorkflowDict.workflow_type["sqlreview"] self.audit.workflow_id = self.wf.id self.audit.save() - r = Audit.can_review(self.user, self.audit.workflow_id, self.audit.workflow_type) + r = Audit.can_review( + self.user, self.audit.workflow_id, self.audit.workflow_type + ) self.assertEqual(r, False) - @patch('sql.utils.workflow_audit.auth_group_users') - @patch('sql.utils.workflow_audit.Audit.detail_by_workflow_id') - def test_can_review_no_prem_exception(self, _detail_by_workflow_id, _auth_group_users): + @patch("sql.utils.workflow_audit.auth_group_users") + @patch("sql.utils.workflow_audit.Audit.detail_by_workflow_id") + def test_can_review_no_prem_exception( + self, _detail_by_workflow_id, _auth_group_users + ): """测试判断用户当前是否是可审核,权限组不存在""" - Group.objects.create(name='auth_group') + Group.objects.create(name="auth_group") _detail_by_workflow_id.side_effect = RuntimeError() _auth_group_users.return_value.filter.exists = True - self.audit.workflow_type = WorkflowDict.workflow_type['sqlreview'] + self.audit.workflow_type = WorkflowDict.workflow_type["sqlreview"] self.audit.workflow_id = self.wf.id self.audit.save() - with self.assertRaisesMessage(Exception, '当前审批auth_group_id不存在,请检查并清洗历史数据'): - Audit.can_review(self.user, self.audit.workflow_id, self.audit.workflow_type) + with self.assertRaisesMessage(Exception, "当前审批auth_group_id不存在,请检查并清洗历史数据"): + Audit.can_review( + self.user, self.audit.workflow_id, self.audit.workflow_type + ) def test_review_info_no_review(self): """测试获取当前工单审批流程和当前审核组,无需审批""" - self.audit.workflow_type = WorkflowDict.workflow_type['sqlreview'] + self.audit.workflow_type = WorkflowDict.workflow_type["sqlreview"] self.audit.workflow_id = self.wf.id - self.audit.audit_auth_groups = '' - self.audit.current_audit = '-1' + self.audit.audit_auth_groups = "" + self.audit.current_audit = "-1" self.audit.save() - audit_auth_group, current_audit_auth_group = Audit.review_info(self.audit.workflow_id, self.audit.workflow_type) - self.assertEqual(audit_auth_group, '无需审批') + audit_auth_group, current_audit_auth_group = Audit.review_info( + self.audit.workflow_id, self.audit.workflow_type + ) + self.assertEqual(audit_auth_group, "无需审批") self.assertEqual(current_audit_auth_group, None) def test_review_info(self): """测试获取当前工单审批流程和当前审核组,无需审批""" - aug = Group.objects.create(name='DBA') - self.audit.workflow_type = WorkflowDict.workflow_type['sqlreview'] + aug = Group.objects.create(name="DBA") + self.audit.workflow_type = WorkflowDict.workflow_type["sqlreview"] self.audit.workflow_id = self.wf.id self.audit.audit_auth_groups = str(aug.id) self.audit.current_audit = str(aug.id) self.audit.save() - audit_auth_group, current_audit_auth_group = Audit.review_info(self.audit.workflow_id, self.audit.workflow_type) - self.assertEqual(audit_auth_group, 'DBA') - self.assertEqual(current_audit_auth_group, 'DBA') + audit_auth_group, current_audit_auth_group = Audit.review_info( + self.audit.workflow_id, self.audit.workflow_type + ) + self.assertEqual(audit_auth_group, "DBA") + self.assertEqual(current_audit_auth_group, "DBA") def test_logs(self): """测试获取工单日志""" @@ -1129,37 +1349,43 @@ def test_logs(self): class TestDataMasking(TestCase): def setUp(self): - self.superuser = User.objects.create(username='super', is_superuser=True) - self.user = User.objects.create(username='user') - self.ins = Instance.objects.create(instance_name='some_ins', type='slave', db_type='mysql', - host='some_host', - port=3306, user='ins_user', password='some_str') + self.superuser = User.objects.create(username="super", is_superuser=True) + self.user = User.objects.create(username="user") + self.ins = Instance.objects.create( + instance_name="some_ins", + type="slave", + db_type="mysql", + host="some_host", + port=3306, + user="ins_user", + password="some_str", + ) self.sys_config = SysConfig() self.wf1 = SqlWorkflow.objects.create( - workflow_name='workflow_name', + workflow_name="workflow_name", group_id=1, - group_name='group_name', + group_name="group_name", engineer=self.superuser.username, engineer_display=self.superuser.display, - audit_auth_groups='audit_auth_groups', + audit_auth_groups="audit_auth_groups", create_time=datetime.datetime.now(), - status='workflow_review_pass', + status="workflow_review_pass", is_backup=True, instance=self.ins, - db_name='db_name', + db_name="db_name", syntax_type=1, ) DataMaskingRules.objects.create( - rule_type=1, - rule_regex='(.{3})(.*)(.{4})', - hide_group=2) + rule_type=1, rule_regex="(.{3})(.*)(.{4})", hide_group=2 + ) DataMaskingColumns.objects.create( rule_type=1, active=True, instance=self.ins, - table_schema='archer_test', - table_name='users', - column_name='phone') + table_schema="archer_test", + table_name="users", + column_name="phone", + ) def tearDown(self): User.objects.all().delete() @@ -1168,213 +1394,464 @@ def tearDown(self): DataMaskingColumns.objects.all().delete() DataMaskingRules.objects.all().delete() - @patch('sql.utils.data_masking.GoInceptionEngine') + @patch("sql.utils.data_masking.GoInceptionEngine") def test_data_masking_not_hit_rules(self, _inception): DataMaskingColumns.objects.all().delete() DataMaskingRules.objects.all().delete() _inception.return_value.query_data_masking.return_value = [ - {"index": 0, "field": "phone", "type": "varchar(80)", "table": "users", "schema": "archer_test", - "alias": "phone"} + { + "index": 0, + "field": "phone", + "type": "varchar(80)", + "table": "users", + "schema": "archer_test", + "alias": "phone", + } ] sql = """select phone from users;""" - rows = (('18888888888',), ('18888888889',), ('18888888810',)) - query_result = ReviewSet(column_list=['phone'], rows=rows, full_sql=sql) - r = data_masking(self.ins, 'archery', sql, query_result) + rows = (("18888888888",), ("18888888889",), ("18888888810",)) + query_result = ReviewSet(column_list=["phone"], rows=rows, full_sql=sql) + r = data_masking(self.ins, "archery", sql, query_result) print("test_data_masking_not_hit_rules:", r.rows) self.assertEqual(r, query_result) - @patch('sql.utils.data_masking.GoInceptionEngine') + @patch("sql.utils.data_masking.GoInceptionEngine") def test_data_masking_hit_rules_not_exists_star(self, _inception): _inception.return_value.query_data_masking.return_value = [ - {"index": 0, "field": "phone", "type": "varchar(80)", "table": "users", "schema": "archer_test", - "alias": "phone"} + { + "index": 0, + "field": "phone", + "type": "varchar(80)", + "table": "users", + "schema": "archer_test", + "alias": "phone", + } ] sql = """select phone from users;""" - rows = (('18888888888',), ('18888888889',), ('18888888810',)) - query_result = ReviewSet(column_list=['phone'], rows=rows, full_sql=sql) - r = data_masking(self.ins, 'archery', sql, query_result) + rows = (("18888888888",), ("18888888889",), ("18888888810",)) + query_result = ReviewSet(column_list=["phone"], rows=rows, full_sql=sql) + r = data_masking(self.ins, "archery", sql, query_result) print("test_data_masking_hit_rules_not_exists_star:", r.rows) - mask_result_rows = [['188****8888', ], ['188****8889', ], ['188****8810', ]] + mask_result_rows = [ + [ + "188****8888", + ], + [ + "188****8889", + ], + [ + "188****8810", + ], + ] self.assertEqual(r.rows, mask_result_rows) - @patch('sql.utils.data_masking.GoInceptionEngine') + @patch("sql.utils.data_masking.GoInceptionEngine") def test_data_masking_hit_rules_exists_star(self, _inception): """[*]""" _inception.return_value.query_data_masking.return_value = [ - {"index": 0, "field": "phone", "type": "varchar(80)", "table": "users", "schema": "archer_test", - "alias": "phone"} + { + "index": 0, + "field": "phone", + "type": "varchar(80)", + "table": "users", + "schema": "archer_test", + "alias": "phone", + } ] sql = """select * from users;""" - rows = (('18888888888',), ('18888888889',), ('18888888810',)) - query_result = ReviewSet(column_list=['phone'], rows=rows, full_sql=sql) - r = data_masking(self.ins, 'archery', sql, query_result) + rows = (("18888888888",), ("18888888889",), ("18888888810",)) + query_result = ReviewSet(column_list=["phone"], rows=rows, full_sql=sql) + r = data_masking(self.ins, "archery", sql, query_result) print("test_data_masking_hit_rules_exists_star:", r.rows) - mask_result_rows = [['188****8888', ], ['188****8889', ], ['188****8810', ]] + mask_result_rows = [ + [ + "188****8888", + ], + [ + "188****8889", + ], + [ + "188****8810", + ], + ] self.assertEqual(r.rows, mask_result_rows) - @patch('sql.utils.data_masking.GoInceptionEngine') + @patch("sql.utils.data_masking.GoInceptionEngine") def test_data_masking_hit_rules_star_and_column(self, _inception): """[*,column_a]""" _inception.return_value.query_data_masking.return_value = [ - {"index": 0, "field": "phone", "type": "varchar(80)", "table": "users", "schema": "archer_test", - "alias": "phone"}, - {"index": 1, "field": "phone", "type": "varchar(80)", "table": "users", "schema": "archer_test", - "alias": "phone"}, + { + "index": 0, + "field": "phone", + "type": "varchar(80)", + "table": "users", + "schema": "archer_test", + "alias": "phone", + }, + { + "index": 1, + "field": "phone", + "type": "varchar(80)", + "table": "users", + "schema": "archer_test", + "alias": "phone", + }, ] sql = """select *,phone from users;""" - rows = (('18888888888', '18888888888',), - ('18888888889', '18888888889',),) - query_result = ReviewSet(column_list=['phone', 'phone'], rows=rows, full_sql=sql) - r = data_masking(self.ins, 'archery', sql, query_result) + rows = ( + ( + "18888888888", + "18888888888", + ), + ( + "18888888889", + "18888888889", + ), + ) + query_result = ReviewSet( + column_list=["phone", "phone"], rows=rows, full_sql=sql + ) + r = data_masking(self.ins, "archery", sql, query_result) print("test_data_masking_hit_rules_star_and_column", r.rows) - mask_result_rows = [['188****8888', '188****8888', ], - ['188****8889', '188****8889', ]] + mask_result_rows = [ + [ + "188****8888", + "188****8888", + ], + [ + "188****8889", + "188****8889", + ], + ] self.assertEqual(r.rows, mask_result_rows) - @patch('sql.utils.data_masking.GoInceptionEngine') + @patch("sql.utils.data_masking.GoInceptionEngine") def test_data_masking_hit_rules_column_and_star(self, _inception): """[column_a, *]""" _inception.return_value.query_data_masking.return_value = [ - {"index": 0, "field": "phone", "type": "varchar(80)", "table": "users", "schema": "archer_test", - "alias": "phone"}, - {"index": 1, "field": "phone", "type": "varchar(80)", "table": "users", "schema": "archer_test", - "alias": "phone"} + { + "index": 0, + "field": "phone", + "type": "varchar(80)", + "table": "users", + "schema": "archer_test", + "alias": "phone", + }, + { + "index": 1, + "field": "phone", + "type": "varchar(80)", + "table": "users", + "schema": "archer_test", + "alias": "phone", + }, ] sql = """select phone,* from users;""" - rows = (('18888888888', '18888888888',), - ('18888888889', '18888888889',)) - query_result = ReviewSet(column_list=['phone', 'phone'], rows=rows, full_sql=sql) - r = data_masking(self.ins, 'archery', sql, query_result) + rows = ( + ( + "18888888888", + "18888888888", + ), + ( + "18888888889", + "18888888889", + ), + ) + query_result = ReviewSet( + column_list=["phone", "phone"], rows=rows, full_sql=sql + ) + r = data_masking(self.ins, "archery", sql, query_result) print("test_data_masking_hit_rules_column_and_star", r.rows) - mask_result_rows = [['188****8888', '188****8888', ], - ['188****8889', '188****8889', ]] + mask_result_rows = [ + [ + "188****8888", + "188****8888", + ], + [ + "188****8889", + "188****8889", + ], + ] self.assertEqual(r.rows, mask_result_rows) - @patch('sql.utils.data_masking.GoInceptionEngine') + @patch("sql.utils.data_masking.GoInceptionEngine") def test_data_masking_hit_rules_column_and_star_and_column(self, _inception): """[column_a,a.*,column_b]""" _inception.return_value.query_data_masking.return_value = [ - {"index": 0, "field": "phone", "type": "varchar(80)", "table": "users", "schema": "archer_test", - "alias": "phone"}, - {"index": 1, "field": "phone", "type": "varchar(80)", "table": "users", "schema": "archer_test", - "alias": "phone"}, - {"index": 2, "field": "phone", "type": "varchar(80)", "table": "users", "schema": "archer_test", - "alias": "phone"} + { + "index": 0, + "field": "phone", + "type": "varchar(80)", + "table": "users", + "schema": "archer_test", + "alias": "phone", + }, + { + "index": 1, + "field": "phone", + "type": "varchar(80)", + "table": "users", + "schema": "archer_test", + "alias": "phone", + }, + { + "index": 2, + "field": "phone", + "type": "varchar(80)", + "table": "users", + "schema": "archer_test", + "alias": "phone", + }, ] sql = """select phone,*,phone from users;""" - rows = (('18888888888', '18888888888', '18888888888',), - ('18888888889', '18888888889', '18888888889',)) - query_result = ReviewSet(column_list=['phone', 'phone', 'phone'], rows=rows, full_sql=sql) - r = data_masking(self.ins, 'archery', sql, query_result) + rows = ( + ( + "18888888888", + "18888888888", + "18888888888", + ), + ( + "18888888889", + "18888888889", + "18888888889", + ), + ) + query_result = ReviewSet( + column_list=["phone", "phone", "phone"], rows=rows, full_sql=sql + ) + r = data_masking(self.ins, "archery", sql, query_result) print("test_data_masking_hit_rules_column_and_star_and_column", r.rows) - mask_result_rows = [['188****8888', '188****8888', '188****8888', ], - ['188****8889', '188****8889', '188****8889', ]] + mask_result_rows = [ + [ + "188****8888", + "188****8888", + "188****8888", + ], + [ + "188****8889", + "188****8889", + "188****8889", + ], + ] self.assertEqual(r.rows, mask_result_rows) - @patch('sql.utils.data_masking.GoInceptionEngine') + @patch("sql.utils.data_masking.GoInceptionEngine") def test_data_masking_hit_rules_star_and_column_and_star(self, _inception): """[a.*, column_a, b.*]""" _inception.return_value.query_data_masking.return_value = [ - {"index": 0, "field": "phone", "type": "varchar(80)", "table": "users", "schema": "archer_test", - "alias": "phone"}, - {"index": 1, "field": "phone", "type": "varchar(80)", "table": "users", "schema": "archer_test", - "alias": "phone"}, - {"index": 2, "field": "phone", "type": "varchar(80)", "table": "users", "schema": "archer_test", - "alias": "phone"} + { + "index": 0, + "field": "phone", + "type": "varchar(80)", + "table": "users", + "schema": "archer_test", + "alias": "phone", + }, + { + "index": 1, + "field": "phone", + "type": "varchar(80)", + "table": "users", + "schema": "archer_test", + "alias": "phone", + }, + { + "index": 2, + "field": "phone", + "type": "varchar(80)", + "table": "users", + "schema": "archer_test", + "alias": "phone", + }, ] sql = """select a.*,phone,a.* from users a;""" - rows = (('18888888888', '18888888888', '18888888888',), - ('18888888889', '18888888889', '18888888889',)) - query_result = ReviewSet(column_list=['phone', 'phone', 'phone'], rows=rows, full_sql=sql) - r = data_masking(self.ins, 'archery', sql, query_result) + rows = ( + ( + "18888888888", + "18888888888", + "18888888888", + ), + ( + "18888888889", + "18888888889", + "18888888889", + ), + ) + query_result = ReviewSet( + column_list=["phone", "phone", "phone"], rows=rows, full_sql=sql + ) + r = data_masking(self.ins, "archery", sql, query_result) print("test_data_masking_hit_rules_star_and_column_and_star", r.rows) - mask_result_rows = [['188****8888', '188****8888', '188****8888', ], - ['188****8889', '188****8889', '188****8889', ]] + mask_result_rows = [ + [ + "188****8888", + "188****8888", + "188****8888", + ], + [ + "188****8889", + "188****8889", + "188****8889", + ], + ] self.assertEqual(r.rows, mask_result_rows) - @patch('sql.utils.data_masking.GoInceptionEngine') + @patch("sql.utils.data_masking.GoInceptionEngine") def test_data_masking_concat_function_support(self, _inception): """concat_函数支持""" _inception.return_value.query_data_masking.return_value = [ - {"index": 0, "field": "phone", "type": "varchar(80)", "table": "users", "schema": "archer_test", - "alias": "concat(phone,1)"} + { + "index": 0, + "field": "phone", + "type": "varchar(80)", + "table": "users", + "schema": "archer_test", + "alias": "concat(phone,1)", + } ] sql = """select concat(phone,1) from users;""" - rows = (('18888888888',), ('18888888889',), ('18888888810',)) - query_result = ReviewSet(column_list=['concat(phone,1)'], rows=rows, full_sql=sql) - r = data_masking(self.ins, 'archery', sql, query_result) - mask_result_rows = [['188****8888', ], ['188****8889', ], ['188****8810', ]] + rows = (("18888888888",), ("18888888889",), ("18888888810",)) + query_result = ReviewSet( + column_list=["concat(phone,1)"], rows=rows, full_sql=sql + ) + r = data_masking(self.ins, "archery", sql, query_result) + mask_result_rows = [ + [ + "188****8888", + ], + [ + "188****8889", + ], + [ + "188****8810", + ], + ] print("test_data_masking_concat_function_support", r.rows) self.assertEqual(r.rows, mask_result_rows) - @patch('sql.utils.data_masking.GoInceptionEngine') + @patch("sql.utils.data_masking.GoInceptionEngine") def test_data_masking_max_function_support(self, _inception): """max_函数支持""" _inception.return_value.query_data_masking.return_value = [ - {"index": 0, "field": "phone", "type": "varchar(80)", "table": "users", "schema": "archer_test", - "alias": "max(phone+1)"} + { + "index": 0, + "field": "phone", + "type": "varchar(80)", + "table": "users", + "schema": "archer_test", + "alias": "max(phone+1)", + } ] sql = """select max(phone+1) from users;""" - rows = (('18888888888',), ('18888888889',), ('18888888810',)) - query_result = ReviewSet(column_list=['max(phone+1)'], rows=rows, full_sql=sql) - mask_result_rows = [['188****8888', ], ['188****8889', ], ['188****8810', ]] - r = data_masking(self.ins, 'archery', sql, query_result) + rows = (("18888888888",), ("18888888889",), ("18888888810",)) + query_result = ReviewSet(column_list=["max(phone+1)"], rows=rows, full_sql=sql) + mask_result_rows = [ + [ + "188****8888", + ], + [ + "188****8889", + ], + [ + "188****8810", + ], + ] + r = data_masking(self.ins, "archery", sql, query_result) print("test_data_masking_max_function_support", r.rows) self.assertEqual(r.rows, mask_result_rows) - @patch('sql.utils.data_masking.GoInceptionEngine') + @patch("sql.utils.data_masking.GoInceptionEngine") def test_data_masking_union_support_keyword(self, _inception): """union关键字""" - self.sys_config.set('query_check', 'true') + self.sys_config.set("query_check", "true") self.sys_config.get_all_config() _inception.return_value.query_data_masking.return_value = [ - {'index': 0, 'field': 'phone', 'type': 'varchar(80)', 'table': 'users', 'schema': 'archer_test', - 'alias': 'phone'}, - {'index': 1, 'field': 'phone', 'type': 'varchar(80)', 'table': 'users', 'schema': 'archer_test', - 'alias': 'phone'} - + { + "index": 0, + "field": "phone", + "type": "varchar(80)", + "table": "users", + "schema": "archer_test", + "alias": "phone", + }, + { + "index": 1, + "field": "phone", + "type": "varchar(80)", + "table": "users", + "schema": "archer_test", + "alias": "phone", + }, + ] + sqls = [ + "select phone from users union select phone from users;", + "select phone from users union all select phone from users;", + ] + rows = (("18888888888",), ("18888888889",), ("18888888810",)) + mask_result_rows = [ + [ + "188****8888", + ], + [ + "188****8889", + ], + [ + "188****8810", + ], ] - sqls = ["select phone from users union select phone from users;", - "select phone from users union all select phone from users;"] - rows = (('18888888888',), ('18888888889',), ('18888888810',)) - mask_result_rows = [['188****8888', ], ['188****8889', ], ['188****8810', ]] for sql in sqls: - query_result = ReviewSet(column_list=['phone'], rows=rows, full_sql=sql) - r = data_masking(self.ins, 'archery', sql, query_result) + query_result = ReviewSet(column_list=["phone"], rows=rows, full_sql=sql) + r = data_masking(self.ins, "archery", sql, query_result) print("test_data_masking_union_support_keyword", r.rows) self.assertEqual(r.rows, mask_result_rows) def test_brute_mask(self): sql = """select * from users;""" - rows = (('18888888888',), ('18888888889',), ('18888888810',)) - query_result = ReviewSet(column_list=['phone'], rows=rows, full_sql=sql) + rows = (("18888888888",), ("18888888889",), ("18888888810",)) + query_result = ReviewSet(column_list=["phone"], rows=rows, full_sql=sql) r = brute_mask(self.ins, query_result) - mask_result_rows = [('188****8888',), ('188****8889',), ('188****8810',)] + mask_result_rows = [("188****8888",), ("188****8889",), ("188****8810",)] self.assertEqual(r.rows, mask_result_rows) def test_simple_column_mask(self): sql = """select * from users;""" - rows = (('18888888888',), ('18888888889',), ('18888888810',)) - query_result = ReviewSet(column_list=['phone'], rows=rows, full_sql=sql) + rows = (("18888888888",), ("18888888889",), ("18888888810",)) + query_result = ReviewSet(column_list=["phone"], rows=rows, full_sql=sql) r = simple_column_mask(self.ins, query_result) - mask_result_rows = [('188****8888',), ('188****8889',), ('188****8810',)] + mask_result_rows = [("188****8888",), ("188****8889",), ("188****8810",)] self.assertEqual(r.rows, mask_result_rows) class TestResourceGroup(TestCase): def setUp(self): self.sys_config = SysConfig() - self.user = User.objects.create(username='test_user', display='中文显示', is_active=True) - self.su = User.objects.create(username='s_user', display='中文显示', is_active=True, is_superuser=True) - self.ins1 = Instance.objects.create(instance_name='some_ins1', type='slave', db_type='mysql', - host='some_host', - port=3306, user='ins_user', password='some_str') - self.ins2 = Instance.objects.create(instance_name='some_ins2', type='slave', db_type='mysql', - host='some_host', - port=3306, user='ins_user', password='some_str') - self.rgp1 = ResourceGroup.objects.create(group_name='group1') - self.rgp2 = ResourceGroup.objects.create(group_name='group2') - self.agp = Group.objects.create(name='auth_group') + self.user = User.objects.create( + username="test_user", display="中文显示", is_active=True + ) + self.su = User.objects.create( + username="s_user", display="中文显示", is_active=True, is_superuser=True + ) + self.ins1 = Instance.objects.create( + instance_name="some_ins1", + type="slave", + db_type="mysql", + host="some_host", + port=3306, + user="ins_user", + password="some_str", + ) + self.ins2 = Instance.objects.create( + instance_name="some_ins2", + type="slave", + db_type="mysql", + host="some_host", + port=3306, + user="ins_user", + password="some_str", + ) + self.rgp1 = ResourceGroup.objects.create(group_name="group1") + self.rgp2 = ResourceGroup.objects.create(group_name="group2") + self.agp = Group.objects.create(name="auth_group") def tearDown(self): self.sys_config.purge() @@ -1426,5 +1903,7 @@ def test_auth_group_users(self): # 用户关联资源组 self.user.resource_group.add(self.rgp1) # 获取资源组内关联指定权限组的用户 - users = auth_group_users(auth_group_names=[self.agp.name], group_id=self.rgp1.group_id) + users = auth_group_users( + auth_group_names=[self.agp.name], group_id=self.rgp1.group_id + ) self.assertIn(self.user, users) diff --git a/sql/utils/workflow_audit.py b/sql/utils/workflow_audit.py index 2dbdfec5e5..b684328218 100644 --- a/sql/utils/workflow_audit.py +++ b/sql/utils/workflow_audit.py @@ -5,8 +5,17 @@ from sql.utils.resource_group import user_groups, auth_group_users from sql.utils.sql_review import is_auto_review from common.utils.const import WorkflowDict -from sql.models import WorkflowAudit, WorkflowAuditDetail, WorkflowAuditSetting, WorkflowLog, ResourceGroup, \ - SqlWorkflow, QueryPrivilegesApply, Users, ArchiveConfig +from sql.models import ( + WorkflowAudit, + WorkflowAuditDetail, + WorkflowAuditSetting, + WorkflowLog, + ResourceGroup, + SqlWorkflow, + QueryPrivilegesApply, + Users, + ArchiveConfig, +) from common.config import SysConfig @@ -14,17 +23,20 @@ class Audit(object): # 新增工单审核 @staticmethod def add(workflow_type, workflow_id): - result = {'status': 0, 'msg': '', 'data': []} + result = {"status": 0, "msg": "", "data": []} # 检查是否已存在待审核数据 - workflow_info = WorkflowAudit.objects.filter(workflow_type=workflow_type, workflow_id=workflow_id, - current_status=WorkflowDict.workflow_status['audit_wait']) + workflow_info = WorkflowAudit.objects.filter( + workflow_type=workflow_type, + workflow_id=workflow_id, + current_status=WorkflowDict.workflow_status["audit_wait"], + ) if len(workflow_info) >= 1: - result['msg'] = '该工单当前状态为待审核,请勿重复提交' - raise Exception(result['msg']) + result["msg"] = "该工单当前状态为待审核,请勿重复提交" + raise Exception(result["msg"]) # 获取工单信息 - if workflow_type == WorkflowDict.workflow_type['query']: + if workflow_type == WorkflowDict.workflow_type["query"]: workflow_detail = QueryPrivilegesApply.objects.get(apply_id=workflow_id) workflow_title = workflow_detail.title group_id = workflow_detail.group_id @@ -32,8 +44,8 @@ def add(workflow_type, workflow_id): create_user = workflow_detail.user_name create_user_display = workflow_detail.user_display audit_auth_groups = workflow_detail.audit_auth_groups - workflow_remark = '' - elif workflow_type == WorkflowDict.workflow_type['sqlreview']: + workflow_remark = "" + elif workflow_type == WorkflowDict.workflow_type["sqlreview"]: workflow_detail = SqlWorkflow.objects.get(pk=workflow_id) workflow_title = workflow_detail.workflow_name group_id = workflow_detail.group_id @@ -41,8 +53,8 @@ def add(workflow_type, workflow_id): create_user = workflow_detail.engineer create_user_display = workflow_detail.engineer_display audit_auth_groups = workflow_detail.audit_auth_groups - workflow_remark = '' - elif workflow_type == WorkflowDict.workflow_type['archive']: + workflow_remark = "" + elif workflow_type == WorkflowDict.workflow_type["archive"]: workflow_detail = ArchiveConfig.objects.get(pk=workflow_id) workflow_title = workflow_detail.title group_id = workflow_detail.resource_group.group_id @@ -50,25 +62,25 @@ def add(workflow_type, workflow_id): create_user = workflow_detail.user_name create_user_display = workflow_detail.user_display audit_auth_groups = workflow_detail.audit_auth_groups - workflow_remark = '' + workflow_remark = "" else: - result['msg'] = '工单类型不存在' - raise Exception(result['msg']) + result["msg"] = "工单类型不存在" + raise Exception(result["msg"]) # 校验是否配置审批流程 - if audit_auth_groups == '': - result['msg'] = '审批流程不能为空,请先配置审批流程' - raise Exception(result['msg']) + if audit_auth_groups == "": + result["msg"] = "审批流程不能为空,请先配置审批流程" + raise Exception(result["msg"]) else: - audit_auth_groups_list = audit_auth_groups.split(',') + audit_auth_groups_list = audit_auth_groups.split(",") # 判断是否无需审核,并且修改审批人为空 - if SysConfig().get('auto_review', False): - if workflow_type == WorkflowDict.workflow_type['sqlreview']: + if SysConfig().get("auto_review", False): + if workflow_type == WorkflowDict.workflow_type["sqlreview"]: if is_auto_review(workflow_id): sql_workflow = SqlWorkflow.objects.get(id=int(workflow_id)) - sql_workflow.audit_auth_groups = '无需审批' - sql_workflow.status = 'workflow_review_pass' + sql_workflow.audit_auth_groups = "无需审批" + sql_workflow.status = "workflow_review_pass" sql_workflow.save() audit_auth_groups_list = None @@ -82,23 +94,28 @@ def add(workflow_type, workflow_id): audit_detail.workflow_type = workflow_type audit_detail.workflow_title = workflow_title audit_detail.workflow_remark = workflow_remark - audit_detail.audit_auth_groups = '' - audit_detail.current_audit = '-1' - audit_detail.next_audit = '-1' - audit_detail.current_status = WorkflowDict.workflow_status['audit_success'] # 审核通过 + audit_detail.audit_auth_groups = "" + audit_detail.current_audit = "-1" + audit_detail.next_audit = "-1" + audit_detail.current_status = WorkflowDict.workflow_status[ + "audit_success" + ] # 审核通过 audit_detail.create_user = create_user audit_detail.create_user_display = create_user_display audit_detail.save() - result['data'] = {'workflow_status': WorkflowDict.workflow_status['audit_success']} - result['msg'] = '无审核配置,直接审核通过' + result["data"] = { + "workflow_status": WorkflowDict.workflow_status["audit_success"] + } + result["msg"] = "无审核配置,直接审核通过" # 增加工单日志 - Audit.add_log(audit_id=audit_detail.audit_id, - operation_type=0, - operation_type_desc='提交', - operation_info='无需审批,系统直接审核通过', - operator=audit_detail.create_user, - operator_display=audit_detail.create_user_display - ) + Audit.add_log( + audit_id=audit_detail.audit_id, + operation_type=0, + operation_type_desc="提交", + operation_info="无需审批,系统直接审核通过", + operator=audit_detail.create_user, + operator_display=audit_detail.create_user_display, + ) else: # 向审核主表插入待审核数据 audit_detail = WorkflowAudit() @@ -108,158 +125,191 @@ def add(workflow_type, workflow_id): audit_detail.workflow_type = workflow_type audit_detail.workflow_title = workflow_title audit_detail.workflow_remark = workflow_remark - audit_detail.audit_auth_groups = ','.join(audit_auth_groups_list) + audit_detail.audit_auth_groups = ",".join(audit_auth_groups_list) audit_detail.current_audit = audit_auth_groups_list[0] # 判断有无下级审核 if len(audit_auth_groups_list) == 1: - audit_detail.next_audit = '-1' + audit_detail.next_audit = "-1" else: audit_detail.next_audit = audit_auth_groups_list[1] - audit_detail.current_status = WorkflowDict.workflow_status['audit_wait'] + audit_detail.current_status = WorkflowDict.workflow_status["audit_wait"] audit_detail.create_user = create_user audit_detail.create_user_display = create_user_display audit_detail.save() - result['data'] = {'workflow_status': WorkflowDict.workflow_status['audit_wait']} + result["data"] = { + "workflow_status": WorkflowDict.workflow_status["audit_wait"] + } # 增加工单日志 - audit_auth_group, current_audit_auth_group = Audit.review_info(workflow_id, workflow_type) - Audit.add_log(audit_id=audit_detail.audit_id, - operation_type=0, - operation_type_desc='提交', - operation_info='等待审批,审批流程:{}'.format(audit_auth_group), - operator=audit_detail.create_user, - operator_display=audit_detail.create_user_display - ) + audit_auth_group, current_audit_auth_group = Audit.review_info( + workflow_id, workflow_type + ) + Audit.add_log( + audit_id=audit_detail.audit_id, + operation_type=0, + operation_type_desc="提交", + operation_info="等待审批,审批流程:{}".format(audit_auth_group), + operator=audit_detail.create_user, + operator_display=audit_detail.create_user_display, + ) # 增加审核id - result['data']['audit_id'] = audit_detail.audit_id + result["data"]["audit_id"] = audit_detail.audit_id # 返回添加结果 return result # 工单审核 @staticmethod def audit(audit_id, audit_status, audit_user, audit_remark): - result = {'status': 0, 'msg': 'ok', 'data': 0} + result = {"status": 0, "msg": "ok", "data": 0} audit_detail = WorkflowAudit.objects.get(audit_id=audit_id) # 不同审核状态 - if audit_status == WorkflowDict.workflow_status['audit_success']: + if audit_status == WorkflowDict.workflow_status["audit_success"]: # 判断当前工单是否为待审核状态 - if audit_detail.current_status != WorkflowDict.workflow_status['audit_wait']: - result['msg'] = '工单不是待审核状态,请返回刷新' - raise Exception(result['msg']) + if ( + audit_detail.current_status + != WorkflowDict.workflow_status["audit_wait"] + ): + result["msg"] = "工单不是待审核状态,请返回刷新" + raise Exception(result["msg"]) # 判断是否还有下一级审核 - if audit_detail.next_audit == '-1': + if audit_detail.next_audit == "-1": # 更新主表审核状态为审核通过 audit_result = WorkflowAudit() audit_result.audit_id = audit_id - audit_result.current_audit = '-1' - audit_result.current_status = WorkflowDict.workflow_status['audit_success'] - audit_result.save(update_fields=['current_audit', 'current_status']) + audit_result.current_audit = "-1" + audit_result.current_status = WorkflowDict.workflow_status[ + "audit_success" + ] + audit_result.save(update_fields=["current_audit", "current_status"]) else: # 更新主表审核下级审核组和当前审核组 audit_result = WorkflowAudit() audit_result.audit_id = audit_id - audit_result.current_status = WorkflowDict.workflow_status['audit_wait'] + audit_result.current_status = WorkflowDict.workflow_status["audit_wait"] audit_result.current_audit = audit_detail.next_audit # 判断后续是否还有下下一级审核组 - audit_auth_groups_list = audit_detail.audit_auth_groups.split(',') + audit_auth_groups_list = audit_detail.audit_auth_groups.split(",") for index, auth_group in enumerate(audit_auth_groups_list): if auth_group == audit_detail.next_audit: # 无下下级审核组 if index == len(audit_auth_groups_list) - 1: - audit_result.next_audit = '-1' + audit_result.next_audit = "-1" break # 存在下下级审核组 else: audit_result.next_audit = audit_auth_groups_list[index + 1] - audit_result.save(update_fields=['current_audit', 'next_audit', 'current_status']) + audit_result.save( + update_fields=["current_audit", "next_audit", "current_status"] + ) # 插入审核明细数据 audit_detail_result = WorkflowAuditDetail() audit_detail_result.audit_id = audit_id audit_detail_result.audit_user = audit_user - audit_detail_result.audit_status = WorkflowDict.workflow_status['audit_success'] + audit_detail_result.audit_status = WorkflowDict.workflow_status[ + "audit_success" + ] audit_detail_result.audit_time = timezone.now() audit_detail_result.remark = audit_remark audit_detail_result.save() # 增加工单日志 - audit_auth_group, current_audit_auth_group = Audit.review_info(audit_detail.workflow_id, - audit_detail.workflow_type) - Audit.add_log(audit_id=audit_id, - operation_type=1, - operation_type_desc='审批通过', - operation_info="审批备注:{},下级审批:{}".format(audit_remark, current_audit_auth_group), - operator=audit_user, - operator_display=Users.objects.get(username=audit_user).display - ) - elif audit_status == WorkflowDict.workflow_status['audit_reject']: + audit_auth_group, current_audit_auth_group = Audit.review_info( + audit_detail.workflow_id, audit_detail.workflow_type + ) + Audit.add_log( + audit_id=audit_id, + operation_type=1, + operation_type_desc="审批通过", + operation_info="审批备注:{},下级审批:{}".format( + audit_remark, current_audit_auth_group + ), + operator=audit_user, + operator_display=Users.objects.get(username=audit_user).display, + ) + elif audit_status == WorkflowDict.workflow_status["audit_reject"]: # 判断当前工单是否为待审核状态 - if audit_detail.current_status != WorkflowDict.workflow_status['audit_wait']: - result['msg'] = '工单不是待审核状态,请返回刷新' - raise Exception(result['msg']) + if ( + audit_detail.current_status + != WorkflowDict.workflow_status["audit_wait"] + ): + result["msg"] = "工单不是待审核状态,请返回刷新" + raise Exception(result["msg"]) # 更新主表审核状态 audit_result = WorkflowAudit() audit_result.audit_id = audit_id - audit_result.current_audit = '-1' - audit_result.next_audit = '-1' - audit_result.current_status = WorkflowDict.workflow_status['audit_reject'] - audit_result.save(update_fields=['current_audit', 'next_audit', 'current_status']) + audit_result.current_audit = "-1" + audit_result.next_audit = "-1" + audit_result.current_status = WorkflowDict.workflow_status["audit_reject"] + audit_result.save( + update_fields=["current_audit", "next_audit", "current_status"] + ) # 插入审核明细数据 audit_detail_result = WorkflowAuditDetail() audit_detail_result.audit_id = audit_id audit_detail_result.audit_user = audit_user - audit_detail_result.audit_status = WorkflowDict.workflow_status['audit_reject'] + audit_detail_result.audit_status = WorkflowDict.workflow_status[ + "audit_reject" + ] audit_detail_result.audit_time = timezone.now() audit_detail_result.remark = audit_remark audit_detail_result.save() # 增加工单日志 - Audit.add_log(audit_id=audit_id, - operation_type=2, - operation_type_desc='审批不通过', - operation_info="审批备注:{}".format(audit_remark), - operator=audit_user, - operator_display=Users.objects.get(username=audit_user).display - ) - elif audit_status == WorkflowDict.workflow_status['audit_abort']: + Audit.add_log( + audit_id=audit_id, + operation_type=2, + operation_type_desc="审批不通过", + operation_info="审批备注:{}".format(audit_remark), + operator=audit_user, + operator_display=Users.objects.get(username=audit_user).display, + ) + elif audit_status == WorkflowDict.workflow_status["audit_abort"]: # 判断当前工单是否为待审核/审核通过状态 - if audit_detail.current_status != WorkflowDict.workflow_status['audit_wait'] and \ - audit_detail.current_status != WorkflowDict.workflow_status['audit_success']: - result['msg'] = '工单不是待审核态/审核通过状态,请返回刷新' - raise Exception(result['msg']) + if ( + audit_detail.current_status + != WorkflowDict.workflow_status["audit_wait"] + and audit_detail.current_status + != WorkflowDict.workflow_status["audit_success"] + ): + result["msg"] = "工单不是待审核态/审核通过状态,请返回刷新" + raise Exception(result["msg"]) # 更新主表审核状态 audit_result = WorkflowAudit() audit_result.audit_id = audit_id - audit_result.next_audit = '-1' - audit_result.current_status = WorkflowDict.workflow_status['audit_abort'] - audit_result.save(update_fields=['current_status', 'next_audit']) + audit_result.next_audit = "-1" + audit_result.current_status = WorkflowDict.workflow_status["audit_abort"] + audit_result.save(update_fields=["current_status", "next_audit"]) # 插入审核明细数据 audit_detail_result = WorkflowAuditDetail() audit_detail_result.audit_id = audit_id audit_detail_result.audit_user = audit_user - audit_detail_result.audit_status = WorkflowDict.workflow_status['audit_abort'] + audit_detail_result.audit_status = WorkflowDict.workflow_status[ + "audit_abort" + ] audit_detail_result.audit_time = timezone.now() audit_detail_result.remark = audit_remark audit_detail_result.save() # 增加工单日志 - Audit.add_log(audit_id=audit_id, - operation_type=3, - operation_type_desc='审批取消', - operation_info="取消原因:{}".format(audit_remark), - operator=audit_user, - operator_display=Users.objects.get(username=audit_user).display - ) + Audit.add_log( + audit_id=audit_id, + operation_type=3, + operation_type_desc="审批取消", + operation_info="取消原因:{}".format(audit_remark), + operator=audit_user, + operator_display=Users.objects.get(username=audit_user).display, + ) else: - result['msg'] = '审核异常' - raise Exception(result['msg']) + result["msg"] = "审核异常" + raise Exception(result["msg"]) # 返回审核结果 - result['data'] = {'workflow_status': audit_result.current_status} + result["data"] = {"workflow_status": audit_result.current_status} return result # 获取用户待办工单数量 @@ -275,9 +325,10 @@ def todo(user): auth_group_ids = [group.id for group in Group.objects.filter(user=user)] return WorkflowAudit.objects.filter( - current_status=WorkflowDict.workflow_status['audit_wait'], + current_status=WorkflowDict.workflow_status["audit_wait"], group_id__in=group_ids, - current_audit__in=auth_group_ids).count() + current_audit__in=auth_group_ids, + ).count() # 通过审核id获取审核信息 @staticmethod @@ -291,7 +342,9 @@ def detail(audit_id): @staticmethod def detail_by_workflow_id(workflow_id, workflow_type): try: - return WorkflowAudit.objects.get(workflow_id=workflow_id, workflow_type=workflow_type) + return WorkflowAudit.objects.get( + workflow_id=workflow_id, workflow_type=workflow_type + ) except Exception: return None @@ -299,7 +352,9 @@ def detail_by_workflow_id(workflow_id, workflow_type): @staticmethod def settings(group_id, workflow_type): try: - return WorkflowAuditSetting.objects.get(workflow_type=workflow_type, group_id=group_id).audit_auth_groups + return WorkflowAuditSetting.objects.get( + workflow_type=workflow_type, group_id=group_id + ).audit_auth_groups except Exception: return None @@ -307,10 +362,12 @@ def settings(group_id, workflow_type): @staticmethod def change_settings(group_id, workflow_type, audit_auth_groups): try: - WorkflowAuditSetting.objects.get(workflow_type=workflow_type, group_id=group_id) - WorkflowAuditSetting.objects.filter(workflow_type=workflow_type, - group_id=group_id - ).update(audit_auth_groups=audit_auth_groups) + WorkflowAuditSetting.objects.get( + workflow_type=workflow_type, group_id=group_id + ) + WorkflowAuditSetting.objects.filter( + workflow_type=workflow_type, group_id=group_id + ).update(audit_auth_groups=audit_auth_groups) except Exception: inset = WorkflowAuditSetting() inset.group_id = group_id @@ -322,59 +379,84 @@ def change_settings(group_id, workflow_type, audit_auth_groups): # 判断用户当前是否是可审核 @staticmethod def can_review(user, workflow_id, workflow_type): - audit_info = WorkflowAudit.objects.get(workflow_id=workflow_id, workflow_type=workflow_type) + audit_info = WorkflowAudit.objects.get( + workflow_id=workflow_id, workflow_type=workflow_type + ) group_id = audit_info.group_id result = False # 只有待审核状态数据才可以审核 - if audit_info.current_status == WorkflowDict.workflow_status['audit_wait']: + if audit_info.current_status == WorkflowDict.workflow_status["audit_wait"]: try: - auth_group_id = Audit.detail_by_workflow_id(workflow_id, workflow_type).current_audit + auth_group_id = Audit.detail_by_workflow_id( + workflow_id, workflow_type + ).current_audit audit_auth_group = Group.objects.get(id=auth_group_id).name except Exception: - raise Exception('当前审批auth_group_id不存在,请检查并清洗历史数据') - if auth_group_users([audit_auth_group], group_id).filter(id=user.id).exists() or user.is_superuser == 1: + raise Exception("当前审批auth_group_id不存在,请检查并清洗历史数据") + if ( + auth_group_users([audit_auth_group], group_id) + .filter(id=user.id) + .exists() + or user.is_superuser == 1 + ): if workflow_type == 1: - if user.has_perm('sql.query_review'): + if user.has_perm("sql.query_review"): result = True elif workflow_type == 2: - if user.has_perm('sql.sql_review'): + if user.has_perm("sql.sql_review"): result = True elif workflow_type == 3: - if user.has_perm('sql.archive_review'): + if user.has_perm("sql.archive_review"): result = True return result # 获取当前工单审批流程和当前审核组 @staticmethod def review_info(workflow_id, workflow_type): - audit_info = WorkflowAudit.objects.get(workflow_id=workflow_id, workflow_type=workflow_type) - if audit_info.audit_auth_groups == '': - audit_auth_group = '无需审批' + audit_info = WorkflowAudit.objects.get( + workflow_id=workflow_id, workflow_type=workflow_type + ) + if audit_info.audit_auth_groups == "": + audit_auth_group = "无需审批" else: try: - audit_auth_group = '->'.join([Group.objects.get(id=auth_group_id).name for auth_group_id in - audit_info.audit_auth_groups.split(',')]) + audit_auth_group = "->".join( + [ + Group.objects.get(id=auth_group_id).name + for auth_group_id in audit_info.audit_auth_groups.split(",") + ] + ) except Exception: audit_auth_group = audit_info.audit_auth_groups - if audit_info.current_audit == '-1': + if audit_info.current_audit == "-1": current_audit_auth_group = None else: try: - current_audit_auth_group = Group.objects.get(id=audit_info.current_audit).name + current_audit_auth_group = Group.objects.get( + id=audit_info.current_audit + ).name except Exception: current_audit_auth_group = audit_info.current_audit return audit_auth_group, current_audit_auth_group # 新增工单日志 @staticmethod - def add_log(audit_id, operation_type, operation_type_desc, operation_info, operator, operator_display): - WorkflowLog(audit_id=audit_id, - operation_type=operation_type, - operation_type_desc=operation_type_desc, - operation_info=operation_info, - operator=operator, - operator_display=operator_display - ).save() + def add_log( + audit_id, + operation_type, + operation_type_desc, + operation_info, + operator, + operator_display, + ): + WorkflowLog( + audit_id=audit_id, + operation_type=operation_type, + operation_type_desc=operation_type_desc, + operation_info=operation_info, + operator=operator, + operator_display=operator_display, + ).save() # 获取工单日志 @staticmethod diff --git a/sql/views.py b/sql/views.py index 8ea4b6df95..15bda95d3b 100644 --- a/sql/views.py +++ b/sql/views.py @@ -20,67 +20,99 @@ from sql.engines.models import ReviewResult, ReviewSet from sql.utils.tasks import task_info -from .models import Users, SqlWorkflow, QueryPrivileges, ResourceGroup, \ - QueryPrivilegesApply, Config, SQL_WORKFLOW_CHOICES, InstanceTag, Instance, \ - QueryLog, ArchiveConfig, AuditEntry, TwoFactorAuthConfig +from .models import ( + Users, + SqlWorkflow, + QueryPrivileges, + ResourceGroup, + QueryPrivilegesApply, + Config, + SQL_WORKFLOW_CHOICES, + InstanceTag, + Instance, + QueryLog, + ArchiveConfig, + AuditEntry, + TwoFactorAuthConfig, +) from sql.utils.workflow_audit import Audit -from sql.utils.sql_review import can_execute, can_timingtask, can_cancel, can_view, can_rollback +from sql.utils.sql_review import ( + can_execute, + can_timingtask, + can_cancel, + can_view, + can_rollback, +) from common.utils.const import Const, WorkflowDict from sql.utils.resource_group import user_groups, user_instances, auth_group_users import logging -logger = logging.getLogger('default') +logger = logging.getLogger("default") def index(request): - index_path_url = SysConfig().get('index_path_url', 'sqlworkflow') + index_path_url = SysConfig().get("index_path_url", "sqlworkflow") return HttpResponseRedirect(f"/{index_path_url.strip('/')}/") def login(request): """登录页面""" if request.user and request.user.is_authenticated: - return HttpResponseRedirect('/') + return HttpResponseRedirect("/") - return render(request, 'login.html', context={'sign_up_enabled': SysConfig().get('sign_up_enabled')}) + return render( + request, + "login.html", + context={"sign_up_enabled": SysConfig().get("sign_up_enabled")}, + ) def twofa(request): """2fa认证页面""" if request.user.is_authenticated: - return HttpResponseRedirect('/') + return HttpResponseRedirect("/") - username = request.session.get('user') + username = request.session.get("user") if username: - verify_mode = request.session.get('verify_mode') + verify_mode = request.session.get("verify_mode") twofa_enabled = TwoFactorAuthConfig.objects.filter(username=username) user_auth_types = [twofa.auth_type for twofa in twofa_enabled] auth_types = [] for user_auth_type in user_auth_types: auth_type = {} - auth_type['code'] = user_auth_type - if user_auth_type == 'totp': - auth_type['display'] = 'Google身份验证器' - elif user_auth_type == 'sms': - auth_type['display'] = '短信验证码' + auth_type["code"] = user_auth_type + if user_auth_type == "totp": + auth_type["display"] = "Google身份验证器" + elif user_auth_type == "sms": + auth_type["display"] = "短信验证码" auth_types.append(auth_type) - if 'sms' in user_auth_types: - phone = TwoFactorAuthConfig.objects.get(username=username, auth_type='sms').phone + if "sms" in user_auth_types: + phone = TwoFactorAuthConfig.objects.get( + username=username, auth_type="sms" + ).phone else: phone = 0 else: - return HttpResponseRedirect('/login/') + return HttpResponseRedirect("/login/") - return render(request, '2fa.html', context={'verify_mode': verify_mode, 'auth_types': auth_types, - 'username': username, 'phone': phone}) + return render( + request, + "2fa.html", + context={ + "verify_mode": verify_mode, + "auth_types": auth_types, + "username": username, + "phone": phone, + }, + ) -@permission_required('sql.menu_dashboard', raise_exception=True) +@permission_required("sql.menu_dashboard", raise_exception=True) def dashboard(request): """dashboard页面""" - return render(request, 'dashboard.html') + return render(request, "dashboard.html") def sqlworkflow(request): @@ -89,28 +121,42 @@ def sqlworkflow(request): # 过滤筛选项的数据 filter_dict = dict() # 管理员,可查看所有工单 - if user.is_superuser or user.has_perm('sql.audit_user'): + if user.is_superuser or user.has_perm("sql.audit_user"): pass # 非管理员,拥有审核权限、资源组粒度执行权限的,可以查看组内所有工单 - elif user.has_perm('sql.sql_review') or user.has_perm('sql.sql_execute_for_resource_group'): + elif user.has_perm("sql.sql_review") or user.has_perm( + "sql.sql_execute_for_resource_group" + ): # 先获取用户所在资源组列表 group_list = user_groups(user) group_ids = [group.group_id for group in group_list] - filter_dict['group_id__in'] = group_ids + filter_dict["group_id__in"] = group_ids # 其他人只能查看自己提交的工单 else: - filter_dict['engineer'] = user.username - instance_id = SqlWorkflow.objects.filter(**filter_dict).values('instance_id').distinct() - instance = Instance.objects.filter(pk__in=instance_id).order_by(Convert('instance_name', 'gbk').asc()) - resource_group_id = SqlWorkflow.objects.filter(**filter_dict).values('group_id').distinct() + filter_dict["engineer"] = user.username + instance_id = ( + SqlWorkflow.objects.filter(**filter_dict).values("instance_id").distinct() + ) + instance = Instance.objects.filter(pk__in=instance_id).order_by( + Convert("instance_name", "gbk").asc() + ) + resource_group_id = ( + SqlWorkflow.objects.filter(**filter_dict).values("group_id").distinct() + ) resource_group = ResourceGroup.objects.filter(group_id__in=resource_group_id) - return render(request, 'sqlworkflow.html', - {'status_list': SQL_WORKFLOW_CHOICES, - 'instance': instance, 'resource_group': resource_group}) + return render( + request, + "sqlworkflow.html", + { + "status_list": SQL_WORKFLOW_CHOICES, + "instance": instance, + "resource_group": resource_group, + }, + ) -@permission_required('sql.sql_submit', raise_exception=True) +@permission_required("sql.sql_submit", raise_exception=True) def submit_sql(request): """提交SQL的页面""" user = request.user @@ -124,11 +170,16 @@ def submit_sql(request): archer_config = SysConfig() # 主动创建标签 - InstanceTag.objects.get_or_create(tag_code='can_write', defaults={'tag_name': '支持上线', 'active': True}) + InstanceTag.objects.get_or_create( + tag_code="can_write", defaults={"tag_name": "支持上线", "active": True} + ) - context = {'active_user': active_user, 'group_list': group_list, - 'enable_backup_switch': archer_config.get('enable_backup_switch')} - return render(request, 'sqlsubmit.html', context) + context = { + "active_user": active_user, + "group_list": group_list, + "enable_backup_switch": archer_config.get("enable_backup_switch"), + } + return render(request, "sqlsubmit.html", context) def detail(request, workflow_id): @@ -137,7 +188,7 @@ def detail(request, workflow_id): if not can_view(request.user, workflow_id): raise PermissionDenied # 自动审批不通过的不需要获取下列信息 - if workflow_detail.status != 'workflow_autoreviewwrong': + if workflow_detail.status != "workflow_autoreviewwrong": # 获取当前审批和审批流程 audit_auth_group, current_audit_auth_group = Audit.review_info(workflow_id, 2) @@ -154,22 +205,30 @@ def detail(request, workflow_id): # 获取审核日志 try: - audit_detail = Audit.detail_by_workflow_id(workflow_id=workflow_id, - workflow_type=WorkflowDict.workflow_type['sqlreview']) + audit_detail = Audit.detail_by_workflow_id( + workflow_id=workflow_id, + workflow_type=WorkflowDict.workflow_type["sqlreview"], + ) audit_id = audit_detail.audit_id - last_operation_info = Audit.logs(audit_id=audit_id).latest('id').operation_info + last_operation_info = ( + Audit.logs(audit_id=audit_id).latest("id").operation_info + ) # 等待审批的展示当前全部审批人 - if workflow_detail.status == 'workflow_manreviewing': + if workflow_detail.status == "workflow_manreviewing": auth_group_name = Group.objects.get(id=audit_detail.current_audit).name - current_audit_users = auth_group_users([auth_group_name], audit_detail.group_id) - current_audit_users_display = [user.display for user in current_audit_users] - last_operation_info += ',当前审批人:' + ','.join(current_audit_users_display) + current_audit_users = auth_group_users( + [auth_group_name], audit_detail.group_id + ) + current_audit_users_display = [ + user.display for user in current_audit_users + ] + last_operation_info += ",当前审批人:" + ",".join(current_audit_users_display) except Exception as e: - logger.debug(f'无审核日志记录,错误信息{e}') - last_operation_info = '' + logger.debug(f"无审核日志记录,错误信息{e}") + last_operation_info = "" else: - audit_auth_group = '系统自动驳回' - current_audit_auth_group = '系统自动驳回' + audit_auth_group = "系统自动驳回" + current_audit_auth_group = "系统自动驳回" is_can_review = False is_can_execute = False is_can_timingtask = False @@ -178,35 +237,44 @@ def detail(request, workflow_id): last_operation_info = None # 获取定时执行任务信息 - if workflow_detail.status == 'workflow_timingtask': - job_id = Const.workflowJobprefix['sqlreview'] + '-' + str(workflow_id) + if workflow_detail.status == "workflow_timingtask": + job_id = Const.workflowJobprefix["sqlreview"] + "-" + str(workflow_id) job = task_info(job_id) if job: run_date = job.next_run else: - run_date = '' + run_date = "" else: - run_date = '' + run_date = "" # 获取是否开启手工执行确认 - manual = SysConfig().get('manual') - - context = {'workflow_detail': workflow_detail, 'last_operation_info': last_operation_info, - 'is_can_review': is_can_review, 'is_can_execute': is_can_execute, 'is_can_timingtask': is_can_timingtask, - 'is_can_cancel': is_can_cancel, 'is_can_rollback': is_can_rollback, 'audit_auth_group': audit_auth_group, - 'manual': manual, 'current_audit_auth_group': current_audit_auth_group, 'run_date': run_date} - return render(request, 'detail.html', context) + manual = SysConfig().get("manual") + + context = { + "workflow_detail": workflow_detail, + "last_operation_info": last_operation_info, + "is_can_review": is_can_review, + "is_can_execute": is_can_execute, + "is_can_timingtask": is_can_timingtask, + "is_can_cancel": is_can_cancel, + "is_can_rollback": is_can_rollback, + "audit_auth_group": audit_auth_group, + "manual": manual, + "current_audit_auth_group": current_audit_auth_group, + "run_date": run_date, + } + return render(request, "detail.html", context) def rollback(request): """展示回滚的SQL页面""" - workflow_id = request.GET.get('workflow_id') + workflow_id = request.GET.get("workflow_id") if not can_rollback(request.user, workflow_id): raise PermissionDenied - download = request.GET.get('download') - if workflow_id == '' or workflow_id is None: - context = {'errMsg': 'workflow_id参数为空.'} - return render(request, 'error.html', context) + download = request.GET.get("download") + if workflow_id == "" or workflow_id is None: + context = {"errMsg": "workflow_id参数为空."} + return render(request, "error.html", context) workflow = SqlWorkflow.objects.get(id=int(workflow_id)) # 直接下载回滚语句 @@ -216,55 +284,66 @@ def rollback(request): list_backup_sql = query_engine.get_rollback(workflow=workflow) except Exception as msg: logger.error(traceback.format_exc()) - context = {'errMsg': msg} - return render(request, 'error.html', context) + context = {"errMsg": msg} + return render(request, "error.html", context) # 获取数据,存入目录 - path = os.path.join(settings.BASE_DIR, 'downloads/rollback') + path = os.path.join(settings.BASE_DIR, "downloads/rollback") os.makedirs(path, exist_ok=True) - file_name = f'{path}/rollback_{workflow_id}.sql' - with open(file_name, 'w') as f: + file_name = f"{path}/rollback_{workflow_id}.sql" + with open(file_name, "w") as f: for sql in list_backup_sql: - f.write(f'/*{sql[0]}*/\n{sql[1]}\n') + f.write(f"/*{sql[0]}*/\n{sql[1]}\n") # 返回 - response = FileResponse(open(file_name, 'rb')) - response['Content-Type'] = 'application/octet-stream' - response['Content-Disposition'] = f'attachment;filename="rollback_{workflow_id}.sql"' + response = FileResponse(open(file_name, "rb")) + response["Content-Type"] = "application/octet-stream" + response[ + "Content-Disposition" + ] = f'attachment;filename="rollback_{workflow_id}.sql"' return response # 异步获取,并在页面展示,如果数据量大加载会缓慢 else: rollback_workflow_name = f"【回滚工单】原工单Id:{workflow_id} ,{workflow.workflow_name}" - context = {'workflow_detail': workflow, 'rollback_workflow_name': rollback_workflow_name} - return render(request, 'rollback.html', context) + context = { + "workflow_detail": workflow, + "rollback_workflow_name": rollback_workflow_name, + } + return render(request, "rollback.html", context) -@permission_required('sql.menu_sqlanalyze', raise_exception=True) +@permission_required("sql.menu_sqlanalyze", raise_exception=True) def sqlanalyze(request): """SQL分析页面""" - return render(request, 'sqlanalyze.html') + return render(request, "sqlanalyze.html") -@permission_required('sql.menu_query', raise_exception=True) +@permission_required("sql.menu_query", raise_exception=True) def sqlquery(request): """SQL在线查询页面""" # 主动创建标签 - InstanceTag.objects.get_or_create(tag_code='can_read', defaults={'tag_name': '支持查询', 'active': True}) + InstanceTag.objects.get_or_create( + tag_code="can_read", defaults={"tag_name": "支持查询", "active": True} + ) # 收藏语句 user = request.user - favorites = QueryLog.objects.filter(username=user.username, favorite=True).values('id', 'alias') - can_download = 1 if user.has_perm('sql.query_download') or user.is_superuser else 0 - return render(request, 'sqlquery.html', {'favorites': favorites, 'can_download':can_download}) + favorites = QueryLog.objects.filter(username=user.username, favorite=True).values( + "id", "alias" + ) + can_download = 1 if user.has_perm("sql.query_download") or user.is_superuser else 0 + return render( + request, "sqlquery.html", {"favorites": favorites, "can_download": can_download} + ) -@permission_required('sql.menu_queryapplylist', raise_exception=True) +@permission_required("sql.menu_queryapplylist", raise_exception=True) def queryapplylist(request): """查询权限申请列表页面""" user = request.user # 获取资源组 group_list = user_groups(user) - context = {'group_list': group_list} - return render(request, 'queryapplylist.html', context) + context = {"group_list": group_list} + return render(request, "queryapplylist.html", context) def queryapplydetail(request, apply_id): @@ -278,100 +357,114 @@ def queryapplydetail(request, apply_id): # 获取审核日志 if workflow_detail.status == 2: try: - audit_id = Audit.detail_by_workflow_id(workflow_id=apply_id, workflow_type=1).audit_id - last_operation_info = Audit.logs(audit_id=audit_id).latest('id').operation_info + audit_id = Audit.detail_by_workflow_id( + workflow_id=apply_id, workflow_type=1 + ).audit_id + last_operation_info = ( + Audit.logs(audit_id=audit_id).latest("id").operation_info + ) except Exception as e: - logger.debug(f'无审核日志记录,错误信息{e}') - last_operation_info = '' + logger.debug(f"无审核日志记录,错误信息{e}") + last_operation_info = "" else: - last_operation_info = '' + last_operation_info = "" - context = {'workflow_detail': workflow_detail, 'audit_auth_group': audit_auth_group, - 'last_operation_info': last_operation_info, 'current_audit_auth_group': current_audit_auth_group, - 'is_can_review': is_can_review} - return render(request, 'queryapplydetail.html', context) + context = { + "workflow_detail": workflow_detail, + "audit_auth_group": audit_auth_group, + "last_operation_info": last_operation_info, + "current_audit_auth_group": current_audit_auth_group, + "is_can_review": is_can_review, + } + return render(request, "queryapplydetail.html", context) def queryuserprivileges(request): """查询权限管理页面""" # 获取所有用户 - user_list = QueryPrivileges.objects.filter(is_deleted=0).values('user_display').distinct() - context = {'user_list': user_list} - return render(request, 'queryuserprivileges.html', context) + user_list = ( + QueryPrivileges.objects.filter(is_deleted=0).values("user_display").distinct() + ) + context = {"user_list": user_list} + return render(request, "queryuserprivileges.html", context) -@permission_required('sql.menu_sqladvisor', raise_exception=True) +@permission_required("sql.menu_sqladvisor", raise_exception=True) def sqladvisor(request): """SQL优化工具页面""" - return render(request, 'sqladvisor.html') + return render(request, "sqladvisor.html") -@permission_required('sql.menu_slowquery', raise_exception=True) +@permission_required("sql.menu_slowquery", raise_exception=True) def slowquery(request): """SQL慢日志页面""" - return render(request, 'slowquery.html') + return render(request, "slowquery.html") -@permission_required('sql.menu_instance', raise_exception=True) +@permission_required("sql.menu_instance", raise_exception=True) def instance(request): """实例管理页面""" # 获取实例标签 tags = InstanceTag.objects.filter(active=True) - return render(request, 'instance.html', {'tags': tags}) + return render(request, "instance.html", {"tags": tags}) -@permission_required('sql.menu_instance_account', raise_exception=True) +@permission_required("sql.menu_instance_account", raise_exception=True) def instanceaccount(request): """实例账号管理页面""" - return render(request, 'instanceaccount.html') + return render(request, "instanceaccount.html") -@permission_required('sql.menu_database', raise_exception=True) +@permission_required("sql.menu_database", raise_exception=True) def database(request): """实例数据库管理页面""" # 获取所有有效用户,通知对象 active_user = Users.objects.filter(is_active=1) - return render(request, 'database.html', {"active_user": active_user}) + return render(request, "database.html", {"active_user": active_user}) -@permission_required('sql.menu_dbdiagnostic', raise_exception=True) +@permission_required("sql.menu_dbdiagnostic", raise_exception=True) def dbdiagnostic(request): """会话管理页面""" - return render(request, 'dbdiagnostic.html') + return render(request, "dbdiagnostic.html") -@permission_required('sql.menu_data_dictionary', raise_exception=True) +@permission_required("sql.menu_data_dictionary", raise_exception=True) def data_dictionary(request): """数据字典页面""" - return render(request, 'data_dictionary.html', locals()) + return render(request, "data_dictionary.html", locals()) -@permission_required('sql.menu_param', raise_exception=True) +@permission_required("sql.menu_param", raise_exception=True) def instance_param(request): """实例参数管理页面""" - return render(request, 'param.html') + return render(request, "param.html") -@permission_required('sql.menu_my2sql', raise_exception=True) +@permission_required("sql.menu_my2sql", raise_exception=True) def my2sql(request): """my2sql页面""" - return render(request, 'my2sql.html') + return render(request, "my2sql.html") -@permission_required('sql.menu_schemasync', raise_exception=True) +@permission_required("sql.menu_schemasync", raise_exception=True) def schemasync(request): """数据库差异对比页面""" - return render(request, 'schemasync.html') + return render(request, "schemasync.html") -@permission_required('sql.menu_archive', raise_exception=True) +@permission_required("sql.menu_archive", raise_exception=True) def archive(request): """归档列表页面""" # 获取资源组 group_list = user_groups(request.user) - ins_list = user_instances(request.user, db_type=['mysql']).order_by(Convert('instance_name', 'gbk').asc()) - return render(request, 'archive.html', {'group_list': group_list, 'ins_list': ins_list}) + ins_list = user_instances(request.user, db_type=["mysql"]).order_by( + Convert("instance_name", "gbk").asc() + ) + return render( + request, "archive.html", {"group_list": group_list, "ins_list": ins_list} + ) def archive_detail(request, id): @@ -382,24 +475,32 @@ def archive_detail(request, id): audit_auth_group, current_audit_auth_group = Audit.review_info(id, 3) is_can_review = Audit.can_review(request.user, id, 3) except Exception as e: - logger.debug(f'归档配置{id}无审核信息,{e}') + logger.debug(f"归档配置{id}无审核信息,{e}") audit_auth_group, current_audit_auth_group = None, None is_can_review = False # 获取审核日志 if archive_config.status == 2: try: - audit_id = Audit.detail_by_workflow_id(workflow_id=id, workflow_type=3).audit_id - last_operation_info = Audit.logs(audit_id=audit_id).latest('id').operation_info + audit_id = Audit.detail_by_workflow_id( + workflow_id=id, workflow_type=3 + ).audit_id + last_operation_info = ( + Audit.logs(audit_id=audit_id).latest("id").operation_info + ) except Exception as e: - logger.debug(f'归档配置{id}无审核日志记录,错误信息{e}') - last_operation_info = '' + logger.debug(f"归档配置{id}无审核日志记录,错误信息{e}") + last_operation_info = "" else: - last_operation_info = '' + last_operation_info = "" - context = {'archive_config': archive_config, 'audit_auth_group': audit_auth_group, - 'last_operation_info': last_operation_info, 'current_audit_auth_group': current_audit_auth_group, - 'is_can_review': is_can_review} - return render(request, 'archivedetail.html', context) + context = { + "archive_config": archive_config, + "audit_auth_group": audit_auth_group, + "last_operation_info": last_operation_info, + "current_audit_auth_group": current_audit_auth_group, + "is_can_review": is_can_review, + } + return render(request, "archivedetail.html", context) @superuser_required @@ -412,29 +513,35 @@ def config(request): # 获取所有实例标签 instance_tags = InstanceTag.objects.all() # 支持自动审核的数据库类型 - db_type = ['mysql', 'oracle', 'mongo', 'clickhouse'] + db_type = ["mysql", "oracle", "mongo", "clickhouse"] # 获取所有配置项 - all_config = Config.objects.all().values('item', 'value') + all_config = Config.objects.all().values("item", "value") sys_config = {} for items in all_config: - sys_config[items['item']] = items['value'] + sys_config[items["item"]] = items["value"] - context = {'group_list': group_list, 'auth_group_list': auth_group_list, 'instance_tags': instance_tags, - 'db_type': db_type, 'config': sys_config, 'WorkflowDict': WorkflowDict} - return render(request, 'config.html', context) + context = { + "group_list": group_list, + "auth_group_list": auth_group_list, + "instance_tags": instance_tags, + "db_type": db_type, + "config": sys_config, + "WorkflowDict": WorkflowDict, + } + return render(request, "config.html", context) @superuser_required def group(request): """资源组管理页面""" - return render(request, 'group.html') + return render(request, "group.html") @superuser_required def groupmgmt(request, group_id): """资源组组关系管理页面""" group = ResourceGroup.objects.get(group_id=group_id) - return render(request, 'groupmgmt.html', {'group': group}) + return render(request, "groupmgmt.html", {"group": group}) def workflows(request): @@ -448,38 +555,46 @@ def workflowsdetail(request, audit_id): audit_detail = Audit.detail(audit_id) if not audit_detail: raise Http404("不存在对应的工单记录") - if audit_detail.workflow_type == WorkflowDict.workflow_type['query']: - return HttpResponseRedirect(reverse('sql:queryapplydetail', args=(audit_detail.workflow_id,))) - elif audit_detail.workflow_type == WorkflowDict.workflow_type['sqlreview']: - return HttpResponseRedirect(reverse('sql:detail', args=(audit_detail.workflow_id,))) - elif audit_detail.workflow_type == WorkflowDict.workflow_type['archive']: - return HttpResponseRedirect(reverse('sql:archive_detail', args=(audit_detail.workflow_id,))) - - -@permission_required('sql.menu_document', raise_exception=True) + if audit_detail.workflow_type == WorkflowDict.workflow_type["query"]: + return HttpResponseRedirect( + reverse("sql:queryapplydetail", args=(audit_detail.workflow_id,)) + ) + elif audit_detail.workflow_type == WorkflowDict.workflow_type["sqlreview"]: + return HttpResponseRedirect( + reverse("sql:detail", args=(audit_detail.workflow_id,)) + ) + elif audit_detail.workflow_type == WorkflowDict.workflow_type["archive"]: + return HttpResponseRedirect( + reverse("sql:archive_detail", args=(audit_detail.workflow_id,)) + ) + + +@permission_required("sql.menu_document", raise_exception=True) def dbaprinciples(request): """SQL文档页面""" # 读取MD文件 - file = os.path.join(settings.BASE_DIR, 'docs/docs.md') - with open(file, 'r', encoding="utf-8") as f: - md = f.read().replace('\n', '\\n') - return render(request, 'dbaprinciples.html', {'md': md}) + file = os.path.join(settings.BASE_DIR, "docs/docs.md") + with open(file, "r", encoding="utf-8") as f: + md = f.read().replace("\n", "\\n") + return render(request, "dbaprinciples.html", {"md": md}) -@permission_required('sql.audit_user', raise_exception=True) +@permission_required("sql.audit_user", raise_exception=True) def audit(request): """通用审计日志页面""" - _action_types = AuditEntry.objects.values_list('action').distinct() - action_types = [ i[0] for i in _action_types ] - return render(request, 'audit.html', {'action_types': action_types}) + _action_types = AuditEntry.objects.values_list("action").distinct() + action_types = [i[0] for i in _action_types] + return render(request, "audit.html", {"action_types": action_types}) -@permission_required('sql.audit_user', raise_exception=True) +@permission_required("sql.audit_user", raise_exception=True) def audit_sqlquery(request): """SQL在线查询页面审计""" user = request.user - favorites = QueryLog.objects.filter(username=user.username, favorite=True).values('id', 'alias') - return render(request, 'audit_sqlquery.html', {'favorites': favorites}) + favorites = QueryLog.objects.filter(username=user.username, favorite=True).values( + "id", "alias" + ) + return render(request, "audit_sqlquery.html", {"favorites": favorites}) def audit_sqlworkflow(request): @@ -488,22 +603,34 @@ def audit_sqlworkflow(request): # 过滤筛选项的数据 filter_dict = dict() # 管理员,可查看所有工单 - if user.is_superuser or user.has_perm('sql.audit_user'): + if user.is_superuser or user.has_perm("sql.audit_user"): pass # 非管理员,拥有审核权限、资源组粒度执行权限的,可以查看组内所有工单 - elif user.has_perm('sql.sql_review') or user.has_perm('sql.sql_execute_for_resource_group'): + elif user.has_perm("sql.sql_review") or user.has_perm( + "sql.sql_execute_for_resource_group" + ): # 先获取用户所在资源组列表 group_list = user_groups(user) group_ids = [group.group_id for group in group_list] - filter_dict['group_id__in'] = group_ids + filter_dict["group_id__in"] = group_ids # 其他人只能查看自己提交的工单 else: - filter_dict['engineer'] = user.username - instance_id = SqlWorkflow.objects.filter(**filter_dict).values('instance_id').distinct() + filter_dict["engineer"] = user.username + instance_id = ( + SqlWorkflow.objects.filter(**filter_dict).values("instance_id").distinct() + ) instance = Instance.objects.filter(pk__in=instance_id) - resource_group_id = SqlWorkflow.objects.filter(**filter_dict).values('group_id').distinct() + resource_group_id = ( + SqlWorkflow.objects.filter(**filter_dict).values("group_id").distinct() + ) resource_group = ResourceGroup.objects.filter(group_id__in=resource_group_id) - return render(request, 'audit_sqlworkflow.html', - {'status_list': SQL_WORKFLOW_CHOICES, - 'instance': instance, 'resource_group': resource_group}) + return render( + request, + "audit_sqlworkflow.html", + { + "status_list": SQL_WORKFLOW_CHOICES, + "instance": instance, + "resource_group": resource_group, + }, + ) diff --git a/sql_api/api_instance.py b/sql_api/api_instance.py index 0229c19483..4cb50b51ba 100644 --- a/sql_api/api_instance.py +++ b/sql_api/api_instance.py @@ -1,8 +1,14 @@ from rest_framework import views, generics, status, serializers from rest_framework.response import Response from drf_spectacular.utils import extend_schema -from .serializers import InstanceSerializer, InstanceDetailSerializer, TunnelSerializer, \ - AliyunRdsSerializer, InstanceResourceSerializer, InstanceResourceListSerializer +from .serializers import ( + InstanceSerializer, + InstanceDetailSerializer, + TunnelSerializer, + AliyunRdsSerializer, + InstanceResourceSerializer, + InstanceResourceListSerializer, +) from .pagination import CustomizedPagination from .filters import InstanceFilter from sql.models import Instance, Tunnel, AliyunRdsConfig @@ -15,28 +21,31 @@ class InstanceList(generics.ListAPIView): """ 列出所有的instance或者创建一个新的instance配置 """ + filterset_class = InstanceFilter pagination_class = CustomizedPagination serializer_class = InstanceSerializer - queryset = Instance.objects.all().order_by('id') - - @extend_schema(summary="实例清单", - request=InstanceSerializer, - responses={200: InstanceSerializer}, - description="列出所有实例(过滤,分页)") + queryset = Instance.objects.all().order_by("id") + + @extend_schema( + summary="实例清单", + request=InstanceSerializer, + responses={200: InstanceSerializer}, + description="列出所有实例(过滤,分页)", + ) def get(self, request): instances = self.filter_queryset(self.queryset) page_ins = self.paginate_queryset(queryset=instances) serializer_obj = self.get_serializer(page_ins, many=True) - data = { - 'data': serializer_obj.data - } + data = {"data": serializer_obj.data} return self.get_paginated_response(data) - @extend_schema(summary="创建实例", - request=InstanceSerializer, - responses={201: InstanceSerializer}, - description="创建一个实例配置") + @extend_schema( + summary="创建实例", + request=InstanceSerializer, + responses={201: InstanceSerializer}, + description="创建一个实例配置", + ) def post(self, request): serializer = InstanceSerializer(data=request.data) if serializer.is_valid(): @@ -49,6 +58,7 @@ class InstanceDetail(views.APIView): """ 实例操作 """ + serializer_class = InstanceDetailSerializer def get_object(self, pk): @@ -57,10 +67,12 @@ def get_object(self, pk): except Instance.DoesNotExist: raise Http404 - @extend_schema(summary="更新实例", - request=InstanceDetailSerializer, - responses={200: InstanceDetailSerializer}, - description="更新一个实例配置") + @extend_schema( + summary="更新实例", + request=InstanceDetailSerializer, + responses={200: InstanceDetailSerializer}, + description="更新一个实例配置", + ) def put(self, request, pk): instance = self.get_object(pk) serializer = InstanceDetailSerializer(instance, data=request.data) @@ -69,8 +81,7 @@ def put(self, request, pk): return Response(serializer.data) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - @extend_schema(summary="删除实例", - description="删除一个实例配置") + @extend_schema(summary="删除实例", description="删除一个实例配置") def delete(self, request, pk): instance = self.get_object(pk) instance.delete() @@ -81,27 +92,30 @@ class TunnelList(generics.ListAPIView): """ 列出所有的tunnel或者创建一个新的tunnel配置 """ + pagination_class = CustomizedPagination serializer_class = TunnelSerializer - queryset = Tunnel.objects.all().order_by('id') - - @extend_schema(summary="隧道清单", - request=TunnelSerializer, - responses={200: TunnelSerializer}, - description="列出所有隧道(过滤,分页)") + queryset = Tunnel.objects.all().order_by("id") + + @extend_schema( + summary="隧道清单", + request=TunnelSerializer, + responses={200: TunnelSerializer}, + description="列出所有隧道(过滤,分页)", + ) def get(self, request): tunnels = self.filter_queryset(self.queryset) page_tunnels = self.paginate_queryset(queryset=tunnels) serializer_obj = self.get_serializer(page_tunnels, many=True) - data = { - 'data': serializer_obj.data - } + data = {"data": serializer_obj.data} return self.get_paginated_response(data) - @extend_schema(summary="创建隧道", - request=TunnelSerializer, - responses={201: TunnelSerializer}, - description="创建一个隧道配置") + @extend_schema( + summary="创建隧道", + request=TunnelSerializer, + responses={201: TunnelSerializer}, + description="创建一个隧道配置", + ) def post(self, request): serializer = TunnelSerializer(data=request.data) if serializer.is_valid(): @@ -114,27 +128,30 @@ class AliyunRdsList(generics.ListAPIView): """ 列出所有的AliyunRDS或者创建一个新的AliyunRDS配置 """ + pagination_class = CustomizedPagination serializer_class = AliyunRdsSerializer - queryset = AliyunRdsConfig.objects.all().select_related('ak').order_by('id') - - @extend_schema(summary="AliyunRDS清单", - request=AliyunRdsSerializer, - responses={200: AliyunRdsSerializer}, - description="列出所有AliyunRDS(过滤,分页)") + queryset = AliyunRdsConfig.objects.all().select_related("ak").order_by("id") + + @extend_schema( + summary="AliyunRDS清单", + request=AliyunRdsSerializer, + responses={200: AliyunRdsSerializer}, + description="列出所有AliyunRDS(过滤,分页)", + ) def get(self, request): aliyunrds = self.filter_queryset(self.queryset) page_rds = self.paginate_queryset(queryset=aliyunrds) serializer_obj = self.get_serializer(page_rds, many=True) - data = { - 'data': serializer_obj.data - } + data = {"data": serializer_obj.data} return self.get_paginated_response(data) - @extend_schema(summary="创建AliyunRDS", - request=AliyunRdsSerializer, - responses={201: AliyunRdsSerializer}, - description="创建一个AliyunRDS配置(包含一个CloudAccessKey)") + @extend_schema( + summary="创建AliyunRDS", + request=AliyunRdsSerializer, + responses={201: AliyunRdsSerializer}, + description="创建一个AliyunRDS配置(包含一个CloudAccessKey)", + ) def post(self, request): serializer = AliyunRdsSerializer(data=request.data) if serializer.is_valid(): @@ -148,47 +165,54 @@ class InstanceResource(views.APIView): 获取实例内的资源信息,database、schema、table、column """ - @extend_schema(summary="实例资源", - request=InstanceResourceSerializer, - responses={200: InstanceResourceListSerializer}, - description="获取实例内的资源信息") + @extend_schema( + summary="实例资源", + request=InstanceResourceSerializer, + responses={200: InstanceResourceListSerializer}, + description="获取实例内的资源信息", + ) def post(self, request): # 参数验证 serializer = InstanceResourceSerializer(data=request.data) if not serializer.is_valid(): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - instance_id = request.data['instance_id'] - resource_type = request.data['resource_type'] - db_name = request.data['db_name'] if 'db_name' in request.data.keys() else '' - schema_name = request.data['schema_name'] if 'schema_name' in request.data.keys() else '' - tb_name = request.data['tb_name'] if 'tb_name' in request.data.keys() else '' + instance_id = request.data["instance_id"] + resource_type = request.data["resource_type"] + db_name = request.data["db_name"] if "db_name" in request.data.keys() else "" + schema_name = ( + request.data["schema_name"] if "schema_name" in request.data.keys() else "" + ) + tb_name = request.data["tb_name"] if "tb_name" in request.data.keys() else "" instance = Instance.objects.get(pk=instance_id) try: # escape - db_name = MySQLdb.escape_string(db_name).decode('utf-8') - schema_name = MySQLdb.escape_string(schema_name).decode('utf-8') - tb_name = MySQLdb.escape_string(tb_name).decode('utf-8') + db_name = MySQLdb.escape_string(db_name).decode("utf-8") + schema_name = MySQLdb.escape_string(schema_name).decode("utf-8") + tb_name = MySQLdb.escape_string(tb_name).decode("utf-8") query_engine = get_engine(instance=instance) - if resource_type == 'database': + if resource_type == "database": resource = query_engine.get_all_databases() - elif resource_type == 'schema' and db_name: + elif resource_type == "schema" and db_name: resource = query_engine.get_all_schemas(db_name=db_name) - elif resource_type == 'table' and db_name: - resource = query_engine.get_all_tables(db_name=db_name, schema_name=schema_name) - elif resource_type == 'column' and db_name and tb_name: - resource = query_engine.get_all_columns_by_tb(db_name=db_name, tb_name=tb_name, schema_name=schema_name) + elif resource_type == "table" and db_name: + resource = query_engine.get_all_tables( + db_name=db_name, schema_name=schema_name + ) + elif resource_type == "column" and db_name and tb_name: + resource = query_engine.get_all_columns_by_tb( + db_name=db_name, tb_name=tb_name, schema_name=schema_name + ) else: - raise serializers.ValidationError({'errors': '不支持的资源类型或者参数不完整!'}) + raise serializers.ValidationError({"errors": "不支持的资源类型或者参数不完整!"}) except Exception as msg: - raise serializers.ValidationError({'errors': msg}) + raise serializers.ValidationError({"errors": msg}) else: if resource.error: - raise serializers.ValidationError({'errors': resource.error}) + raise serializers.ValidationError({"errors": resource.error}) else: - resource = {'count': len(resource.rows), - 'result': resource.rows} + resource = {"count": len(resource.rows), "result": resource.rows} serializer_obj = InstanceResourceListSerializer(resource) return Response(serializer_obj.data) diff --git a/sql_api/api_user.py b/sql_api/api_user.py index 5aec0a4151..d307066cc6 100644 --- a/sql_api/api_user.py +++ b/sql_api/api_user.py @@ -1,9 +1,17 @@ from rest_framework import views, generics, status, permissions from rest_framework.response import Response from drf_spectacular.utils import extend_schema -from .serializers import UserSerializer, UserDetailSerializer, GroupSerializer, \ - ResourceGroupSerializer, TwoFASerializer, UserAuthSerializer, TwoFAVerifySerializer, \ - TwoFASaveSerializer, TwoFAStateSerializer +from .serializers import ( + UserSerializer, + UserDetailSerializer, + GroupSerializer, + ResourceGroupSerializer, + TwoFASerializer, + UserAuthSerializer, + TwoFAVerifySerializer, + TwoFASaveSerializer, + TwoFAStateSerializer, +) from .pagination import CustomizedPagination from .permissions import IsOwner from .filters import UserFilter @@ -23,28 +31,31 @@ class UserList(generics.ListAPIView): """ 列出所有的user或者创建一个新的user """ + filterset_class = UserFilter pagination_class = CustomizedPagination serializer_class = UserSerializer - queryset = Users.objects.all().order_by('id') - - @extend_schema(summary="用户清单", - request=UserSerializer, - responses={200: UserSerializer}, - description="列出所有用户(过滤,分页)") + queryset = Users.objects.all().order_by("id") + + @extend_schema( + summary="用户清单", + request=UserSerializer, + responses={200: UserSerializer}, + description="列出所有用户(过滤,分页)", + ) def get(self, request): users = self.filter_queryset(self.queryset) page_user = self.paginate_queryset(queryset=users) serializer_obj = self.get_serializer(page_user, many=True) - data = { - 'data': serializer_obj.data - } + data = {"data": serializer_obj.data} return self.get_paginated_response(data) - @extend_schema(summary="创建用户", - request=UserSerializer, - responses={201: UserSerializer}, - description="创建一个用户") + @extend_schema( + summary="创建用户", + request=UserSerializer, + responses={201: UserSerializer}, + description="创建一个用户", + ) def post(self, request): serializer = UserSerializer(data=request.data) if serializer.is_valid(): @@ -57,6 +68,7 @@ class UserDetail(views.APIView): """ 用户操作 """ + serializer_class = UserDetailSerializer def get_object(self, pk): @@ -65,10 +77,12 @@ def get_object(self, pk): except Users.DoesNotExist: raise Http404 - @extend_schema(summary="更新用户", - request=UserDetailSerializer, - responses={200: UserDetailSerializer}, - description="更新一个用户") + @extend_schema( + summary="更新用户", + request=UserDetailSerializer, + responses={200: UserDetailSerializer}, + description="更新一个用户", + ) def put(self, request, pk): user = self.get_object(pk) serializer = UserDetailSerializer(user, data=request.data) @@ -77,8 +91,7 @@ def put(self, request, pk): return Response(serializer.data) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - @extend_schema(summary="删除用户", - description="删除一个用户") + @extend_schema(summary="删除用户", description="删除一个用户") def delete(self, request, pk): user = self.get_object(pk) user.delete() @@ -89,27 +102,30 @@ class GroupList(generics.ListAPIView): """ 列出所有的group或者创建一个新的group """ + pagination_class = CustomizedPagination serializer_class = GroupSerializer - queryset = Group.objects.all().order_by('id') - - @extend_schema(summary="用户组清单", - request=GroupSerializer, - responses={200: GroupSerializer}, - description="列出所有用户组(过滤,分页)") + queryset = Group.objects.all().order_by("id") + + @extend_schema( + summary="用户组清单", + request=GroupSerializer, + responses={200: GroupSerializer}, + description="列出所有用户组(过滤,分页)", + ) def get(self, request): groups = self.filter_queryset(self.queryset) page_groups = self.paginate_queryset(queryset=groups) serializer_obj = self.get_serializer(page_groups, many=True) - data = { - 'data': serializer_obj.data - } + data = {"data": serializer_obj.data} return self.get_paginated_response(data) - @extend_schema(summary="创建用户组", - request=GroupSerializer, - responses={201: GroupSerializer}, - description="创建一个用户组") + @extend_schema( + summary="创建用户组", + request=GroupSerializer, + responses={201: GroupSerializer}, + description="创建一个用户组", + ) def post(self, request): serializer = GroupSerializer(data=request.data) if serializer.is_valid(): @@ -122,6 +138,7 @@ class GroupDetail(views.APIView): """ 用户组操作 """ + serializer_class = GroupSerializer def get_object(self, pk): @@ -130,10 +147,12 @@ def get_object(self, pk): except Group.DoesNotExist: raise Http404 - @extend_schema(summary="更新用户组", - request=GroupSerializer, - responses={200: GroupSerializer}, - description="更新一个用户组") + @extend_schema( + summary="更新用户组", + request=GroupSerializer, + responses={200: GroupSerializer}, + description="更新一个用户组", + ) def put(self, request, pk): group = self.get_object(pk) serializer = GroupSerializer(group, data=request.data) @@ -142,8 +161,7 @@ def put(self, request, pk): return Response(serializer.data) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - @extend_schema(summary="删除用户组", - description="删除一个用户组") + @extend_schema(summary="删除用户组", description="删除一个用户组") def delete(self, request, pk): group = self.get_object(pk) group.delete() @@ -154,27 +172,30 @@ class ResourceGroupList(generics.ListAPIView): """ 列出所有的resourcegroup或者创建一个新的resourcegroup """ + pagination_class = CustomizedPagination serializer_class = ResourceGroupSerializer - queryset = ResourceGroup.objects.all().order_by('group_id') - - @extend_schema(summary="资源组清单", - request=ResourceGroupSerializer, - responses={200: ResourceGroupSerializer}, - description="列出所有资源组(过滤,分页)") + queryset = ResourceGroup.objects.all().order_by("group_id") + + @extend_schema( + summary="资源组清单", + request=ResourceGroupSerializer, + responses={200: ResourceGroupSerializer}, + description="列出所有资源组(过滤,分页)", + ) def get(self, request): groups = self.filter_queryset(self.queryset) page_groups = self.paginate_queryset(queryset=groups) serializer_obj = self.get_serializer(page_groups, many=True) - data = { - 'data': serializer_obj.data - } + data = {"data": serializer_obj.data} return self.get_paginated_response(data) - @extend_schema(summary="创建资源组", - request=ResourceGroupSerializer, - responses={201: ResourceGroupSerializer}, - description="创建一个资源组") + @extend_schema( + summary="创建资源组", + request=ResourceGroupSerializer, + responses={201: ResourceGroupSerializer}, + description="创建一个资源组", + ) def post(self, request): serializer = ResourceGroupSerializer(data=request.data) if serializer.is_valid(): @@ -187,6 +208,7 @@ class ResourceGroupDetail(views.APIView): """ 资源组操作 """ + serializer_class = ResourceGroupSerializer def get_object(self, pk): @@ -195,10 +217,12 @@ def get_object(self, pk): except ResourceGroup.DoesNotExist: raise Http404 - @extend_schema(summary="更新资源组", - request=ResourceGroupSerializer, - responses={200: ResourceGroupSerializer}, - description="更新一个资源组") + @extend_schema( + summary="更新资源组", + request=ResourceGroupSerializer, + responses={200: ResourceGroupSerializer}, + description="更新一个资源组", + ) def put(self, request, pk): group = self.get_object(pk) serializer = ResourceGroupSerializer(group, data=request.data) @@ -207,8 +231,7 @@ def put(self, request, pk): return Response(serializer.data) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - @extend_schema(summary="删除资源组", - description="删除一个资源组") + @extend_schema(summary="删除资源组", description="删除一个资源组") def delete(self, request, pk): group = self.get_object(pk) group.delete() @@ -219,24 +242,23 @@ class UserAuth(views.APIView): """ 用户认证校验 """ + permission_classes = [IsOwner] - @extend_schema(summary="用户认证校验", - request=UserAuthSerializer, - description="用户认证校验") + @extend_schema(summary="用户认证校验", request=UserAuthSerializer, description="用户认证校验") def post(self, request): # 参数验证 serializer = UserAuthSerializer(data=request.data) if not serializer.is_valid(): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - result = {'status': 0, 'msg': '认证成功'} - engineer = request.data['engineer'] - password = request.data['password'] + result = {"status": 0, "msg": "认证成功"} + engineer = request.data["engineer"] + password = request.data["password"] user = authenticate(username=engineer, password=password) if not user: - result = {'status': 1, 'msg': '用户名或密码错误!'} + result = {"status": 1, "msg": "用户名或密码错误!"} return Response(result) @@ -245,44 +267,43 @@ class TwoFA(views.APIView): """ 配置2fa """ + permission_classes = [permissions.AllowAny] - @extend_schema(summary="配置2fa", - request=TwoFASerializer, - description="配置2fa") + @extend_schema(summary="配置2fa", request=TwoFASerializer, description="配置2fa") def post(self, request): # 参数验证 serializer = TwoFASerializer(data=request.data) if not serializer.is_valid(): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - engineer = request.data['engineer'] - enable = request.data['enable'] - auth_type = request.data['auth_type'] + engineer = request.data["engineer"] + enable = request.data["enable"] + auth_type = request.data["auth_type"] user = Users.objects.get(username=engineer) - request_user = request.session.get('user') + request_user = request.session.get("user") if not request.user.is_authenticated: if request_user: if request_user != engineer: - return Response({'status': 1, 'msg': '登录用户与校验用户不一致!'}) + return Response({"status": 1, "msg": "登录用户与校验用户不一致!"}) else: - return Response({'status': 1, 'msg': '需先校验用户密码!'}) + return Response({"status": 1, "msg": "需先校验用户密码!"}) authenticator = get_authenticator(user=user, auth_type=auth_type) - if enable == 'true': - if auth_type == 'totp': + if enable == "true": + if auth_type == "totp": # 启用2fa - 先生成secret key result = authenticator.generate_key() - elif auth_type == 'sms': + elif auth_type == "sms": # 启用2fa - 先发送短信验证码 - phone = request.data['phone'] - otp = '{:06d}'.format(random.randint(0, 999999)) + phone = request.data["phone"] + otp = "{:06d}".format(random.randint(0, 999999)) result = authenticator.get_captcha(phone=phone, otp=otp) - if result['status'] == 0: - r = get_redis_connection('default') - data = {'otp': otp, 'update_time': int(time.time())} - r.set(f'captcha-{phone}', json.dumps(data), 300) + if result["status"] == 0: + r = get_redis_connection("default") + data = {"otp": otp, "update_time": int(time.time())} + r.set(f"captcha-{phone}", json.dumps(data), 300) else: # 启用2fa result = authenticator.enable() @@ -296,23 +317,28 @@ class TwoFAState(views.APIView): """ 查询用户2fa配置情况 """ + permission_classes = [IsOwner] - @extend_schema(summary="查询2fa配置情况", - request=TwoFAStateSerializer, - description="查询2fa配置情况") + @extend_schema( + summary="查询2fa配置情况", request=TwoFAStateSerializer, description="查询2fa配置情况" + ) def post(self, request): # 参数验证 serializer = TwoFAStateSerializer(data=request.data) if not serializer.is_valid(): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - result = {'status': 0, 'msg': 'ok', 'data': {}} - engineer = request.data['engineer'] + result = {"status": 0, "msg": "ok", "data": {}} + engineer = request.data["engineer"] user = Users.objects.get(username=engineer) configs = TwoFactorAuthConfig.objects.filter(user=user) - result['data']['totp'] = 'enabled' if configs.filter(auth_type='totp') else 'disabled' - result['data']['sms'] = 'enabled' if configs.filter(auth_type='sms') else 'disabled' + result["data"]["totp"] = ( + "enabled" if configs.filter(auth_type="totp") else "disabled" + ) + result["data"]["sms"] = ( + "enabled" if configs.filter(auth_type="sms") else "disabled" + ) return Response(result) @@ -321,25 +347,26 @@ class TwoFASave(views.APIView): """ 保存2fa配置(TOTP) """ + permission_classes = [IsOwner] - @extend_schema(summary="保存2fa配置", - request=TwoFASaveSerializer, - description="保存2fa配置") + @extend_schema( + summary="保存2fa配置", request=TwoFASaveSerializer, description="保存2fa配置" + ) def post(self, request): # 参数验证 serializer = TwoFASaveSerializer(data=request.data) if not serializer.is_valid(): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - engineer = request.data['engineer'] - auth_type = request.data['auth_type'] - key = request.data['key'] if 'key' in request.data.keys() else None - phone = request.data['phone'] if 'phone' in request.data.keys() else None + engineer = request.data["engineer"] + auth_type = request.data["auth_type"] + key = request.data["key"] if "key" in request.data.keys() else None + phone = request.data["phone"] if "phone" in request.data.keys() else None user = Users.objects.get(username=engineer) authenticator = get_authenticator(user=user, auth_type=auth_type) - if auth_type == 'sms': + if auth_type == "sms": result = authenticator.save(phone) else: result = authenticator.save(key) @@ -351,46 +378,47 @@ class TwoFAVerify(views.APIView): """ 检验2fa密码 """ + permission_classes = [permissions.AllowAny] - @extend_schema(summary="检验2fa密码", - request=TwoFAVerifySerializer, - description="检验2fa密码") + @extend_schema( + summary="检验2fa密码", request=TwoFAVerifySerializer, description="检验2fa密码" + ) def post(self, request): # 参数验证 serializer = TwoFAVerifySerializer(data=request.data) if not serializer.is_valid(): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - engineer = request.data['engineer'] - otp = request.data['otp'] - key = request.data['key'] if 'key' in request.data.keys() else None - phone = request.data['phone'] if 'phone' in request.data.keys() else None + engineer = request.data["engineer"] + otp = request.data["otp"] + key = request.data["key"] if "key" in request.data.keys() else None + phone = request.data["phone"] if "phone" in request.data.keys() else None user = Users.objects.get(username=engineer) - request_user = request.session.get('user') + request_user = request.session.get("user") if not request.user.is_authenticated: if request_user: if request_user != engineer: - return Response({'status': 1, 'msg': '登录用户与校验用户不一致!'}) + return Response({"status": 1, "msg": "登录用户与校验用户不一致!"}) else: - return Response({'status': 1, 'msg': '需先校验用户密码!'}) + return Response({"status": 1, "msg": "需先校验用户密码!"}) twofa_config = TwoFactorAuthConfig.objects.filter(user=user) if not twofa_config: if not key: - return Response({'status': 1, 'msg': '用户未配置2FA!'}) + return Response({"status": 1, "msg": "用户未配置2FA!"}) - auth_type = request.data['auth_type'] + auth_type = request.data["auth_type"] authenticator = get_authenticator(user=user, auth_type=auth_type) - if auth_type == 'sms': + if auth_type == "sms": result = authenticator.verify(otp, phone) else: result = authenticator.verify(otp, key) # 校验通过后自动登录,刷新expire_date - if result['status'] == 0 and not request.user.is_authenticated: - login(request, user, backend='django.contrib.auth.backends.ModelBackend') + if result["status"] == 0 and not request.user.is_authenticated: + login(request, user, backend="django.contrib.auth.backends.ModelBackend") request.session.set_expiry(settings.SESSION_COOKIE_AGE) return Response(result) diff --git a/sql_api/api_workflow.py b/sql_api/api_workflow.py index 3575fcc6e6..7bff6b5798 100644 --- a/sql_api/api_workflow.py +++ b/sql_api/api_workflow.py @@ -1,12 +1,28 @@ from rest_framework import views, generics, status, serializers from rest_framework.response import Response from drf_spectacular.utils import extend_schema -from .serializers import WorkflowContentSerializer, ExecuteCheckSerializer, \ - ExecuteCheckResultSerializer, WorkflowAuditSerializer, WorkflowAuditListSerializer, \ - WorkflowLogSerializer, WorkflowLogListSerializer, AuditWorkflowSerializer, ExecuteWorkflowSerializer +from .serializers import ( + WorkflowContentSerializer, + ExecuteCheckSerializer, + ExecuteCheckResultSerializer, + WorkflowAuditSerializer, + WorkflowAuditListSerializer, + WorkflowLogSerializer, + WorkflowLogListSerializer, + AuditWorkflowSerializer, + ExecuteWorkflowSerializer, +) from .pagination import CustomizedPagination from .filters import WorkflowFilter, WorkflowAuditFilter -from sql.models import SqlWorkflow, SqlWorkflowContent, Instance, WorkflowAudit, Users, WorkflowLog, ArchiveConfig +from sql.models import ( + SqlWorkflow, + SqlWorkflowContent, + Instance, + WorkflowAudit, + Users, + WorkflowLog, + ArchiveConfig, +) from sql.utils.sql_review import can_cancel, can_execute, on_correct_time_period from sql.utils.resource_group import user_groups from sql.utils.workflow_audit import Audit @@ -23,24 +39,27 @@ import datetime import logging -logger = logging.getLogger('default') +logger = logging.getLogger("default") class ExecuteCheck(views.APIView): - @extend_schema(summary="SQL检查", - request=ExecuteCheckSerializer, - responses={200: ExecuteCheckResultSerializer}, - description="对提供的SQL进行语法检查") + @extend_schema( + summary="SQL检查", + request=ExecuteCheckSerializer, + responses={200: ExecuteCheckResultSerializer}, + description="对提供的SQL进行语法检查", + ) def post(self, request): # 参数验证 serializer = ExecuteCheckSerializer(data=request.data) if not serializer.is_valid(): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - instance = Instance.objects.get(pk=request.data['instance_id']) + instance = Instance.objects.get(pk=request.data["instance_id"]) check_engine = get_engine(instance=instance) - check_result = check_engine.execute_check(db_name=request.data['db_name'], - sql=request.data['full_sql'].strip()) + check_result = check_engine.execute_check( + db_name=request.data["db_name"], sql=request.data["full_sql"].strip() + ) review_result_list = [] for r in check_result.rows: review_result_list += [r.__dict__] @@ -53,28 +72,33 @@ class WorkflowList(generics.ListAPIView): """ 列出所有的workflow或者提交一条新的workflow """ + filterset_class = WorkflowFilter pagination_class = CustomizedPagination serializer_class = WorkflowContentSerializer - queryset = SqlWorkflowContent.objects.all().select_related('workflow').order_by('-id') - - @extend_schema(summary="SQL上线工单清单", - request=WorkflowContentSerializer, - responses={200: WorkflowContentSerializer}, - description="列出所有SQL上线工单(过滤,分页)") + queryset = ( + SqlWorkflowContent.objects.all().select_related("workflow").order_by("-id") + ) + + @extend_schema( + summary="SQL上线工单清单", + request=WorkflowContentSerializer, + responses={200: WorkflowContentSerializer}, + description="列出所有SQL上线工单(过滤,分页)", + ) def get(self, request): workflows = self.filter_queryset(self.queryset) page_wf = self.paginate_queryset(queryset=workflows) serializer_obj = self.get_serializer(page_wf, many=True) - data = { - 'data': serializer_obj.data - } + data = {"data": serializer_obj.data} return self.get_paginated_response(data) - @extend_schema(summary="提交SQL上线工单", - request=WorkflowContentSerializer, - responses={201: WorkflowContentSerializer}, - description="提交一条SQL上线工单") + @extend_schema( + summary="提交SQL上线工单", + request=WorkflowContentSerializer, + responses={201: WorkflowContentSerializer}, + description="提交一条SQL上线工单", + ) def post(self, request): serializer = WorkflowContentSerializer(data=request.data) if serializer.is_valid(): @@ -87,20 +111,24 @@ class WorkflowAuditList(generics.ListAPIView): """ 列出指定用户当前待自己审核的工单 """ + filterset_class = WorkflowAuditFilter pagination_class = CustomizedPagination serializer_class = WorkflowAuditListSerializer queryset = WorkflowAudit.objects.filter( - current_status=WorkflowDict.workflow_status['audit_wait']).order_by('-audit_id') + current_status=WorkflowDict.workflow_status["audit_wait"] + ).order_by("-audit_id") @extend_schema(exclude=True) def get(self, request): return Response({"detail": "方法 “GET” 不被允许。"}) - @extend_schema(summary="待审核清单", - request=WorkflowAuditSerializer, - responses={200: WorkflowAuditListSerializer}, - description="列出指定用户待审核清单(过滤,分页)") + @extend_schema( + summary="待审核清单", + request=WorkflowAuditSerializer, + responses={200: WorkflowAuditListSerializer}, + description="列出指定用户待审核清单(过滤,分页)", + ) def post(self, request): # 参数验证 serializer = WorkflowAuditSerializer(data=request.data) @@ -108,7 +136,7 @@ def post(self, request): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) # 先获取用户所在资源组列表 - user = Users.objects.get(username=request.data['engineer']) + user = Users.objects.get(username=request.data["engineer"]) group_list = user_groups(user) group_ids = [group.group_id for group in group_list] @@ -118,15 +146,15 @@ def post(self, request): else: auth_group_ids = [group.id for group in Group.objects.filter(user=user)] - self.queryset = self.queryset.filter(current_status=WorkflowDict.workflow_status['audit_wait'], - group_id__in=group_ids, - current_audit__in=auth_group_ids) + self.queryset = self.queryset.filter( + current_status=WorkflowDict.workflow_status["audit_wait"], + group_id__in=group_ids, + current_audit__in=auth_group_ids, + ) audit = self.filter_queryset(self.queryset) page_audit = self.paginate_queryset(queryset=audit) serializer_obj = self.get_serializer(page_audit, many=True) - data = { - 'data': serializer_obj.data - } + data = {"data": serializer_obj.data} return self.get_paginated_response(data) @@ -135,28 +163,28 @@ class AuditWorkflow(views.APIView): 审核workflow,包括查询权限申请、SQL上线申请、数据归档申请 """ - @extend_schema(summary="审核工单", - request=AuditWorkflowSerializer, - description="审核一条工单(通过或终止)") + @extend_schema( + summary="审核工单", request=AuditWorkflowSerializer, description="审核一条工单(通过或终止)" + ) def post(self, request): # 参数验证 serializer = AuditWorkflowSerializer(data=request.data) if not serializer.is_valid(): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - audit_type = request.data['audit_type'] - workflow_type = request.data['workflow_type'] - workflow_id = request.data['workflow_id'] - audit_remark = request.data['audit_remark'] - engineer = request.data['engineer'] + audit_type = request.data["audit_type"] + workflow_type = request.data["workflow_type"] + workflow_id = request.data["workflow_id"] + audit_remark = request.data["audit_remark"] + engineer = request.data["engineer"] user = Users.objects.get(username=engineer) # 审核查询权限申请 if workflow_type == 1: - audit_status = 1 if audit_type == 'pass' else 2 + audit_status = 1 if audit_type == "pass" else 2 if audit_remark is None: - audit_remark = '' + audit_remark = "" if Audit.can_review(user, workflow_id, workflow_type) is False: raise serializers.ValidationError({"errors": "你无权操作当前工单!"}) @@ -164,30 +192,49 @@ def post(self, request): # 使用事务保持数据一致性 try: with transaction.atomic(): - audit_id = Audit.detail_by_workflow_id(workflow_id=workflow_id, - workflow_type=WorkflowDict.workflow_type['query']).audit_id + audit_id = Audit.detail_by_workflow_id( + workflow_id=workflow_id, + workflow_type=WorkflowDict.workflow_type["query"], + ).audit_id # 调用工作流接口审核 - audit_result = Audit.audit(audit_id, audit_status, user.username, audit_remark) + audit_result = Audit.audit( + audit_id, audit_status, user.username, audit_remark + ) # 按照审核结果更新业务表审核状态 audit_detail = Audit.detail(audit_id) - if audit_detail.workflow_type == WorkflowDict.workflow_type['query']: + if ( + audit_detail.workflow_type + == WorkflowDict.workflow_type["query"] + ): # 更新业务表审核状态,插入权限信息 - _query_apply_audit_call_back(audit_detail.workflow_id, audit_result['data']['workflow_status']) + _query_apply_audit_call_back( + audit_detail.workflow_id, + audit_result["data"]["workflow_status"], + ) except Exception as msg: logger.error(traceback.format_exc()) - raise serializers.ValidationError({'errors': msg}) + raise serializers.ValidationError({"errors": msg}) else: # 消息通知 - async_task(notify_for_audit, audit_id=audit_id, audit_remark=audit_remark, timeout=60, - task_name=f'query-priv-audit-{workflow_id}') - return Response({'msg': 'passed'}) if audit_type == 'pass' else Response({'msg': 'canceled'}) + async_task( + notify_for_audit, + audit_id=audit_id, + audit_remark=audit_remark, + timeout=60, + task_name=f"query-priv-audit-{workflow_id}", + ) + return ( + Response({"msg": "passed"}) + if audit_type == "pass" + else Response({"msg": "canceled"}) + ) # 审核SQL上线申请 elif workflow_type == 2: # SQL上线申请通过 - if audit_type == 'pass': + if audit_type == "pass": # 权限验证 if Audit.can_review(user, workflow_id, workflow_type) is False: raise serializers.ValidationError({"errors": "你无权操作当前工单!"}) @@ -196,30 +243,48 @@ def post(self, request): try: with transaction.atomic(): # 调用工作流接口审核 - audit_id = Audit.detail_by_workflow_id(workflow_id=workflow_id, - workflow_type=WorkflowDict.workflow_type[ - 'sqlreview']).audit_id - audit_result = Audit.audit(audit_id, WorkflowDict.workflow_status['audit_success'], - user.username, audit_remark) + audit_id = Audit.detail_by_workflow_id( + workflow_id=workflow_id, + workflow_type=WorkflowDict.workflow_type["sqlreview"], + ).audit_id + audit_result = Audit.audit( + audit_id, + WorkflowDict.workflow_status["audit_success"], + user.username, + audit_remark, + ) # 按照审核结果更新业务表审核状态 - if audit_result['data']['workflow_status'] == WorkflowDict.workflow_status['audit_success']: + if ( + audit_result["data"]["workflow_status"] + == WorkflowDict.workflow_status["audit_success"] + ): # 将流程状态修改为审核通过 - SqlWorkflow(id=workflow_id, status='workflow_review_pass').save(update_fields=['status']) + SqlWorkflow( + id=workflow_id, status="workflow_review_pass" + ).save(update_fields=["status"]) except Exception as msg: logger.error(traceback.format_exc()) - raise serializers.ValidationError({'errors': msg}) + raise serializers.ValidationError({"errors": msg}) else: # 开启了Pass阶段通知参数才发送消息通知 sys_config = SysConfig() - is_notified = 'Pass' in sys_config.get('notify_phase_control').split(',') \ - if sys_config.get('notify_phase_control') else True + is_notified = ( + "Pass" in sys_config.get("notify_phase_control").split(",") + if sys_config.get("notify_phase_control") + else True + ) if is_notified: - async_task(notify_for_audit, audit_id=audit_id, audit_remark=audit_remark, timeout=60, - task_name=f'sqlreview-pass-{workflow_id}') - return Response({'msg': 'passed'}) + async_task( + notify_for_audit, + audit_id=audit_id, + audit_remark=audit_remark, + timeout=60, + task_name=f"sqlreview-pass-{workflow_id}", + ) + return Response({"msg": "passed"}) # SQL上线申请驳回/取消 - elif audit_type == 'cancel': + elif audit_type == "cancel": workflow_detail = SqlWorkflow.objects.get(id=workflow_id) if audit_remark is None: @@ -232,73 +297,93 @@ def post(self, request): try: with transaction.atomic(): # 调用工作流接口取消或者驳回 - audit_id = Audit.detail_by_workflow_id(workflow_id=workflow_id, - workflow_type=WorkflowDict.workflow_type[ - 'sqlreview']).audit_id + audit_id = Audit.detail_by_workflow_id( + workflow_id=workflow_id, + workflow_type=WorkflowDict.workflow_type["sqlreview"], + ).audit_id # 仅待审核的需要调用工作流,审核通过的不需要 - if workflow_detail.status != 'workflow_manreviewing': + if workflow_detail.status != "workflow_manreviewing": # 增加工单日志 if user.username == workflow_detail.engineer: - Audit.add_log(audit_id=audit_id, - operation_type=3, - operation_type_desc='取消执行', - operation_info="取消原因:{}".format(audit_remark), - operator=user.username, - operator_display=user.display - ) + Audit.add_log( + audit_id=audit_id, + operation_type=3, + operation_type_desc="取消执行", + operation_info="取消原因:{}".format(audit_remark), + operator=user.username, + operator_display=user.display, + ) else: - Audit.add_log(audit_id=audit_id, - operation_type=2, - operation_type_desc='审批不通过', - operation_info="审批备注:{}".format(audit_remark), - operator=user.username, - operator_display=user.display - ) + Audit.add_log( + audit_id=audit_id, + operation_type=2, + operation_type_desc="审批不通过", + operation_info="审批备注:{}".format(audit_remark), + operator=user.username, + operator_display=user.display, + ) else: if user.username == workflow_detail.engineer: - Audit.audit(audit_id, - WorkflowDict.workflow_status['audit_abort'], - user.username, audit_remark) + Audit.audit( + audit_id, + WorkflowDict.workflow_status["audit_abort"], + user.username, + audit_remark, + ) # 非提交人需要校验审核权限 - elif user.has_perm('sql.sql_review'): - Audit.audit(audit_id, - WorkflowDict.workflow_status['audit_reject'], - user.username, audit_remark) + elif user.has_perm("sql.sql_review"): + Audit.audit( + audit_id, + WorkflowDict.workflow_status["audit_reject"], + user.username, + audit_remark, + ) else: - raise serializers.ValidationError({"errors": "Permission Denied"}) + raise serializers.ValidationError( + {"errors": "Permission Denied"} + ) # 删除定时执行task - if workflow_detail.status == 'workflow_timingtask': + if workflow_detail.status == "workflow_timingtask": schedule_name = f"sqlreview-timing-{workflow_id}" del_schedule(schedule_name) # 将流程状态修改为人工终止流程 - workflow_detail.status = 'workflow_abort' + workflow_detail.status = "workflow_abort" workflow_detail.save() except Exception as msg: logger.error(f"取消工单报错,错误信息:{traceback.format_exc()}") - raise serializers.ValidationError({'errors': msg}) + raise serializers.ValidationError({"errors": msg}) else: # 发送取消、驳回通知,开启了Cancel阶段通知参数才发送消息通知 sys_config = SysConfig() - is_notified = 'Cancel' in sys_config.get('notify_phase_control').split(',') \ - if sys_config.get('notify_phase_control') else True + is_notified = ( + "Cancel" in sys_config.get("notify_phase_control").split(",") + if sys_config.get("notify_phase_control") + else True + ) if is_notified: - audit_detail = Audit.detail_by_workflow_id(workflow_id=workflow_id, - workflow_type=WorkflowDict.workflow_type[ - 'sqlreview']) + audit_detail = Audit.detail_by_workflow_id( + workflow_id=workflow_id, + workflow_type=WorkflowDict.workflow_type["sqlreview"], + ) if audit_detail.current_status in ( - WorkflowDict.workflow_status['audit_abort'], - WorkflowDict.workflow_status['audit_reject']): - async_task(notify_for_audit, audit_id=audit_detail.audit_id, audit_remark=audit_remark, - timeout=60, - task_name=f'sqlreview-cancel-{workflow_id}') - return Response({'msg': 'canceled'}) + WorkflowDict.workflow_status["audit_abort"], + WorkflowDict.workflow_status["audit_reject"], + ): + async_task( + notify_for_audit, + audit_id=audit_detail.audit_id, + audit_remark=audit_remark, + timeout=60, + task_name=f"sqlreview-cancel-{workflow_id}", + ) + return Response({"msg": "canceled"}) # 审核数据归档申请 elif workflow_type == 3: - audit_status = 1 if audit_type == 'pass' else 2 + audit_status = 1 if audit_type == "pass" else 2 if audit_remark is None: - audit_remark = '' + audit_remark = "" if Audit.can_review(user, workflow_id, workflow_type) is False: raise serializers.ValidationError({"errors": "你无权操作当前工单!"}) @@ -306,24 +391,39 @@ def post(self, request): # 使用事务保持数据一致性 try: with transaction.atomic(): - audit_id = Audit.detail_by_workflow_id(workflow_id=workflow_id, - workflow_type=WorkflowDict.workflow_type['archive']).audit_id + audit_id = Audit.detail_by_workflow_id( + workflow_id=workflow_id, + workflow_type=WorkflowDict.workflow_type["archive"], + ).audit_id # 调用工作流插入审核信息,更新业务表审核状态 - audit_status = Audit.audit(audit_id, audit_status, user.username, audit_remark)['data'][ - 'workflow_status'] - ArchiveConfig(id=workflow_id, - status=audit_status, - state=True if audit_status == WorkflowDict.workflow_status['audit_success'] else False - ).save(update_fields=['status', 'state']) + audit_status = Audit.audit( + audit_id, audit_status, user.username, audit_remark + )["data"]["workflow_status"] + ArchiveConfig( + id=workflow_id, + status=audit_status, + state=True + if audit_status == WorkflowDict.workflow_status["audit_success"] + else False, + ).save(update_fields=["status", "state"]) except Exception as msg: logger.error(traceback.format_exc()) - raise serializers.ValidationError({'errors': msg}) + raise serializers.ValidationError({"errors": msg}) else: # 消息通知 - async_task(notify_for_audit, audit_id=audit_id, audit_remark=audit_remark, timeout=60, - task_name=f'archive-audit-{workflow_id}') - return Response({'msg': 'passed'}) if audit_type == 'pass' else Response({'msg': 'canceled'}) + async_task( + notify_for_audit, + audit_id=audit_id, + audit_remark=audit_remark, + timeout=60, + task_name=f"archive-audit-{workflow_id}", + ) + return ( + Response({"msg": "passed"}) + if audit_type == "pass" + else Response({"msg": "canceled"}) + ) class ExecuteWorkflow(views.APIView): @@ -331,86 +431,116 @@ class ExecuteWorkflow(views.APIView): 执行workflow,包括SQL上线工单、数据归档工单 """ - @extend_schema(summary="执行工单", - request=ExecuteWorkflowSerializer, - description="执行一条工单") + @extend_schema( + summary="执行工单", request=ExecuteWorkflowSerializer, description="执行一条工单" + ) def post(self, request): # 参数验证 serializer = ExecuteWorkflowSerializer(data=request.data) if not serializer.is_valid(): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - workflow_type = request.data['workflow_type'] - workflow_id = request.data['workflow_id'] + workflow_type = request.data["workflow_type"] + workflow_id = request.data["workflow_id"] # 执行SQL上线工单 if workflow_type == 2: - mode = request.data['mode'] - engineer = request.data['engineer'] + mode = request.data["mode"] + engineer = request.data["engineer"] user = Users.objects.get(username=engineer) # 校验多个权限 - if not (user.has_perm('sql.sql_execute') or user.has_perm('sql.sql_execute_for_resource_group')): + if not ( + user.has_perm("sql.sql_execute") + or user.has_perm("sql.sql_execute_for_resource_group") + ): raise serializers.ValidationError({"errors": "你无权执行当前工单!"}) if can_execute(user, workflow_id) is False: raise serializers.ValidationError({"errors": "你无权执行当前工单!"}) if on_correct_time_period(workflow_id) is False: - raise serializers.ValidationError({"errors": "不在可执行时间范围内,如果需要修改执行时间请重新提交工单!"}) + raise serializers.ValidationError( + {"errors": "不在可执行时间范围内,如果需要修改执行时间请重新提交工单!"} + ) # 获取审核信息 - audit_id = Audit.detail_by_workflow_id(workflow_id=workflow_id, - workflow_type=WorkflowDict.workflow_type['sqlreview']).audit_id + audit_id = Audit.detail_by_workflow_id( + workflow_id=workflow_id, + workflow_type=WorkflowDict.workflow_type["sqlreview"], + ).audit_id # 交由系统执行 if mode == "auto": # 修改工单状态为排队中 - SqlWorkflow(id=workflow_id, status="workflow_queuing").save(update_fields=['status']) + SqlWorkflow(id=workflow_id, status="workflow_queuing").save( + update_fields=["status"] + ) # 删除定时执行任务 schedule_name = f"sqlreview-timing-{workflow_id}" del_schedule(schedule_name) # 加入执行队列 - async_task('sql.utils.execute_sql.execute', workflow_id, user, - hook='sql.utils.execute_sql.execute_callback', - timeout=-1, task_name=f'sqlreview-execute-{workflow_id}') + async_task( + "sql.utils.execute_sql.execute", + workflow_id, + user, + hook="sql.utils.execute_sql.execute_callback", + timeout=-1, + task_name=f"sqlreview-execute-{workflow_id}", + ) # 增加工单日志 - Audit.add_log(audit_id=audit_id, - operation_type=5, - operation_type_desc='执行工单', - operation_info='工单执行排队中', - operator=user.username, - operator_display=user.display) + Audit.add_log( + audit_id=audit_id, + operation_type=5, + operation_type_desc="执行工单", + operation_info="工单执行排队中", + operator=user.username, + operator_display=user.display, + ) # 线下手工执行 elif mode == "manual": # 将流程状态修改为执行结束 - SqlWorkflow(id=workflow_id, status="workflow_finish", finish_time=datetime.datetime.now() - ).save(update_fields=['status', 'finish_time']) + SqlWorkflow( + id=workflow_id, + status="workflow_finish", + finish_time=datetime.datetime.now(), + ).save(update_fields=["status", "finish_time"]) # 增加工单日志 - Audit.add_log(audit_id=audit_id, - operation_type=6, - operation_type_desc='手工工单', - operation_info='确认手工执行结束', - operator=user.username, - operator_display=user.display) + Audit.add_log( + audit_id=audit_id, + operation_type=6, + operation_type_desc="手工工单", + operation_info="确认手工执行结束", + operator=user.username, + operator_display=user.display, + ) # 开启了Execute阶段通知参数才发送消息通知 sys_config = SysConfig() - is_notified = 'Execute' in sys_config.get('notify_phase_control').split(',') \ - if sys_config.get('notify_phase_control') else True + is_notified = ( + "Execute" in sys_config.get("notify_phase_control").split(",") + if sys_config.get("notify_phase_control") + else True + ) if is_notified: notify_for_execute(SqlWorkflow.objects.get(id=workflow_id)) # 执行数据归档工单 elif workflow_type == 3: - async_task('sql.archiver.archive', workflow_id, timeout=-1, task_name=f'archive-{workflow_id}') + async_task( + "sql.archiver.archive", + workflow_id, + timeout=-1, + task_name=f"archive-{workflow_id}", + ) - return Response({'msg': '开始执行,执行结果请到工单详情页查看'}) + return Response({"msg": "开始执行,执行结果请到工单详情页查看"}) class WorkflowLogList(generics.ListAPIView): """ 获取某个工单的日志 """ + pagination_class = CustomizedPagination serializer_class = WorkflowLogListSerializer queryset = WorkflowLog.objects.all() @@ -419,22 +549,24 @@ class WorkflowLogList(generics.ListAPIView): def get(self, request): return Response({"detail": "方法 “GET” 不被允许。"}) - @extend_schema(summary="工单日志", - request=WorkflowLogSerializer, - responses={200: WorkflowLogListSerializer}, - description="获取某个工单的日志(分页)") + @extend_schema( + summary="工单日志", + request=WorkflowLogSerializer, + responses={200: WorkflowLogListSerializer}, + description="获取某个工单的日志(分页)", + ) def post(self, request): # 参数验证 serializer = WorkflowLogSerializer(data=request.data) if not serializer.is_valid(): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - audit_id = WorkflowAudit.objects.get(workflow_id=request.data['workflow_id'], - workflow_type=request.data['workflow_type']).audit_id - workflow_logs = self.queryset.filter(audit_id=audit_id).order_by('-id') + audit_id = WorkflowAudit.objects.get( + workflow_id=request.data["workflow_id"], + workflow_type=request.data["workflow_type"], + ).audit_id + workflow_logs = self.queryset.filter(audit_id=audit_id).order_by("-id") page_log = self.paginate_queryset(queryset=workflow_logs) serializer_obj = self.get_serializer(page_log, many=True) - data = { - 'data': serializer_obj.data - } + data = {"data": serializer_obj.data} return self.get_paginated_response(data) diff --git a/sql_api/apps.py b/sql_api/apps.py index 8cec48388f..e3bdbe1fee 100644 --- a/sql_api/apps.py +++ b/sql_api/apps.py @@ -2,4 +2,4 @@ class SqlApi2Config(AppConfig): - name = 'sql_api' + name = "sql_api" diff --git a/sql_api/filters.py b/sql_api/filters.py index db71b0f630..11e4bbc123 100644 --- a/sql_api/filters.py +++ b/sql_api/filters.py @@ -3,48 +3,44 @@ class UserFilter(filters.FilterSet): - class Meta: model = Users fields = { - 'id': ['exact'], - 'username': ['exact'], + "id": ["exact"], + "username": ["exact"], } class InstanceFilter(filters.FilterSet): - class Meta: model = Instance fields = { - 'id': ['exact'], - 'instance_name': ['icontains'], - 'db_type': ['exact'], - 'host': ['exact'], + "id": ["exact"], + "instance_name": ["icontains"], + "db_type": ["exact"], + "host": ["exact"], } class WorkflowFilter(filters.FilterSet): - class Meta: model = SqlWorkflowContent fields = { - 'id': ['exact'], - 'workflow_id': ['exact'], - 'workflow__workflow_name': ['icontains'], - 'workflow__instance_id': ['exact'], - 'workflow__db_name': ['exact'], - 'workflow__engineer': ['exact'], - 'workflow__status': ['exact'], - 'workflow__create_time': ['lt', 'gte'], + "id": ["exact"], + "workflow_id": ["exact"], + "workflow__workflow_name": ["icontains"], + "workflow__instance_id": ["exact"], + "workflow__db_name": ["exact"], + "workflow__engineer": ["exact"], + "workflow__status": ["exact"], + "workflow__create_time": ["lt", "gte"], } class WorkflowAuditFilter(filters.FilterSet): - class Meta: model = WorkflowAudit fields = { - 'workflow_title': ['icontains'], - 'workflow_type': ['exact'], + "workflow_title": ["icontains"], + "workflow_type": ["exact"], } diff --git a/sql_api/pagination.py b/sql_api/pagination.py index 510015965c..b262fce9dc 100644 --- a/sql_api/pagination.py +++ b/sql_api/pagination.py @@ -8,15 +8,24 @@ class CustomizedPagination(PageNumberPagination): """ 自定义分页器 """ - page_size = settings.REST_FRAMEWORK['PAGE_SIZE'] if 'PAGE_SIZE' in settings.REST_FRAMEWORK.keys() else 20 - page_query_param = 'page' - page_size_query_param = 'size' + + page_size = ( + settings.REST_FRAMEWORK["PAGE_SIZE"] + if "PAGE_SIZE" in settings.REST_FRAMEWORK.keys() + else 20 + ) + page_query_param = "page" + page_size_query_param = "size" max_page_size = None def get_paginated_response(self, data): - return Response(OrderedDict([ - ('count', data.get('count', self.page.paginator.count)), - ('next', self.get_next_link()), - ('previous', self.get_previous_link()), - ('results', data.get('data', data)) - ])) + return Response( + OrderedDict( + [ + ("count", data.get("count", self.page.paginator.count)), + ("next", self.get_next_link()), + ("previous", self.get_previous_link()), + ("results", data.get("data", data)), + ] + ) + ) diff --git a/sql_api/permissions.py b/sql_api/permissions.py index ff017a2d49..bff4e508b1 100644 --- a/sql_api/permissions.py +++ b/sql_api/permissions.py @@ -6,9 +6,10 @@ class IsInUserWhitelist(permissions.BasePermission): """ 自定义权限,只允许白名单用户调用api """ + def has_permission(self, request, view): - config = SysConfig().get('api_user_whitelist') - user_list = config.split(',') if config else [] + config = SysConfig().get("api_user_whitelist") + user_list = config.split(",") if config else [] api_user_whitelist = [int(uid) for uid in user_list] # 只有在api_user_whitelist参数中的用户才有权限 @@ -19,9 +20,10 @@ class IsOwner(permissions.BasePermission): """ 当参数engineer与请求用户一致时才有权限 """ + def has_permission(self, request, view): try: - engineer = request.data['engineer'] + engineer = request.data["engineer"] except KeyError as e: return False diff --git a/sql_api/serializers.py b/sql_api/serializers.py index 874ef81352..93045ad673 100644 --- a/sql_api/serializers.py +++ b/sql_api/serializers.py @@ -1,6 +1,16 @@ from rest_framework import serializers -from sql.models import Users, Instance, Tunnel, AliyunRdsConfig, CloudAccessKey, \ - SqlWorkflow, SqlWorkflowContent, ResourceGroup, WorkflowAudit, WorkflowLog +from sql.models import ( + Users, + Instance, + Tunnel, + AliyunRdsConfig, + CloudAccessKey, + SqlWorkflow, + SqlWorkflowContent, + ResourceGroup, + WorkflowAudit, + WorkflowLog, +) from django.contrib.auth.models import Group from django.contrib.auth.password_validation import validate_password from django.core.exceptions import ValidationError @@ -15,7 +25,7 @@ import traceback import logging -logger = logging.getLogger('default') +logger = logging.getLogger("default") class UserSerializer(serializers.ModelSerializer): @@ -35,20 +45,13 @@ def validate_password(self, password): class Meta: model = Users fields = "__all__" - extra_kwargs = { - 'password': { - 'write_only': True - }, - 'display': { - 'required': True - } - } + extra_kwargs = {"password": {"write_only": True}, "display": {"required": True}} class UserDetailSerializer(serializers.ModelSerializer): def update(self, instance, validated_data): for attr, value in validated_data.items(): - if attr == 'password': + if attr == "password": instance.set_password(value) else: setattr(instance, attr, value) @@ -66,64 +69,58 @@ class Meta: model = Users fields = "__all__" extra_kwargs = { - 'password': { - 'write_only': True, - 'required': False - }, - 'username': { - 'required': False - } + "password": {"write_only": True, "required": False}, + "username": {"required": False}, } class GroupSerializer(serializers.ModelSerializer): - class Meta: model = Group fields = "__all__" class ResourceGroupSerializer(serializers.ModelSerializer): - class Meta: model = ResourceGroup fields = "__all__" class UserAuthSerializer(serializers.Serializer): - engineer = serializers.CharField(label='用户名') - password = serializers.CharField(label='密码') + engineer = serializers.CharField(label="用户名") + password = serializers.CharField(label="密码") class TwoFASerializer(serializers.Serializer): - engineer = serializers.CharField(label='用户名') - enable = serializers.ChoiceField(choices=['true', 'false'], label='启用or禁用') - phone = serializers.CharField(required=False, label='手机号码') - auth_type = serializers.ChoiceField(choices=['totp', 'sms'], - label='验证类型:totp-Google身份验证器,sms-短信验证码') + engineer = serializers.CharField(label="用户名") + enable = serializers.ChoiceField(choices=["true", "false"], label="启用or禁用") + phone = serializers.CharField(required=False, label="手机号码") + auth_type = serializers.ChoiceField( + choices=["totp", "sms"], label="验证类型:totp-Google身份验证器,sms-短信验证码" + ) def validate(self, attrs): - auth_type = attrs.get('auth_type') - engineer = attrs.get('engineer') - enable = attrs.get('enable') + auth_type = attrs.get("auth_type") + engineer = attrs.get("engineer") + enable = attrs.get("enable") try: Users.objects.get(username=engineer) except Users.DoesNotExist: raise serializers.ValidationError({"errors": "不存在该用户"}) - if auth_type == 'sms' and enable == 'true': - if not attrs.get('phone'): + if auth_type == "sms" and enable == "true": + if not attrs.get("phone"): raise serializers.ValidationError({"errors": "缺少 phone"}) return attrs class TwoFAStateSerializer(serializers.Serializer): - engineer = serializers.CharField(label='用户名') + engineer = serializers.CharField(label="用户名") def validate(self, attrs): - engineer = attrs.get('engineer') + engineer = attrs.get("engineer") try: Users.objects.get(username=engineer) @@ -134,23 +131,25 @@ def validate(self, attrs): class TwoFASaveSerializer(serializers.Serializer): - engineer = serializers.CharField(label='用户名') - key = serializers.CharField(required=False, label='密钥') - phone = serializers.CharField(required=False, label='手机号码') - auth_type = serializers.ChoiceField(choices=['disabled', 'totp', 'sms'], - label='验证类型:disabled-关闭,totp-Google身份验证器,sms-短信验证码') + engineer = serializers.CharField(label="用户名") + key = serializers.CharField(required=False, label="密钥") + phone = serializers.CharField(required=False, label="手机号码") + auth_type = serializers.ChoiceField( + choices=["disabled", "totp", "sms"], + label="验证类型:disabled-关闭,totp-Google身份验证器,sms-短信验证码", + ) def validate(self, attrs): - engineer = attrs.get('engineer') - auth_type = attrs.get('auth_type') - key = attrs.get('key') - phone = attrs.get('phone') + engineer = attrs.get("engineer") + auth_type = attrs.get("auth_type") + key = attrs.get("key") + phone = attrs.get("phone") - if auth_type == 'sms': + if auth_type == "sms": if not phone: raise serializers.ValidationError({"errors": "缺少 phone"}) - if auth_type == 'totp': + if auth_type == "totp": if not key: raise serializers.ValidationError({"errors": "缺少 key"}) @@ -163,18 +162,18 @@ def validate(self, attrs): class TwoFAVerifySerializer(serializers.Serializer): - engineer = serializers.CharField(label='用户名') - otp = serializers.IntegerField(label='一次性密码/验证码') - key = serializers.CharField(required=False, label='密钥') - phone = serializers.CharField(required=False, label='手机号码') - auth_type = serializers.CharField(label='验证方式') + engineer = serializers.CharField(label="用户名") + otp = serializers.IntegerField(label="一次性密码/验证码") + key = serializers.CharField(required=False, label="密钥") + phone = serializers.CharField(required=False, label="手机号码") + auth_type = serializers.CharField(label="验证方式") def validate(self, attrs): - engineer = attrs.get('engineer') - auth_type = attrs.get('auth_type') + engineer = attrs.get("engineer") + auth_type = attrs.get("auth_type") - if auth_type == 'sms': - if not attrs.get('phone'): + if auth_type == "sms": + if not attrs.get("phone"): raise serializers.ValidationError({"errors": "缺少 phone"}) try: @@ -186,51 +185,33 @@ def validate(self, attrs): class InstanceSerializer(serializers.ModelSerializer): - class Meta: model = Instance fields = "__all__" - extra_kwargs = { - 'password': { - 'write_only': True - } - } + extra_kwargs = {"password": {"write_only": True}} class InstanceDetailSerializer(serializers.ModelSerializer): - class Meta: model = Instance fields = "__all__" extra_kwargs = { - 'password': { - 'write_only': True - }, - 'instance_name': { - 'required': False - }, - 'type': { - 'required': False - }, - 'db_type': { - 'required': False - }, - 'host': { - 'required': False - } + "password": {"write_only": True}, + "instance_name": {"required": False}, + "type": {"required": False}, + "db_type": {"required": False}, + "host": {"required": False}, } class TunnelSerializer(serializers.ModelSerializer): - class Meta: model = Tunnel fields = "__all__" - write_only_fields = ['password', 'pkey', 'pkey_password'] + write_only_fields = ["password", "pkey", "pkey_password"] class CloudAccessKeySerializer(serializers.ModelSerializer): - class Meta: model = CloudAccessKey fields = "__all__" @@ -241,7 +222,7 @@ class AliyunRdsSerializer(serializers.ModelSerializer): def create(self, validated_data): """创建包含accesskey的aliyunrds实例""" - rds_data = validated_data.pop('ak') + rds_data = validated_data.pop("ak") try: with transaction.atomic(): @@ -249,24 +230,26 @@ def create(self, validated_data): rds = AliyunRdsConfig.objects.create(ak=ak, **validated_data) except Exception as e: logger.error(f"创建AliyunRds报错,错误信息:{traceback.format_exc()}") - raise serializers.ValidationError({'errors': str(e)}) + raise serializers.ValidationError({"errors": str(e)}) else: return rds class Meta: model = AliyunRdsConfig - fields = ('id', 'rds_dbinstanceid', 'is_enable', 'instance', 'ak') + fields = ("id", "rds_dbinstanceid", "is_enable", "instance", "ak") class InstanceResourceSerializer(serializers.Serializer): - instance_id = serializers.IntegerField(label='实例id') - resource_type = serializers.ChoiceField(choices=['database', 'schema', 'table', 'column'], label='资源类型') - db_name = serializers.CharField(required=False, label='数据库名') - schema_name = serializers.CharField(required=False, label='schema名') - tb_name = serializers.CharField(required=False, label='表名') + instance_id = serializers.IntegerField(label="实例id") + resource_type = serializers.ChoiceField( + choices=["database", "schema", "table", "column"], label="资源类型" + ) + db_name = serializers.CharField(required=False, label="数据库名") + schema_name = serializers.CharField(required=False, label="schema名") + tb_name = serializers.CharField(required=False, label="表名") def validate(self, attrs): - instance_id = attrs.get('instance_id') + instance_id = attrs.get("instance_id") try: Instance.objects.get(id=instance_id) @@ -282,9 +265,9 @@ class InstanceResourceListSerializer(serializers.Serializer): class ExecuteCheckSerializer(serializers.Serializer): - instance_id = serializers.IntegerField(label='实例id') - db_name = serializers.CharField(label='数据库名') - full_sql = serializers.CharField(label='SQL内容') + instance_id = serializers.IntegerField(label="实例id") + db_name = serializers.CharField(label="数据库名") + full_sql = serializers.CharField(label="SQL内容") def validate_instance_id(self, instance_id): try: @@ -311,8 +294,8 @@ class ExecuteCheckResultSerializer(serializers.Serializer): class WorkflowSerializer(serializers.ModelSerializer): def validate(self, attrs): - engineer = attrs.get('engineer') - group_id = attrs.get('group_id') + engineer = attrs.get("engineer") + group_id = attrs.get("group_id") try: Users.objects.get(username=engineer) @@ -329,13 +312,17 @@ def validate(self, attrs): class Meta: model = SqlWorkflow fields = "__all__" - read_only_fields = ['status', 'is_backup', 'syntax_type', 'audit_auth_groups', 'engineer_display', - 'group_name', 'finish_time', 'is_manual'] - extra_kwargs = { - 'demand_url': { - 'required': False - } - } + read_only_fields = [ + "status", + "is_backup", + "syntax_type", + "audit_auth_groups", + "engineer_display", + "group_name", + "finish_time", + "is_manual", + ] + extra_kwargs = {"demand_url": {"required": False}} class WorkflowContentSerializer(serializers.ModelSerializer): @@ -343,90 +330,119 @@ class WorkflowContentSerializer(serializers.ModelSerializer): def create(self, validated_data): """使用原工单submit流程创建工单""" - workflow_data = validated_data.pop('workflow') - instance = workflow_data['instance'] - sql_content = validated_data['sql_content'].strip() - user = Users.objects.get(username=workflow_data['engineer']) - group = ResourceGroup.objects.get(pk=workflow_data['group_id']) + workflow_data = validated_data.pop("workflow") + instance = workflow_data["instance"] + sql_content = validated_data["sql_content"].strip() + user = Users.objects.get(username=workflow_data["engineer"]) + group = ResourceGroup.objects.get(pk=workflow_data["group_id"]) active_user = Users.objects.filter(is_active=1) # 验证组权限(用户是否在该组、该组是否有指定实例) try: - user_instances(user, tag_codes=['can_write']).get(id=instance.id) + user_instances(user, tag_codes=["can_write"]).get(id=instance.id) except instance.DoesNotExist: - raise serializers.ValidationError({'errors': '你所在组未关联该实例!'}) + raise serializers.ValidationError({"errors": "你所在组未关联该实例!"}) # 再次交给engine进行检测,防止绕过 try: check_engine = get_engine(instance=instance) - check_result = check_engine.execute_check(db_name=workflow_data['db_name'], - sql=sql_content) + check_result = check_engine.execute_check( + db_name=workflow_data["db_name"], sql=sql_content + ) except Exception as e: - raise serializers.ValidationError({'errors': str(e)}) + raise serializers.ValidationError({"errors": str(e)}) # 未开启备份选项,并且engine支持备份,强制设置备份 - is_backup = workflow_data['is_backup'] if 'is_backup' in workflow_data.keys() else False + is_backup = ( + workflow_data["is_backup"] if "is_backup" in workflow_data.keys() else False + ) sys_config = SysConfig() - if not sys_config.get('enable_backup_switch') and check_engine.auto_backup: + if not sys_config.get("enable_backup_switch") and check_engine.auto_backup: is_backup = True # 按照系统配置确定是自动驳回还是放行 - auto_review_wrong = sys_config.get('auto_review_wrong', '') # 1表示出现警告就驳回,2和空表示出现错误才驳回 - workflow_status = 'workflow_manreviewing' - if check_result.warning_count > 0 and auto_review_wrong == '1': - workflow_status = 'workflow_autoreviewwrong' - elif check_result.error_count > 0 and auto_review_wrong in ('', '1', '2'): - workflow_status = 'workflow_autoreviewwrong' - - workflow_data.update(status=workflow_status, - is_backup=is_backup, - is_manual=0, - syntax_type=check_result.syntax_type, - engineer_display=user.display, - group_name=group.group_name, - audit_auth_groups=Audit.settings(workflow_data['group_id'], - WorkflowDict.workflow_type['sqlreview'])) + auto_review_wrong = sys_config.get( + "auto_review_wrong", "" + ) # 1表示出现警告就驳回,2和空表示出现错误才驳回 + workflow_status = "workflow_manreviewing" + if check_result.warning_count > 0 and auto_review_wrong == "1": + workflow_status = "workflow_autoreviewwrong" + elif check_result.error_count > 0 and auto_review_wrong in ("", "1", "2"): + workflow_status = "workflow_autoreviewwrong" + + workflow_data.update( + status=workflow_status, + is_backup=is_backup, + is_manual=0, + syntax_type=check_result.syntax_type, + engineer_display=user.display, + group_name=group.group_name, + audit_auth_groups=Audit.settings( + workflow_data["group_id"], WorkflowDict.workflow_type["sqlreview"] + ), + ) try: with transaction.atomic(): workflow = SqlWorkflow.objects.create(**workflow_data) - validated_data['review_content'] = check_result.json() - workflow_content = SqlWorkflowContent.objects.create(workflow=workflow, **validated_data) + validated_data["review_content"] = check_result.json() + workflow_content = SqlWorkflowContent.objects.create( + workflow=workflow, **validated_data + ) # 自动审核通过了,才调用工作流 - if workflow_status == 'workflow_manreviewing': + if workflow_status == "workflow_manreviewing": # 调用工作流插入审核信息, SQL上线权限申请workflow_type=2 - Audit.add(WorkflowDict.workflow_type['sqlreview'], workflow.id) + Audit.add(WorkflowDict.workflow_type["sqlreview"], workflow.id) except Exception as e: logger.error(f"提交工单报错,错误信息:{traceback.format_exc()}") - raise serializers.ValidationError({'errors': str(e)}) + raise serializers.ValidationError({"errors": str(e)}) else: # 自动审核通过且开启了Apply阶段通知参数才发送消息通知 - is_notified = 'Apply' in sys_config.get('notify_phase_control').split(',') \ - if sys_config.get('notify_phase_control') else True - if workflow_status == 'workflow_manreviewing' and is_notified: + is_notified = ( + "Apply" in sys_config.get("notify_phase_control").split(",") + if sys_config.get("notify_phase_control") + else True + ) + if workflow_status == "workflow_manreviewing" and is_notified: # 获取审核信息 - audit_id = Audit.detail_by_workflow_id(workflow_id=workflow.id, - workflow_type=WorkflowDict.workflow_type['sqlreview']).audit_id - async_task(notify_for_audit, audit_id=audit_id, cc_users=active_user, timeout=60, - task_name=f'sqlreview-submit-{workflow.id}') + audit_id = Audit.detail_by_workflow_id( + workflow_id=workflow.id, + workflow_type=WorkflowDict.workflow_type["sqlreview"], + ).audit_id + async_task( + notify_for_audit, + audit_id=audit_id, + cc_users=active_user, + timeout=60, + task_name=f"sqlreview-submit-{workflow.id}", + ) return workflow_content class Meta: model = SqlWorkflowContent - fields = ('id', 'workflow_id', 'workflow', 'sql_content', 'review_content', 'execute_result') - read_only_fields = ['review_content', 'execute_result'] + fields = ( + "id", + "workflow_id", + "workflow", + "sql_content", + "review_content", + "execute_result", + ) + read_only_fields = ["review_content", "execute_result"] class AuditWorkflowSerializer(serializers.Serializer): - engineer = serializers.CharField(label='操作用户') - workflow_id = serializers.IntegerField(label='工单id') - audit_remark = serializers.CharField(label='审批备注') - workflow_type = serializers.ChoiceField(choices=[1, 2, 3], label='工单类型:1-查询权限申请,2-SQL上线申请,3-数据归档申请') - audit_type = serializers.ChoiceField(choices=['pass', 'cancel'], label='审核类型') + engineer = serializers.CharField(label="操作用户") + workflow_id = serializers.IntegerField(label="工单id") + audit_remark = serializers.CharField(label="审批备注") + workflow_type = serializers.ChoiceField( + choices=[1, 2, 3], label="工单类型:1-查询权限申请,2-SQL上线申请,3-数据归档申请" + ) + audit_type = serializers.ChoiceField(choices=["pass", "cancel"], label="审核类型") def validate(self, attrs): - engineer = attrs.get('engineer') - workflow_id = attrs.get('workflow_id') - workflow_type = attrs.get('workflow_type') + engineer = attrs.get("engineer") + workflow_id = attrs.get("workflow_id") + workflow_type = attrs.get("workflow_type") try: Users.objects.get(username=engineer) @@ -434,7 +450,9 @@ def validate(self, attrs): raise serializers.ValidationError({"errors": f"不存在该用户:{engineer}"}) try: - WorkflowAudit.objects.get(workflow_id=workflow_id, workflow_type=workflow_type) + WorkflowAudit.objects.get( + workflow_id=workflow_id, workflow_type=workflow_type + ) except WorkflowAudit.DoesNotExist: raise serializers.ValidationError({"errors": "不存在该工单"}) @@ -442,7 +460,7 @@ def validate(self, attrs): class WorkflowAuditSerializer(serializers.Serializer): - engineer = serializers.CharField(label='操作用户') + engineer = serializers.CharField(label="操作用户") def validate_engineer(self, engineer): try: @@ -453,35 +471,51 @@ def validate_engineer(self, engineer): class WorkflowAuditListSerializer(serializers.ModelSerializer): - class Meta: model = WorkflowAudit - exclude = ['group_id', 'workflow_id', 'workflow_remark', 'next_audit', 'create_user', 'sys_time'] + exclude = [ + "group_id", + "workflow_id", + "workflow_remark", + "next_audit", + "create_user", + "sys_time", + ] class WorkflowLogSerializer(serializers.Serializer): - workflow_id = serializers.IntegerField(label='工单id') - workflow_type = serializers.ChoiceField(choices=[1, 2, 3], label='工单类型:1-查询权限申请,2-SQL上线申请,3-数据归档申请') + workflow_id = serializers.IntegerField(label="工单id") + workflow_type = serializers.ChoiceField( + choices=[1, 2, 3], label="工单类型:1-查询权限申请,2-SQL上线申请,3-数据归档申请" + ) class WorkflowLogListSerializer(serializers.ModelSerializer): - class Meta: model = WorkflowLog - fields = ['operation_type_desc', 'operation_info', 'operator_display', 'operation_time'] + fields = [ + "operation_type_desc", + "operation_info", + "operator_display", + "operation_time", + ] class ExecuteWorkflowSerializer(serializers.Serializer): - engineer = serializers.CharField(required=False, label='操作用户') - workflow_id = serializers.IntegerField(label='工单id') - workflow_type = serializers.ChoiceField(choices=[2, 3], label='工单类型:1-查询权限申请,2-SQL上线申请,3-数据归档申请') - mode = serializers.ChoiceField(choices=['auto', 'manual'], label='执行模式:auto-线上执行,manual-已手动执行', required=False) + engineer = serializers.CharField(required=False, label="操作用户") + workflow_id = serializers.IntegerField(label="工单id") + workflow_type = serializers.ChoiceField( + choices=[2, 3], label="工单类型:1-查询权限申请,2-SQL上线申请,3-数据归档申请" + ) + mode = serializers.ChoiceField( + choices=["auto", "manual"], label="执行模式:auto-线上执行,manual-已手动执行", required=False + ) def validate(self, attrs): - engineer = attrs.get('engineer') - workflow_id = attrs.get('workflow_id') - workflow_type = attrs.get('workflow_type') - mode = attrs.get('mode') + engineer = attrs.get("engineer") + workflow_id = attrs.get("workflow_id") + workflow_type = attrs.get("workflow_type") + mode = attrs.get("mode") # SQL上线工单的mode和engineer为必需字段 if workflow_type == 2: @@ -496,7 +530,9 @@ def validate(self, attrs): raise serializers.ValidationError({"errors": f"不存在该用户:{engineer}"}) try: - WorkflowAudit.objects.get(workflow_id=workflow_id, workflow_type=workflow_type) + WorkflowAudit.objects.get( + workflow_id=workflow_id, workflow_type=workflow_type + ) except WorkflowAudit.DoesNotExist: raise serializers.ValidationError({"errors": "不存在该工单"}) diff --git a/sql_api/tests.py b/sql_api/tests.py index 80b6ebe870..4c93e6c31a 100644 --- a/sql_api/tests.py +++ b/sql_api/tests.py @@ -5,9 +5,20 @@ from rest_framework.test import APITestCase from rest_framework import status from common.config import SysConfig -from sql.models import ResourceGroup, Instance, AliyunRdsConfig, CloudAccessKey, Tunnel, \ - SqlWorkflow, SqlWorkflowContent, WorkflowAudit, WorkflowLog, InstanceTag, WorkflowAuditSetting, \ - TwoFactorAuthConfig +from sql.models import ( + ResourceGroup, + Instance, + AliyunRdsConfig, + CloudAccessKey, + Tunnel, + SqlWorkflow, + SqlWorkflowContent, + WorkflowAudit, + WorkflowLog, + InstanceTag, + WorkflowAuditSetting, + TwoFactorAuthConfig, +) import json User = get_user_model() @@ -15,36 +26,40 @@ class InfoTest(TestCase): def setUp(self) -> None: - self.superuser = User.objects.create(username='super', is_superuser=True) + self.superuser = User.objects.create(username="super", is_superuser=True) self.client.force_login(self.superuser) def tearDown(self) -> None: self.superuser.delete() def test_info_api(self): - r = self.client.get('/api/info') + r = self.client.get("/api/info") r_json = r.json() - self.assertIsInstance(r_json['archery']['version'], str) + self.assertIsInstance(r_json["archery"]["version"], str) def test_debug_api(self): - r = self.client.get('/api/debug') + r = self.client.get("/api/debug") r_json = r.json() - self.assertIsInstance(r_json['archery']['version'], str) + self.assertIsInstance(r_json["archery"]["version"], str) class TestUser(APITestCase): """测试用户相关接口""" def setUp(self): - self.user = User(username='test_user', display='测试用户', is_active=True) - self.user.set_password('test_password') + self.user = User(username="test_user", display="测试用户", is_active=True) + self.user.set_password("test_password") self.user.save() - self.group = Group.objects.create(id=1, name='DBA') - self.res_group = ResourceGroup.objects.create(group_id=1, group_name='test') - r = self.client.post('/api/auth/token/', {'username': 'test_user', 'password': 'test_password'}, format='json') - self.token = r.data['access'] - self.client.credentials(HTTP_AUTHORIZATION='Bearer ' + self.token) - SysConfig().set('api_user_whitelist', self.user.id) + self.group = Group.objects.create(id=1, name="DBA") + self.res_group = ResourceGroup.objects.create(group_id=1, group_name="test") + r = self.client.post( + "/api/auth/token/", + {"username": "test_user", "password": "test_password"}, + format="json", + ) + self.token = r.data["access"] + self.client.credentials(HTTP_AUTHORIZATION="Bearer " + self.token) + SysConfig().set("api_user_whitelist", self.user.id) def tearDown(self): self.user.delete() @@ -54,131 +69,124 @@ def tearDown(self): def test_user_not_in_whitelist(self): """测试api用户白名单参数""" - SysConfig().set('api_user_whitelist', '') - r = self.client.get('/api/v1/user/', format='json') + SysConfig().set("api_user_whitelist", "") + r = self.client.get("/api/v1/user/", format="json") self.assertEqual(r.status_code, status.HTTP_403_FORBIDDEN) - self.assertDictEqual(r.json(), {'detail': '您没有执行该操作的权限。'}) + self.assertDictEqual(r.json(), {"detail": "您没有执行该操作的权限。"}) def test_get_user_list(self): """测试获取用户清单""" - r = self.client.get('/api/v1/user/', format='json') + r = self.client.get("/api/v1/user/", format="json") self.assertEqual(r.status_code, status.HTTP_200_OK) - self.assertEqual(r.json()['count'], 1) + self.assertEqual(r.json()["count"], 1) def test_create_user(self): """测试创建用户""" json_data = { - 'username': 'test_user2', - 'password': 'test_password2', - 'display': '测试用户2' + "username": "test_user2", + "password": "test_password2", + "display": "测试用户2", } - r = self.client.post('/api/v1/user/', json_data, format='json') + r = self.client.post("/api/v1/user/", json_data, format="json") self.assertEqual(r.status_code, status.HTTP_201_CREATED) - self.assertEqual(r.json()['username'], 'test_user2') + self.assertEqual(r.json()["username"], "test_user2") def test_update_user(self): """测试更新用户""" - json_data = { - 'display': '更新中文名' - } - r = self.client.put(f'/api/v1/user/{self.user.id}/', json_data, format='json') + json_data = {"display": "更新中文名"} + r = self.client.put(f"/api/v1/user/{self.user.id}/", json_data, format="json") user = User.objects.get(pk=self.user.id) self.assertEqual(r.status_code, status.HTTP_200_OK) - self.assertEqual(user.display, '更新中文名') + self.assertEqual(user.display, "更新中文名") def test_delete_user(self): """测试删除用户""" json_data = { - 'username': 'test_user2', - 'password': 'test_password2', - 'display': '测试用户2' + "username": "test_user2", + "password": "test_password2", + "display": "测试用户2", } - r1 = self.client.post('/api/v1/user/', json_data, format='json') - r2 = self.client.delete(f'/api/v1/user/{r1.json()["id"]}/', format='json') + r1 = self.client.post("/api/v1/user/", json_data, format="json") + r2 = self.client.delete(f'/api/v1/user/{r1.json()["id"]}/', format="json") self.assertEqual(r2.status_code, status.HTTP_204_NO_CONTENT) - self.assertEqual(User.objects.filter(username='test_user2').count(), 0) + self.assertEqual(User.objects.filter(username="test_user2").count(), 0) def test_get_user_group_list(self): """测试获取用户组清单""" - r = self.client.get('/api/v1/user/group/', format='json') + r = self.client.get("/api/v1/user/group/", format="json") self.assertEqual(r.status_code, status.HTTP_200_OK) - self.assertEqual(r.json()['count'], 1) + self.assertEqual(r.json()["count"], 1) def test_create_user_group(self): """测试创建用户组""" - json_data = { - 'name': 'RD' - } - r = self.client.post('/api/v1/user/group/', json_data, format='json') + json_data = {"name": "RD"} + r = self.client.post("/api/v1/user/group/", json_data, format="json") self.assertEqual(r.status_code, status.HTTP_201_CREATED) - self.assertEqual(r.json()['name'], 'RD') + self.assertEqual(r.json()["name"], "RD") def test_update_user_group(self): """测试更新用户组""" - json_data = { - 'name': '更新用户组名称' - } - r = self.client.put(f'/api/v1/user/group/{self.group.id}/', json_data, format='json') + json_data = {"name": "更新用户组名称"} + r = self.client.put( + f"/api/v1/user/group/{self.group.id}/", json_data, format="json" + ) group = Group.objects.get(pk=self.group.id) self.assertEqual(r.status_code, status.HTTP_200_OK) - self.assertEqual(group.name, '更新用户组名称') + self.assertEqual(group.name, "更新用户组名称") def test_delete_user_group(self): """测试删除用户组""" - r = self.client.delete(f'/api/v1/user/group/{self.group.id}/', format='json') + r = self.client.delete(f"/api/v1/user/group/{self.group.id}/", format="json") self.assertEqual(r.status_code, status.HTTP_204_NO_CONTENT) - self.assertEqual(Group.objects.filter(name='DBA').count(), 0) + self.assertEqual(Group.objects.filter(name="DBA").count(), 0) def test_get_resource_group_list(self): """测试获取资源组清单""" - r = self.client.get('/api/v1/user/resourcegroup/', format='json') + r = self.client.get("/api/v1/user/resourcegroup/", format="json") self.assertEqual(r.status_code, status.HTTP_200_OK) - self.assertEqual(r.json()['count'], 1) + self.assertEqual(r.json()["count"], 1) def test_create_resource_group(self): """测试创建资源组""" json_data = { - 'group_name': 'prod', - 'ding_webhook': 'https://oapi.dingtalk.com/robot/send?access_token=123' + "group_name": "prod", + "ding_webhook": "https://oapi.dingtalk.com/robot/send?access_token=123", } - r = self.client.post('/api/v1/user/resourcegroup/', json_data, format='json') + r = self.client.post("/api/v1/user/resourcegroup/", json_data, format="json") self.assertEqual(r.status_code, status.HTTP_201_CREATED) - self.assertEqual(r.json()['group_name'], 'prod') + self.assertEqual(r.json()["group_name"], "prod") def test_update_resource_group(self): """测试更新资源组""" - json_data = { - 'group_name': '更新资源组名称' - } - r = self.client.put(f'/api/v1/user/resourcegroup/{self.res_group.group_id}/', json_data, format='json') + json_data = {"group_name": "更新资源组名称"} + r = self.client.put( + f"/api/v1/user/resourcegroup/{self.res_group.group_id}/", + json_data, + format="json", + ) group = ResourceGroup.objects.get(pk=self.res_group.group_id) self.assertEqual(r.status_code, status.HTTP_200_OK) - self.assertEqual(group.group_name, '更新资源组名称') + self.assertEqual(group.group_name, "更新资源组名称") def test_delete_resource_group(self): """测试删除资源组""" - r = self.client.delete(f'/api/v1/user/resourcegroup/{self.res_group.group_id}/', format='json') + r = self.client.delete( + f"/api/v1/user/resourcegroup/{self.res_group.group_id}/", format="json" + ) self.assertEqual(r.status_code, status.HTTP_204_NO_CONTENT) - self.assertEqual(Group.objects.filter(name='test').count(), 0) + self.assertEqual(Group.objects.filter(name="test").count(), 0) def test_user_auth(self): """测试用户认证校验""" - json_data = { - "engineer": "test_user", - "password": "test_password" - } - r = self.client.post(f'/api/v1/user/auth/', json_data, format='json') + json_data = {"engineer": "test_user", "password": "test_password"} + r = self.client.post(f"/api/v1/user/auth/", json_data, format="json") self.assertEqual(r.status_code, status.HTTP_200_OK) - self.assertEqual(r.json(), {'status': 0, 'msg': '认证成功'}) + self.assertEqual(r.json(), {"status": 0, "msg": "认证成功"}) def test_2fa_config(self): """测试用户配置2fa""" - json_data = { - "engineer": "test_user", - "auth_type": "totp", - "enable": "false" - } - r = self.client.post(f'/api/v1/user/2fa/', json_data, format='json') + json_data = {"engineer": "test_user", "auth_type": "totp", "enable": "false"} + r = self.client.post(f"/api/v1/user/2fa/", json_data, format="json") self.assertEqual(r.status_code, status.HTTP_200_OK) self.assertEqual(TwoFactorAuthConfig.objects.count(), 0) @@ -187,9 +195,9 @@ def test_2fa_save(self): json_data = { "engineer": "test_user", "auth_type": "totp", - "key": "ZUGRIJZP6H7LIOAL4LH5JA4GSXXT3WOK" + "key": "ZUGRIJZP6H7LIOAL4LH5JA4GSXXT3WOK", } - r = self.client.post(f'/api/v1/user/2fa/save/', json_data, format='json') + r = self.client.post(f"/api/v1/user/2fa/save/", json_data, format="json") self.assertEqual(r.status_code, status.HTTP_200_OK) self.assertEqual(TwoFactorAuthConfig.objects.count(), 1) @@ -199,29 +207,46 @@ def test_2fa_verify(self): "engineer": "test_user", "otp": 123456, "key": "ZUGRIJZP6H7LIOAL4LH5JA4GSXXT3WOK", - "auth_type": "totp" + "auth_type": "totp", } - r = self.client.post(f'/api/v1/user/2fa/verify/', json_data, format='json') + r = self.client.post(f"/api/v1/user/2fa/verify/", json_data, format="json") self.assertEqual(r.status_code, status.HTTP_200_OK) - self.assertEqual(r.json()['status'], 1) + self.assertEqual(r.json()["status"], 1) class TestInstance(APITestCase): """测试实例相关接口""" def setUp(self): - self.user = User(username='test_user', display='测试用户', is_active=True) - self.user.set_password('test_password') + self.user = User(username="test_user", display="测试用户", is_active=True) + self.user.set_password("test_password") self.user.save() - self.ins = Instance.objects.create(instance_name='some_ins', type='slave', db_type='mysql', - host='some_host', port=3306, user='ins_user', password='some_str') - self.ak = CloudAccessKey.objects.create(type='aliyun', key_id='abc', key_secret='abc') - self.rds = AliyunRdsConfig.objects.create(rds_dbinstanceid='abc', ak_id=self.ak.id, instance=self.ins) - self.tunnel = Tunnel.objects.create(tunnel_name='one_tunnel', host='one_host', port=22) - r = self.client.post('/api/auth/token/', {'username': 'test_user', 'password': 'test_password'}, format='json') - self.token = r.data['access'] - self.client.credentials(HTTP_AUTHORIZATION='Bearer ' + self.token) - SysConfig().set('api_user_whitelist', self.user.id) + self.ins = Instance.objects.create( + instance_name="some_ins", + type="slave", + db_type="mysql", + host="some_host", + port=3306, + user="ins_user", + password="some_str", + ) + self.ak = CloudAccessKey.objects.create( + type="aliyun", key_id="abc", key_secret="abc" + ) + self.rds = AliyunRdsConfig.objects.create( + rds_dbinstanceid="abc", ak_id=self.ak.id, instance=self.ins + ) + self.tunnel = Tunnel.objects.create( + tunnel_name="one_tunnel", host="one_host", port=22 + ) + r = self.client.post( + "/api/auth/token/", + {"username": "test_user", "password": "test_password"}, + format="json", + ) + self.token = r.data["access"] + self.client.credentials(HTTP_AUTHORIZATION="Bearer " + self.token) + SysConfig().set("api_user_whitelist", self.user.id) def tearDown(self): self.user.delete() @@ -233,49 +258,54 @@ def tearDown(self): def test_get_instance_list(self): """测试获取实例清单""" - r = self.client.get('/api/v1/instance/', format='json') + r = self.client.get("/api/v1/instance/", format="json") self.assertEqual(r.status_code, status.HTTP_200_OK) - self.assertEqual(r.json()['count'], 1) + self.assertEqual(r.json()["count"], 1) def test_create_instance(self): """测试创建实例""" json_data = { - 'instance_name': 'test_ins', - 'type': 'master', - 'db_type': 'mysql', - 'host': 'some_host', - 'port': 3306 + "instance_name": "test_ins", + "type": "master", + "db_type": "mysql", + "host": "some_host", + "port": 3306, } - r = self.client.post('/api/v1/instance/', json_data, format='json') + r = self.client.post("/api/v1/instance/", json_data, format="json") self.assertEqual(r.status_code, status.HTTP_201_CREATED) - self.assertEqual(r.json()['instance_name'], 'test_ins') + self.assertEqual(r.json()["instance_name"], "test_ins") def test_update_instance(self): """测试更新实例""" - json_data = { - 'instance_name': '更新实例名称' - } - r = self.client.put(f'/api/v1/instance/{self.ins.id}/', json_data, format='json') + json_data = {"instance_name": "更新实例名称"} + r = self.client.put( + f"/api/v1/instance/{self.ins.id}/", json_data, format="json" + ) ins = Instance.objects.get(pk=self.ins.id) self.assertEqual(r.status_code, status.HTTP_200_OK) - self.assertEqual(ins.instance_name, '更新实例名称') + self.assertEqual(ins.instance_name, "更新实例名称") def test_delete_instance(self): """测试删除实例""" - r = self.client.delete(f'/api/v1/instance/{self.ins.id}/', format='json') + r = self.client.delete(f"/api/v1/instance/{self.ins.id}/", format="json") self.assertEqual(r.status_code, status.HTTP_204_NO_CONTENT) - self.assertEqual(Instance.objects.filter(instance_name='some_ins').count(), 0) + self.assertEqual(Instance.objects.filter(instance_name="some_ins").count(), 0) def test_get_aliyunrds_list(self): """测试获取aliyunrds清单""" - r = self.client.get('/api/v1/instance/rds/', format='json') + r = self.client.get("/api/v1/instance/rds/", format="json") self.assertEqual(r.status_code, status.HTTP_200_OK) - self.assertEqual(r.json()['count'], 1) + self.assertEqual(r.json()["count"], 1) def test_create_aliyunrds(self): """测试创建aliyunrds""" - ins = Instance.objects.create(instance_name='another_ins', type='slave', db_type='mysql', - host='another_host', port=3306) + ins = Instance.objects.create( + instance_name="another_ins", + type="slave", + db_type="mysql", + host="another_host", + port=3306, + ) json_data = { "rds_dbinstanceid": "bbc", "is_enable": True, @@ -284,29 +314,25 @@ def test_create_aliyunrds(self): "type": "aliyun", "key_id": "bbc", "key_secret": "bbc", - "remark": "bbc" - } + "remark": "bbc", + }, } - r = self.client.post('/api/v1/instance/rds/', json_data, format='json') + r = self.client.post("/api/v1/instance/rds/", json_data, format="json") self.assertEqual(r.status_code, status.HTTP_201_CREATED) - self.assertEqual(r.json()['rds_dbinstanceid'], 'bbc') + self.assertEqual(r.json()["rds_dbinstanceid"], "bbc") def test_get_tunnel_list(self): """测试获取隧道清单""" - r = self.client.get('/api/v1/instance/tunnel/', format='json') + r = self.client.get("/api/v1/instance/tunnel/", format="json") self.assertEqual(r.status_code, status.HTTP_200_OK) - self.assertEqual(r.json()['count'], 1) + self.assertEqual(r.json()["count"], 1) def test_create_tunnel(self): """测试创建隧道""" - json_data = { - "tunnel_name": "tunnel_test", - "host": "one_host", - "port": 22 - } - r = self.client.post('/api/v1/instance/tunnel/', json_data, format='json') + json_data = {"tunnel_name": "tunnel_test", "host": "one_host", "port": 22} + r = self.client.post("/api/v1/instance/tunnel/", json_data, format="json") self.assertEqual(r.status_code, status.HTTP_201_CREATED) - self.assertEqual(r.json()['tunnel_name'], 'tunnel_test') + self.assertEqual(r.json()["tunnel_name"], "tunnel_test") class TestWorkflow(APITestCase): @@ -314,63 +340,82 @@ class TestWorkflow(APITestCase): def setUp(self): self.now = datetime.now() - self.group = Group.objects.create(id=1, name='DBA') - self.res_group = ResourceGroup.objects.create(group_id=1, group_name='test') - self.ins_tag = InstanceTag.objects.create(tag_code='can_write', active=1) - self.wfs = WorkflowAuditSetting.objects.create(group_id=self.res_group.group_id, - workflow_type=2, audit_auth_groups=self.group.id) - can_execute_permission = Permission.objects.get(codename='sql_execute') - can_execute_resource_permission = Permission.objects.get(codename='sql_execute_for_resource_group') - can_review_permission = Permission.objects.get(codename='sql_review') - self.user = User(username='test_user', display='测试用户', is_active=True) - self.user.set_password('test_password') + self.group = Group.objects.create(id=1, name="DBA") + self.res_group = ResourceGroup.objects.create(group_id=1, group_name="test") + self.ins_tag = InstanceTag.objects.create(tag_code="can_write", active=1) + self.wfs = WorkflowAuditSetting.objects.create( + group_id=self.res_group.group_id, + workflow_type=2, + audit_auth_groups=self.group.id, + ) + can_execute_permission = Permission.objects.get(codename="sql_execute") + can_execute_resource_permission = Permission.objects.get( + codename="sql_execute_for_resource_group" + ) + can_review_permission = Permission.objects.get(codename="sql_review") + self.user = User(username="test_user", display="测试用户", is_active=True) + self.user.set_password("test_password") self.user.save() - self.user.user_permissions.add(can_execute_permission, can_execute_resource_permission, can_review_permission) + self.user.user_permissions.add( + can_execute_permission, + can_execute_resource_permission, + can_review_permission, + ) self.user.groups.add(self.group.id) self.user.resource_group.add(self.res_group.group_id) - self.ins = Instance.objects.create(instance_name='some_ins', type='slave', db_type='redis', - host='some_host', port=6379, user='ins_user', password='some_str') + self.ins = Instance.objects.create( + instance_name="some_ins", + type="slave", + db_type="redis", + host="some_host", + port=6379, + user="ins_user", + password="some_str", + ) self.ins.resource_group.add(self.res_group.group_id) self.ins.instance_tag.add(self.ins_tag.id) self.wf1 = SqlWorkflow.objects.create( - workflow_name='some_name', + workflow_name="some_name", group_id=1, - group_name='g1', + group_name="g1", engineer=self.user.username, engineer_display=self.user.display, - audit_auth_groups='1', + audit_auth_groups="1", create_time=self.now - timedelta(days=1), - status='workflow_manreviewing', + status="workflow_manreviewing", is_backup=False, instance=self.ins, - db_name='some_db', + db_name="some_db", syntax_type=1, ) self.wfc1 = SqlWorkflowContent.objects.create( workflow=self.wf1, - sql_content='some_sql', - execute_result=json.dumps([{ - 'id': 1, - 'sql': 'some_content' - }]) + sql_content="some_sql", + execute_result=json.dumps([{"id": 1, "sql": "some_content"}]), ) self.audit1 = WorkflowAudit.objects.create( group_id=1, - group_name='some_group', + group_name="some_group", workflow_id=self.wf1.id, workflow_type=2, - workflow_title='申请标题', - workflow_remark='申请备注', - audit_auth_groups='1', - current_audit='1', - next_audit='-1', - current_status=0) - self.wl = WorkflowLog.objects.create(audit_id=self.audit1.audit_id, - operation_type=1) - r = self.client.post('/api/auth/token/', {'username': 'test_user', 'password': 'test_password'}, format='json') - self.token = r.data['access'] - self.client.credentials(HTTP_AUTHORIZATION='Bearer ' + self.token) - SysConfig().set('api_user_whitelist', self.user.id) + workflow_title="申请标题", + workflow_remark="申请备注", + audit_auth_groups="1", + current_audit="1", + next_audit="-1", + current_status=0, + ) + self.wl = WorkflowLog.objects.create( + audit_id=self.audit1.audit_id, operation_type=1 + ) + r = self.client.post( + "/api/auth/token/", + {"username": "test_user", "password": "test_password"}, + format="json", + ) + self.token = r.data["access"] + self.client.credentials(HTTP_AUTHORIZATION="Bearer " + self.token) + SysConfig().set("api_user_whitelist", self.user.id) def tearDown(self): self.user.delete() @@ -383,45 +428,43 @@ def tearDown(self): def test_get_sql_workflow_list(self): """测试获取SQL上线工单列表""" - r = self.client.get('/api/v1/workflow/', format='json') + r = self.client.get("/api/v1/workflow/", format="json") self.assertEqual(r.status_code, status.HTTP_200_OK) - self.assertEqual(r.json()['count'], 1) + self.assertEqual(r.json()["count"], 1) def test_get_audit_list(self): """测试获取待审核工单列表""" - json_data = { - "engineer": "test_user" - } - r = self.client.post('/api/v1/workflow/auditlist/', json_data, format='json') + json_data = {"engineer": "test_user"} + r = self.client.post("/api/v1/workflow/auditlist/", json_data, format="json") self.assertEqual(r.status_code, status.HTTP_200_OK) - self.assertEqual(r.json()['count'], 1) + self.assertEqual(r.json()["count"], 1) def test_get_workflow_log_list(self): """测试获工单日志""" json_data = { "workflow_id": self.wf1.id, - "workflow_type": self.audit1.workflow_type + "workflow_type": self.audit1.workflow_type, } - r = self.client.post('/api/v1/workflow/log/', json_data, format='json') + r = self.client.post("/api/v1/workflow/log/", json_data, format="json") self.assertEqual(r.status_code, status.HTTP_200_OK) - self.assertEqual(r.json()['count'], 1) + self.assertEqual(r.json()["count"], 1) def test_submit_workflow(self): """测试提交SQL上线工单""" json_data = { "workflow": { - "workflow_name": "上线工单1", - "demand_url": "test", - "group_id": 1, - "db_name": "test_db", - "engineer": self.user.username, - "instance": self.ins.id + "workflow_name": "上线工单1", + "demand_url": "test", + "group_id": 1, + "db_name": "test_db", + "engineer": self.user.username, + "instance": self.ins.id, }, - "sql_content": "alter table abc add column note varchar(64);" + "sql_content": "alter table abc add column note varchar(64);", } - r = self.client.post('/api/v1/workflow/', json_data, format='json') + r = self.client.post("/api/v1/workflow/", json_data, format="json") self.assertEqual(r.status_code, status.HTTP_201_CREATED) - self.assertEqual(r.json()['workflow']['workflow_name'], '上线工单1') + self.assertEqual(r.json()["workflow"]["workflow_name"], "上线工单1") def test_audit_workflow(self): """测试审核工单""" @@ -430,11 +473,11 @@ def test_audit_workflow(self): "workflow_id": self.wf1.id, "audit_remark": "取消", "workflow_type": self.audit1.workflow_type, - "audit_type": "cancel" + "audit_type": "cancel", } - r = self.client.post('/api/v1/workflow/audit/', json_data, format='json') + r = self.client.post("/api/v1/workflow/audit/", json_data, format="json") self.assertEqual(r.status_code, status.HTTP_200_OK) - self.assertEqual(r.json(), {'msg': 'canceled'}) + self.assertEqual(r.json(), {"msg": "canceled"}) def test_execute_workflow(self): """测试执行工单""" @@ -444,16 +487,16 @@ def test_execute_workflow(self): "workflow_id": self.wf1.id, "audit_remark": "通过", "workflow_type": self.audit1.workflow_type, - "audit_type": "pass" + "audit_type": "pass", } - self.client.post('/api/v1/workflow/audit/', audit_data, format='json') + self.client.post("/api/v1/workflow/audit/", audit_data, format="json") # 再执行 execute_data = { "engineer": self.user.username, "workflow_id": self.wf1.id, "workflow_type": self.audit1.workflow_type, - "mode": "manual" + "mode": "manual", } - r = self.client.post('/api/v1/workflow/execute/', execute_data, format='json') + r = self.client.post("/api/v1/workflow/execute/", execute_data, format="json") self.assertEqual(r.status_code, status.HTTP_200_OK) - self.assertEqual(r.json(), {'msg': '开始执行,执行结果请到工单详情页查看'}) + self.assertEqual(r.json(), {"msg": "开始执行,执行结果请到工单详情页查看"}) diff --git a/sql_api/urls.py b/sql_api/urls.py index a4c064b33b..1c098e6b64 100644 --- a/sql_api/urls.py +++ b/sql_api/urls.py @@ -1,43 +1,57 @@ from django.urls import path, include from sql_api import views from rest_framework import routers -from rest_framework_simplejwt.views import TokenObtainPairView, TokenRefreshView, TokenVerifyView -from drf_spectacular.views import SpectacularAPIView, SpectacularRedocView, SpectacularSwaggerView +from rest_framework_simplejwt.views import ( + TokenObtainPairView, + TokenRefreshView, + TokenVerifyView, +) +from drf_spectacular.views import ( + SpectacularAPIView, + SpectacularRedocView, + SpectacularSwaggerView, +) from . import api_user, api_instance, api_workflow router = routers.DefaultRouter() urlpatterns = [ - path('v1/', include(router.urls)), - path('auth/token/', TokenObtainPairView.as_view(), name='token_obtain_pair'), - path('auth/token/refresh/', TokenRefreshView.as_view(), name='token_refresh'), - path('auth/token/verify/', TokenVerifyView.as_view(), name='token_verify'), - path('schema/', SpectacularAPIView.as_view(), name='schema'), - path('swagger/', SpectacularSwaggerView.as_view(url_name='sql_api:schema'), name='swagger'), - path('redoc/', SpectacularRedocView.as_view(url_name='sql_api:schema'), name='redoc'), - path('v1/user/', api_user.UserList.as_view()), - path('v1/user//', api_user.UserDetail.as_view()), - path('v1/user/group/', api_user.GroupList.as_view()), - path('v1/user/group//', api_user.GroupDetail.as_view()), - path('v1/user/resourcegroup/', api_user.ResourceGroupList.as_view()), - path('v1/user/resourcegroup//', api_user.ResourceGroupDetail.as_view()), - path('v1/user/auth/', api_user.UserAuth.as_view()), - path('v1/user/2fa/', api_user.TwoFA.as_view()), - path('v1/user/2fa/state/', api_user.TwoFAState.as_view()), - path('v1/user/2fa/save/', api_user.TwoFASave.as_view()), - path('v1/user/2fa/verify/', api_user.TwoFAVerify.as_view()), - path('v1/instance/', api_instance.InstanceList.as_view()), - path('v1/instance//', api_instance.InstanceDetail.as_view()), - path('v1/instance/resource/', api_instance.InstanceResource.as_view()), - path('v1/instance/tunnel/', api_instance.TunnelList.as_view()), - path('v1/instance/rds/', api_instance.AliyunRdsList.as_view()), - path('v1/workflow/', api_workflow.WorkflowList.as_view()), - path('v1/workflow/sqlcheck/', api_workflow.ExecuteCheck.as_view()), - path('v1/workflow/audit/', api_workflow.AuditWorkflow.as_view()), - path('v1/workflow/auditlist/', api_workflow.WorkflowAuditList.as_view()), - path('v1/workflow/execute/', api_workflow.ExecuteWorkflow.as_view()), - path('v1/workflow/log/', api_workflow.WorkflowLogList.as_view()), - path('info', views.info), - path('debug', views.debug), - path('do_once/mirage', views.mirage) + path("v1/", include(router.urls)), + path("auth/token/", TokenObtainPairView.as_view(), name="token_obtain_pair"), + path("auth/token/refresh/", TokenRefreshView.as_view(), name="token_refresh"), + path("auth/token/verify/", TokenVerifyView.as_view(), name="token_verify"), + path("schema/", SpectacularAPIView.as_view(), name="schema"), + path( + "swagger/", + SpectacularSwaggerView.as_view(url_name="sql_api:schema"), + name="swagger", + ), + path( + "redoc/", SpectacularRedocView.as_view(url_name="sql_api:schema"), name="redoc" + ), + path("v1/user/", api_user.UserList.as_view()), + path("v1/user//", api_user.UserDetail.as_view()), + path("v1/user/group/", api_user.GroupList.as_view()), + path("v1/user/group//", api_user.GroupDetail.as_view()), + path("v1/user/resourcegroup/", api_user.ResourceGroupList.as_view()), + path("v1/user/resourcegroup//", api_user.ResourceGroupDetail.as_view()), + path("v1/user/auth/", api_user.UserAuth.as_view()), + path("v1/user/2fa/", api_user.TwoFA.as_view()), + path("v1/user/2fa/state/", api_user.TwoFAState.as_view()), + path("v1/user/2fa/save/", api_user.TwoFASave.as_view()), + path("v1/user/2fa/verify/", api_user.TwoFAVerify.as_view()), + path("v1/instance/", api_instance.InstanceList.as_view()), + path("v1/instance//", api_instance.InstanceDetail.as_view()), + path("v1/instance/resource/", api_instance.InstanceResource.as_view()), + path("v1/instance/tunnel/", api_instance.TunnelList.as_view()), + path("v1/instance/rds/", api_instance.AliyunRdsList.as_view()), + path("v1/workflow/", api_workflow.WorkflowList.as_view()), + path("v1/workflow/sqlcheck/", api_workflow.ExecuteCheck.as_view()), + path("v1/workflow/audit/", api_workflow.AuditWorkflow.as_view()), + path("v1/workflow/auditlist/", api_workflow.WorkflowAuditList.as_view()), + path("v1/workflow/execute/", api_workflow.ExecuteWorkflow.as_view()), + path("v1/workflow/log/", api_workflow.WorkflowLogList.as_view()), + path("info", views.info), + path("debug", views.debug), + path("do_once/mirage", views.mirage), ] diff --git a/sql_api/views.py b/sql_api/views.py index ec8e17e43a..5e470f5bc6 100644 --- a/sql_api/views.py +++ b/sql_api/views.py @@ -23,15 +23,13 @@ def info(request): # 获取django_q信息 - django_q_version = '.'.join(str(i) for i in django_q.VERSION) + django_q_version = ".".join(str(i) for i in django_q.VERSION) system_info = { - 'archery': { - 'version': archery.display_version + "archery": {"version": archery.display_version}, + "django_q": { + "version": django_q_version, }, - 'django_q': { - 'version': django_q_version, - } } return JsonResponse(system_info) @@ -39,24 +37,24 @@ def info(request): @superuser_required def debug(request): # 获取完整信息 - full = request.GET.get('full') + full = request.GET.get("full") # 系统配置 sys_config = SysConfig().sys_config # 敏感信息处理 secret_keys = [ - 'inception_remote_backup_password', - 'ding_app_secret', - 'feishu_app_secret', - 'mail_smtp_password' + "inception_remote_backup_password", + "ding_app_secret", + "feishu_app_secret", + "mail_smtp_password", ] sys_config.update({k: "******" for k in secret_keys}) # MySQL信息 cursor = connection.cursor() mysql_info = { - 'mysql_server_info': cursor.db.mysql_server_info, - 'timezone_name': cursor.db.timezone_name + "mysql_server_info": cursor.db.mysql_server_info, + "timezone_name": cursor.db.timezone_name, } # Redis信息 @@ -64,30 +62,30 @@ def debug(request): redis_conn = get_redis_connection("default") full_redis_info = redis_conn.info() redis_info = { - 'redis_version': full_redis_info.get('redis_version'), - 'redis_mode': full_redis_info.get('redis_mode'), - 'role': full_redis_info.get('role'), - 'maxmemory_human': full_redis_info.get('maxmemory_human'), - 'used_memory_human': full_redis_info.get('used_memory_human'), + "redis_version": full_redis_info.get("redis_version"), + "redis_mode": full_redis_info.get("redis_mode"), + "role": full_redis_info.get("role"), + "maxmemory_human": full_redis_info.get("maxmemory_human"), + "used_memory_human": full_redis_info.get("used_memory_human"), } except Exception as e: - redis_info = f'获取Redis信息报错:{e}' + redis_info = f"获取Redis信息报错:{e}" full_redis_info = redis_info # django_q try: - django_q_version = '.'.join(str(i) for i in django_q.VERSION) + django_q_version = ".".join(str(i) for i in django_q.VERSION) broker = get_broker() stats = Stat.get_all(broker=broker) queue_size = broker.queue_size() lock_size = broker.lock_size() if lock_size: - queue_size = '{}({})'.format(queue_size, lock_size) + queue_size = "{}({})".format(queue_size, lock_size) q_broker_stats = { - 'info': broker.info(), - 'Queued': queue_size, - 'Success': Success.objects.count(), - 'Failures': Failure.objects.count(), + "info": broker.info(), + "Queued": queue_size, + "Success": Success.objects.count(), + "Failures": Failure.objects.count(), } q_cluster_stats = [] for stat in stats: @@ -95,93 +93,104 @@ def debug(request): uptime = (timezone.now() - stat.tob).total_seconds() hours, remainder = divmod(uptime, 3600) minutes, seconds = divmod(remainder, 60) - uptime = '%d:%02d:%02d' % (hours, minutes, seconds) - q_cluster_stats.append({ - 'host': stat.host, - 'cluster_id': stat.cluster_id, - 'state': stat.status, - 'pool': len(stat.workers), - 'tq': stat.task_q_size, - 'rq': stat.done_q_size, - 'rc': stat.reincarnations, - 'up': uptime - }) + uptime = "%d:%02d:%02d" % (hours, minutes, seconds) + q_cluster_stats.append( + { + "host": stat.host, + "cluster_id": stat.cluster_id, + "state": stat.status, + "pool": len(stat.workers), + "tq": stat.task_q_size, + "rq": stat.done_q_size, + "rc": stat.reincarnations, + "up": uptime, + } + ) django_q_info = { - 'version': django_q_version, - 'conf': django_q.conf.Conf.conf, - 'q_cluster_stats': q_cluster_stats if q_cluster_stats else '没有正在运行的集群信息,请检查django_q状态', - 'q_broker_stats': q_broker_stats + "version": django_q_version, + "conf": django_q.conf.Conf.conf, + "q_cluster_stats": q_cluster_stats + if q_cluster_stats + else "没有正在运行的集群信息,请检查django_q状态", + "q_broker_stats": q_broker_stats, } except Exception as e: - django_q_info = f'获取django_q信息报错:{e}' + django_q_info = f"获取django_q信息报错:{e}" # Inception和goInception信息 - go_inception_host = sys_config.get('go_inception_host') - go_inception_port = sys_config.get('go_inception_port', 0) - inception_remote_backup_host = sys_config.get('inception_remote_backup_host', '') - inception_remote_backup_port = sys_config.get('inception_remote_backup_port', '') - inception_remote_backup_user = sys_config.get('inception_remote_backup_user', '') - inception_remote_backup_password = sys_config.get('inception_remote_backup_password', '') + go_inception_host = sys_config.get("go_inception_host") + go_inception_port = sys_config.get("go_inception_port", 0) + inception_remote_backup_host = sys_config.get("inception_remote_backup_host", "") + inception_remote_backup_port = sys_config.get("inception_remote_backup_port", "") + inception_remote_backup_user = sys_config.get("inception_remote_backup_user", "") + inception_remote_backup_password = sys_config.get( + "inception_remote_backup_password", "" + ) # goInception try: - goinc_conn = MySQLdb.connect(host=go_inception_host, port=int(go_inception_port), - connect_timeout=1, cursorclass=MySQLdb.cursors.DictCursor) + goinc_conn = MySQLdb.connect( + host=go_inception_host, + port=int(go_inception_port), + connect_timeout=1, + cursorclass=MySQLdb.cursors.DictCursor, + ) cursor = goinc_conn.cursor() - cursor.execute('inception get variables') + cursor.execute("inception get variables") rows = cursor.fetchall() full_goinception_info = dict() for row in rows: - full_goinception_info[row.get('Variable_name')] = row.get('Value') + full_goinception_info[row.get("Variable_name")] = row.get("Value") goinception_info = { - 'version': full_goinception_info.get('version'), - 'max_allowed_packet': full_goinception_info.get('max_allowed_packet'), - 'lang': full_goinception_info.get('lang'), - 'osc_on': full_goinception_info.get('osc_on'), - 'osc_bin_dir': full_goinception_info.get('osc_bin_dir'), - 'ghost_on': full_goinception_info.get('ghost_on'), + "version": full_goinception_info.get("version"), + "max_allowed_packet": full_goinception_info.get("max_allowed_packet"), + "lang": full_goinception_info.get("lang"), + "osc_on": full_goinception_info.get("osc_on"), + "osc_bin_dir": full_goinception_info.get("osc_bin_dir"), + "ghost_on": full_goinception_info.get("ghost_on"), } except Exception as e: - goinception_info = f'获取goInception信息报错:{e}' + goinception_info = f"获取goInception信息报错:{e}" full_goinception_info = goinception_info # 备份库 try: - bak_conn = MySQLdb.connect(host=inception_remote_backup_host, - port=int(inception_remote_backup_port), - user=inception_remote_backup_user, - password=inception_remote_backup_password, - connect_timeout=1) + bak_conn = MySQLdb.connect( + host=inception_remote_backup_host, + port=int(inception_remote_backup_port), + user=inception_remote_backup_user, + password=inception_remote_backup_password, + connect_timeout=1, + ) cursor = bak_conn.cursor() - cursor.execute('select 1;') - backup_info = 'normal' + cursor.execute("select 1;") + backup_info = "normal" except Exception as e: - backup_info = f'无法连接goInception备份库\n{e}' + backup_info = f"无法连接goInception备份库\n{e}" # PACKAGES installed_packages = pkg_resources.working_set - installed_packages_list = sorted([ - "%s==%s" % (i.key, i.version) for i in installed_packages]) + installed_packages_list = sorted( + ["%s==%s" % (i.key, i.version) for i in installed_packages] + ) # 最终集合 system_info = { - 'archery': { - 'version': archery.display_version - }, - 'django_q': django_q_info, - 'inception': { - 'goinception_info': full_goinception_info if full else goinception_info, - 'backup_info': backup_info + "archery": {"version": archery.display_version}, + "django_q": django_q_info, + "inception": { + "goinception_info": full_goinception_info if full else goinception_info, + "backup_info": backup_info, }, - 'runtime_info': { - 'python_version': platform.python_version(), - 'mysql_info': mysql_info, - 'redis_info': full_redis_info if full else redis_info, - 'sys_argv': sys.argv, - 'platform': platform.uname() + "runtime_info": { + "python_version": platform.python_version(), + "mysql_info": mysql_info, + "redis_info": full_redis_info if full else redis_info, + "sys_argv": sys.argv, + "platform": platform.uname(), }, - 'sys_config': sys_config, - 'packages': installed_packages_list + "sys_config": sys_config, + "packages": installed_packages_list, } return JsonResponse(system_info) @@ -197,7 +206,9 @@ def mirage(request): for ins in Instance.objects.all(): # 忽略解密错误的数据(本身为异常数据) try: - Instance(pk=ins.pk, password=pc.decrypt(ins.password)).save(update_fields=['password']) + Instance(pk=ins.pk, password=pc.decrypt(ins.password)).save( + update_fields=["password"] + ) except: pass # 使用django-mirage-field重新加密 From 33ded63cb9b1cec0d40c6d1d4ec0201065de4143 Mon Sep 17 00:00:00 2001 From: hhyo Date: Fri, 15 Jul 2022 21:15:51 +0800 Subject: [PATCH 3/3] =?UTF-8?q?=E6=9B=B4=E6=96=B0=20badge?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/README.md b/README.md index d1182544d4..98ba150c83 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,6 @@ # Archery -[![star](https://gitee.com/rtttte/Archery/badge/star.svg?theme=gvp)](https://gitee.com/rtttte/Archery) [![Django CI](https://github.com/hhyo/Archery/actions/workflows/django.yml/badge.svg)](https://github.com/hhyo/Archery/actions/workflows/django.yml) [![Release](https://img.shields.io/github/release/hhyo/archery.svg)](https://github.com/hhyo/archery/releases/) [![codecov](https://codecov.io/gh/hhyo/archery/branch/master/graph/badge.svg)](https://codecov.io/gh/hhyo/archery) @@ -11,7 +10,7 @@ [![Publish Docker image](https://github.com/hhyo/Archery/actions/workflows/docker-image.yml/badge.svg)](https://github.com/hhyo/Archery/actions/workflows/docker-image.yml) [![docker_pulls](https://img.shields.io/docker/pulls/hhyo/archery.svg)](https://hub.docker.com/r/hhyo/archery/) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](http://github.com/hhyo/archery/blob/master/LICENSE) -[![996.icu](https://img.shields.io/badge/link-996.icu-red.svg)](https://996.icu) +[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [文档](https://archerydms.com/) | [FAQ](https://github.com/hhyo/archery/wiki/FAQ) | [Releases](https://github.com/hhyo/archery/releases/)