Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: store Entry suffix separately #503

Merged
merged 2 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions tagstudio/src/core/constants.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from enum import Enum

VERSION: str = "9.3.2" # Major.Minor.Patch
VERSION_BRANCH: str = "" # Usually "" or "Pre-Release"

# The folder & file names where TagStudio keeps its data relative to a library.
TS_FOLDER_NAME: str = ".TagStudio"
BACKUP_FOLDER_NAME: str = "backups"
COLLAGE_FOLDER_NAME: str = "collages"
LIBRARY_FILENAME: str = "ts_library.json"

# TODO: Turn this whitelist into a user-configurable blacklist.
IMAGE_TYPES: list[str] = [
Expand Down Expand Up @@ -122,13 +119,5 @@
+ SHORTCUT_TYPES
)


TAG_FAVORITE = 1
TAG_ARCHIVED = 0


class LibraryPrefs(Enum):
IS_EXCLUDE_LIST = True
EXTENSION_LIST: list[str] = [".json", ".xmp", ".aae"]
PAGE_SIZE: int = 500
DB_VERSION: int = 1
40 changes: 40 additions & 0 deletions tagstudio/src/core/driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from pathlib import Path

import structlog
from PySide6.QtCore import QSettings
from src.core.constants import TS_FOLDER_NAME
from src.core.enums import SettingItems
from src.core.library.alchemy.library import LibraryStatus

logger = structlog.get_logger(__name__)


class DriverMixin:
settings: QSettings

def evaluate_path(self, open_path: str | None) -> LibraryStatus:
"""Check if the path of library is valid."""
library_path: Path | None = None
if open_path:
library_path = Path(open_path)
if not library_path.exists():
logger.error("Path does not exist.", open_path=open_path)
return LibraryStatus(success=False, message="Path does not exist.")
elif self.settings.value(
SettingItems.START_LOAD_LAST, defaultValue=True, type=bool
) and self.settings.value(SettingItems.LAST_LIBRARY):
library_path = Path(str(self.settings.value(SettingItems.LAST_LIBRARY)))
if not (library_path / TS_FOLDER_NAME).exists():
logger.error(
"TagStudio folder does not exist.",
library_path=library_path,
ts_folder=TS_FOLDER_NAME,
)
self.settings.setValue(SettingItems.LAST_LIBRARY, "")
# dont consider this a fatal error, just skip opening the library
library_path = None

return LibraryStatus(
success=True,
library_path=library_path,
)
30 changes: 30 additions & 0 deletions tagstudio/src/core/enums.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import enum
from typing import Any
from uuid import uuid4


class SettingItems(str, enum.Enum):
Expand Down Expand Up @@ -31,3 +33,31 @@ class MacroID(enum.Enum):
BUILD_URL = "build_url"
MATCH = "match"
CLEAN_URL = "clean_url"


class DefaultEnum(enum.Enum):
"""Allow saving multiple identical values in property called .default."""

default: Any

def __new__(cls, value):
# Create the enum instance
obj = object.__new__(cls)
# make value random
obj._value_ = uuid4()
# assign the actual value into .default property
obj.default = value
return obj

@property
def value(self):
raise AttributeError("access the value via .default property instead")


class LibraryPrefs(DefaultEnum):
"""Library preferences with default value accessible via .default property."""

IS_EXCLUDE_LIST = True
EXTENSION_LIST: list[str] = [".json", ".xmp", ".aae"]
PAGE_SIZE: int = 500
DB_VERSION: int = 2
16 changes: 8 additions & 8 deletions tagstudio/src/core/library/alchemy/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,27 @@ class BaseField(Base):
__abstract__ = True

@declared_attr
def id(cls) -> Mapped[int]: # noqa: N805
def id(self) -> Mapped[int]:
return mapped_column(primary_key=True, autoincrement=True)

@declared_attr
def type_key(cls) -> Mapped[str]: # noqa: N805
def type_key(self) -> Mapped[str]:
return mapped_column(ForeignKey("value_type.key"))

@declared_attr
def type(cls) -> Mapped[ValueType]: # noqa: N805
return relationship(foreign_keys=[cls.type_key], lazy=False) # type: ignore
def type(self) -> Mapped[ValueType]:
return relationship(foreign_keys=[self.type_key], lazy=False) # type: ignore

@declared_attr
def entry_id(cls) -> Mapped[int]: # noqa: N805
def entry_id(self) -> Mapped[int]:
return mapped_column(ForeignKey("entries.id"))

@declared_attr
def entry(cls) -> Mapped[Entry]: # noqa: N805
return relationship(foreign_keys=[cls.entry_id]) # type: ignore
def entry(self) -> Mapped[Entry]:
return relationship(foreign_keys=[self.entry_id]) # type: ignore

@declared_attr
def position(cls) -> Mapped[int]: # noqa: N805
def position(self) -> Mapped[int]:
return mapped_column(default=0)

def __hash__(self):
Expand Down
91 changes: 69 additions & 22 deletions tagstudio/src/core/library/alchemy/library.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
import shutil
import sys
import unicodedata
from dataclasses import dataclass
from datetime import UTC, datetime
Expand Down Expand Up @@ -34,8 +35,8 @@
TAG_ARCHIVED,
TAG_FAVORITE,
TS_FOLDER_NAME,
LibraryPrefs,
)
from ...enums import LibraryPrefs
from .db import make_tables
from .enums import FieldTypeEnum, FilterState, TagColor
from .fields import (
Expand All @@ -48,8 +49,6 @@
from .joins import TagField, TagSubtag
from .models import Entry, Folder, Preferences, Tag, TagAlias, ValueType

LIBRARY_FILENAME: str = "ts_library.sqlite"

logger = structlog.get_logger(__name__)


Expand Down Expand Up @@ -115,6 +114,15 @@ def __getitem__(self, index: int) -> Entry:
return self.items[index]


@dataclass
class LibraryStatus:
"""Keep status of library opening operation."""

success: bool
library_path: Path | None = None
message: str | None = None


class Library:
"""Class for the Library object, and all CRUD operations made upon it."""

Expand All @@ -123,30 +131,28 @@ class Library:
engine: Engine | None
folder: Folder | None

FILENAME: str = "ts_library.sqlite"

def close(self):
if self.engine:
self.engine.dispose()
self.library_dir = None
self.storage_path = None
self.folder = None

def open_library(self, library_dir: Path | str, storage_path: str | None = None) -> None:
if isinstance(library_dir, str):
library_dir = Path(library_dir)

self.library_dir = library_dir
def open_library(self, library_dir: Path, storage_path: str | None = None) -> LibraryStatus:
if storage_path == ":memory:":
self.storage_path = storage_path
else:
self.verify_ts_folders(self.library_dir)
self.storage_path = self.library_dir / TS_FOLDER_NAME / LIBRARY_FILENAME
self.verify_ts_folders(library_dir)
self.storage_path = library_dir / TS_FOLDER_NAME / self.FILENAME

connection_string = URL.create(
drivername="sqlite",
database=str(self.storage_path),
)

logger.info("opening library", connection_string=connection_string)
logger.info("opening library", library_dir=library_dir, connection_string=connection_string)
self.engine = create_engine(connection_string)
with Session(self.engine) as session:
make_tables(self.engine)
Expand All @@ -159,9 +165,24 @@ def open_library(self, library_dir: Path | str, storage_path: str | None = None)
# default tags may exist already
session.rollback()

if "pytest" not in sys.modules:
db_version = session.scalar(
select(Preferences).where(Preferences.key == LibraryPrefs.DB_VERSION.name)
)

if not db_version:
# TODO - remove after #503 is merged and LibraryPrefs.DB_VERSION increased again
return LibraryStatus(
success=False,
message=(
"Library version mismatch.\n"
f"Found: v0, expected: v{LibraryPrefs.DB_VERSION.default}"
),
)

for pref in LibraryPrefs:
try:
session.add(Preferences(key=pref.name, value=pref.value))
session.add(Preferences(key=pref.name, value=pref.default))
session.commit()
except IntegrityError:
logger.debug("preference already exists", pref=pref)
Expand All @@ -183,11 +204,30 @@ def open_library(self, library_dir: Path | str, storage_path: str | None = None)
logger.debug("ValueType already exists", field=field)
session.rollback()

db_version = session.scalar(
select(Preferences).where(Preferences.key == LibraryPrefs.DB_VERSION.name)
)
# if the db version is different, we cant proceed
if db_version.value != LibraryPrefs.DB_VERSION.default:
logger.error(
"DB version mismatch",
db_version=db_version.value,
expected=LibraryPrefs.DB_VERSION.default,
)
# TODO - handle migration
return LibraryStatus(
success=False,
message=(
"Library version mismatch.\n"
f"Found: v{db_version.value}, expected: v{LibraryPrefs.DB_VERSION.default}"
),
)

# check if folder matching current path exists already
self.folder = session.scalar(select(Folder).where(Folder.path == self.library_dir))
self.folder = session.scalar(select(Folder).where(Folder.path == library_dir))
if not self.folder:
folder = Folder(
path=self.library_dir,
path=library_dir,
uuid=str(uuid4()),
)
session.add(folder)
Expand All @@ -196,6 +236,10 @@ def open_library(self, library_dir: Path | str, storage_path: str | None = None)
session.commit()
self.folder = folder

# everything is fine, set the library path
self.library_dir = library_dir
return LibraryStatus(success=True, library_path=library_dir)

@property
def default_fields(self) -> list[BaseField]:
with Session(self.engine) as session:
Expand Down Expand Up @@ -324,15 +368,18 @@ def add_entries(self, items: list[Entry]) -> list[int]:

with Session(self.engine) as session:
# add all items
session.add_all(items)
session.flush()

new_ids = [item.id for item in items]
try:
session.add_all(items)
session.commit()
except IntegrityError:
session.rollback()
logger.exception("IntegrityError")
return []

new_ids = [item.id for item in items]
session.expunge_all()

session.commit()

return new_ids

def remove_entries(self, entry_ids: list[int]) -> None:
Expand Down Expand Up @@ -396,9 +443,9 @@ def search_library(

if not search.id: # if `id` is set, we don't need to filter by extensions
if extensions and is_exclude_list:
statement = statement.where(Entry.path.notilike(f"%.{','.join(extensions)}"))
statement = statement.where(Entry.suffix.notin_(extensions))
elif extensions:
statement = statement.where(Entry.path.ilike(f"%.{','.join(extensions)}"))
statement = statement.where(Entry.suffix.in_(extensions))

statement = statement.options(
selectinload(Entry.text_fields),
Expand Down Expand Up @@ -770,7 +817,7 @@ def save_library_backup_to_disk(self) -> Path:
target_path = self.library_dir / TS_FOLDER_NAME / BACKUP_FOLDER_NAME / filename

shutil.copy2(
self.library_dir / TS_FOLDER_NAME / LIBRARY_FILENAME,
self.library_dir / TS_FOLDER_NAME / self.FILENAME,
target_path,
)

Expand Down
3 changes: 3 additions & 0 deletions tagstudio/src/core/library/alchemy/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class Entry(Base):
folder: Mapped[Folder] = relationship("Folder")

path: Mapped[Path] = mapped_column(PathType, unique=True)
suffix: Mapped[str] = mapped_column()

text_fields: Mapped[list[TextField]] = relationship(
back_populates="entry",
Expand Down Expand Up @@ -177,6 +178,8 @@ def __init__(
self.path = path
self.folder = folder

self.suffix = path.suffix.lstrip(".").lower()

for field in fields:
if isinstance(field, TextField):
self.text_fields.append(field)
Expand Down
2 changes: 2 additions & 0 deletions tagstudio/src/core/library/json/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,8 @@ def compressed_dict(self):
class Library:
"""Class for the Library object, and all CRUD operations made upon it."""

FILENAME: str = "ts_library.json"

def __init__(self) -> None:
# Library Info =========================================================
self.library_dir: Path = None
Expand Down
4 changes: 2 additions & 2 deletions tagstudio/src/qt/modals/file_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
QVBoxLayout,
QWidget,
)
from src.core.constants import LibraryPrefs
from src.core.enums import LibraryPrefs
from src.core.library import Library
from src.qt.widgets.panel import PanelWidget

Expand Down Expand Up @@ -104,7 +104,7 @@ def save(self):
for i in range(self.table.rowCount()):
ext = self.table.item(i, 0)
if ext and ext.text().strip():
extensions.append(ext.text().strip().lower())
extensions.append(ext.text().strip().lstrip(".").lower())

# save preference
self.lib.set_prefs(LibraryPrefs.EXTENSION_LIST, extensions)
Loading