Skip to content

Commit

Permalink
Merge pull request #43 from Ensembl/features/case-sensitive-table
Browse files Browse the repository at this point in the history
updated case sensitivity in table.sql parsing.
  • Loading branch information
JAlvarezJarreta authored Mar 7, 2024
2 parents 5ba5b67 + eb5d4eb commit aa655c3
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
name = "ensembl-py"
description = "Ensembl Python Base Library"
requires-python = ">= 3.8"
version = "1.2.2"
version = "1.2.3"
readme = "README.md"
authors = [
{name = "Ensembl", email = "[email protected]"},
Expand Down Expand Up @@ -46,8 +46,7 @@ dependencies = [
"python-dotenv ~= 0.19.2",
"PyYAML ~= 6.0",
"requests >= 2.22.0",
"sqlalchemy ~= 1.4.0",
"SQLAlchemy-Utils >= 0.37, < 0.39",
"sqlalchemy ~= 1.4.0"
]

[project.optional-dependencies]
Expand Down
20 changes: 14 additions & 6 deletions src/python/ensembl/database/unittestdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,16 @@

__all__ = ['UnitTestDB', 'UnitTestDBError', 'DataLoadingError']

from pathlib import Path
import os
import re
import subprocess
from pathlib import Path
from typing import Iterator, Optional, Union

import sqlalchemy
from sqlalchemy import create_engine, text
from sqlalchemy.engine.url import make_url
from sqlalchemy_utils import database_exists

from sqlalchemy import exc
from .dbconnection import DBConnection, Query, URL


Expand Down Expand Up @@ -82,6 +81,10 @@ def __init__(self, url: URL, dump_dir: os.PathLike, name: Optional[str] = None)
# Establish the connection to the database, load the schema and import the data
self.dbc = DBConnection(db_url)
with self.dbc.begin() as conn:
if self.dbc.dialect == 'mysql':
conn.execute(f"SET FOREIGN_KEY_CHECKS=0;")
elif self.dbc.dialect == 'sqlite':
conn.execute(f"PRAGMA foreign_keys = OFF;")
for query in self._parse_sql_file(dump_dir_path / 'table.sql'):
table = self._get_table_name(query)
try:
Expand All @@ -95,11 +98,16 @@ def __init__(self, url: URL, dump_dir: os.PathLike, name: Optional[str] = None)
conn.execute(f"TRUNCATE TABLE {table}")
else:
conn.execute(f"DELETE FROM {table}")

self._load_data(conn, table, filepath)
except:
if self.dbc.dialect == 'mysql':
conn.execute(f"SET FOREIGN_KEY_CHECKS=1;")
elif self.dbc.dialect == 'sqlite':
conn.execute(f"PRAGMA foreign_keys = ON;")
except exc.SQLAlchemyError as e:
# Make sure the database is deleted before raising the exception
self.drop()
raise
raise e
# Update the loaded metadata information of the database
self.dbc.load_metadata()

Expand Down Expand Up @@ -184,7 +192,7 @@ def _parse_sql_file(filepath: Union[str, bytes, os.PathLike]
@staticmethod
def _get_table_name(query: Query) -> str:
"""Returns the table name of a ``CREATE TABLE`` SQL query, empty string otherwise."""
match = re.search(r'^CREATE[ ]+TABLE[ ]+(`[^`]+`|[^ ]+)', str(query))
match = re.search(r'^CREATE[ ]+TABLE[ ]+(`[^`]+`|[^ ]+)', str(query), flags=re.IGNORECASE)
return match.group(1).strip('`') if match else ''


Expand Down

0 comments on commit aa655c3

Please sign in to comment.