Skip to content

Commit

Permalink
add sst_generated
Browse files Browse the repository at this point in the history
  • Loading branch information
CrysR committed Jan 31, 2025
1 parent 3bf60a3 commit 58687a9
Showing 1 changed file with 22 additions and 9 deletions.
31 changes: 22 additions & 9 deletions src/webapp/routers/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class BatchCreationRequest(BaseModel):
description: str | None = None
# Disabled data means it is no longer in use or not available for use.
batch_disabled: bool = False
file_ids: set[str] | None = None


class BatchInfo(BaseModel):
Expand Down Expand Up @@ -404,14 +405,25 @@ def create_batch(
.all()
)
if len(query_result) == 0:
local_session.get().add(
BatchTable(
name=req.name,
inst_id=str_to_uuid(inst_id),
description=req.description,
creator=str_to_uuid(current_user.user_id),
batch = BatchTable(
name=req.name,
inst_id=str_to_uuid(inst_id),
description=req.description,
creator=str_to_uuid(current_user.user_id),
)
# xxx todo: Query all the files and add them to this batch.
"""
for f_id in req.file_ids:
file_result = (
local_session.get()
.execute(select(FileTable).where(FileTable.id == str_to_uuid(f_id)))
.all()
)
)
for e in file_result:
batch.files.add(e)
"""
local_session.get().add(batch)
local_session.get().commit()
query_result = (
local_session.get()
.execute(select(BatchTable).where(BatchTable.name == req.name))
Expand All @@ -430,14 +442,14 @@ def create_batch(
if len(query_result) > 1:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Batch duplicates found.",
detail="Batch with this name already exists.",
)
return {
"batch_id": uuid_to_str(query_result[0][0].id),
"inst_id": uuid_to_str(query_result[0][0].inst_id),
"name": query_result[0][0].name,
"description": query_result[0][0].description,
"file_ids": [],
"file_ids": query_result[0][0].files,
"creator": uuid_to_str(query_result[0][0].creator),
"deleted": False,
"completed": False,
Expand Down Expand Up @@ -837,6 +849,7 @@ def validate_file(
inst_id=str_to_uuid(inst_id),
uploader=str_to_uuid(current_user.user_id),
source="MANUAL_UPLOAD",
sst_generated=False,
schemas=list(inferred_schemas),
valid=True,
)
Expand Down

0 comments on commit 58687a9

Please sign in to comment.