diff --git a/core/Cargo.toml b/core/Cargo.toml index 6154b38be437..d1619ef6c5ca 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -66,6 +66,10 @@ path = "bin/migrations/20241204_create_nodes.rs" name = "fix_created_dsdocs" path = "bin/migrations/20241203_fix_created_dsdocs.rs" +[[bin]] +name = "elasticsearch_backfill_document_tags_index" +path = "bin/migrations/20250205_backfill_document_tags_index.rs" + [[test]] name = "oauth_connections_test" path = "src/oauth/tests/functional_connections.rs" diff --git a/core/bin/core_api.rs b/core/bin/core_api.rs index 51f1e60102d9..98cc15cc3dce 100644 --- a/core/bin/core_api.rs +++ b/core/bin/core_api.rs @@ -50,7 +50,8 @@ use dust::{ run, search_filter::{Filterable, SearchFilter}, search_stores::search_store::{ - ElasticsearchSearchStore, NodesSearchFilter, NodesSearchOptions, SearchStore, + DatasourceViewFilter, ElasticsearchSearchStore, NodesSearchFilter, NodesSearchOptions, + SearchStore, TagsQueryType, }, sqlite_workers::client::{self, HEARTBEAT_INTERVAL_MS}, stores::{ @@ -3265,6 +3266,58 @@ async fn nodes_search( ) } +#[derive(serde::Deserialize)] +#[serde(deny_unknown_fields)] +struct TagsSearchPayload { + query: Option, + query_type: Option, + data_source_views: Vec, + node_ids: Option>, + limit: Option, +} + +async fn tags_search( + State(state): State>, + Json(payload): Json, +) -> (StatusCode, Json) { + match state + .search_store + .search_tags( + payload.query, + payload.query_type, + payload.data_source_views, + payload.node_ids, + payload.limit, + ) + .await + { + Ok(tags) => ( + StatusCode::OK, + Json(APIResponse { + error: None, + response: Some(json!({ + "tags": tags + .into_iter() + .map(|(k, v, ds)| json!({ + "tag": k, + "match_count": v, + "data_sources": ds.into_iter() + .map(|(k, _v)| k) + .collect::>() + })) + .collect::>() + })), + }), + ), + Err(e) => error_response( + StatusCode::INTERNAL_SERVER_ERROR, + "internal_server_error", + "Failed to list tags", + Some(e), + ), + } +} + #[derive(serde::Deserialize)] struct DatabaseQueryRunPayload { query: String, @@ -3753,6 +3806,7 @@ fn main() { //Search .route("/nodes/search", post(nodes_search)) + .route("/tags/search", post(tags_search)) // Misc .route("/tokenize", post(tokenize)) diff --git a/core/bin/elasticsearch/backfill_folders_index.rs b/core/bin/elasticsearch/backfill_folders_index.rs index 9dab728f2c95..a1d98991792b 100644 --- a/core/bin/elasticsearch/backfill_folders_index.rs +++ b/core/bin/elasticsearch/backfill_folders_index.rs @@ -94,6 +94,7 @@ async fn list_data_source_nodes( parents.get(1).cloned(), parents, source_url, + None, ), row_id, element_row_id, diff --git a/core/bin/migrations/20250205_backfill_document_tags_index.rs b/core/bin/migrations/20250205_backfill_document_tags_index.rs new file mode 100644 index 000000000000..ab3c3be12b85 --- /dev/null +++ b/core/bin/migrations/20250205_backfill_document_tags_index.rs @@ -0,0 +1,231 @@ +use bb8::Pool; +use bb8_postgres::PostgresConnectionManager; +use clap::Parser; +use dust::{ + search_stores::search_store::ElasticsearchSearchStore, + stores::{postgres::PostgresStore, store::Store}, +}; +use elasticsearch::{http::request::JsonBody, indices::IndicesExistsParts, BulkParts}; +use http::StatusCode; +use serde_json::json; +use tokio_postgres::NoTls; + +#[derive(Clone, Copy, Debug, clap::ValueEnum)] +enum NodeType { + Document, + Table, +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + #[arg(long, help = "The version of the index")] + index_version: u32, + + #[arg(long, help = "Skip confirmation")] + skip_confirmation: bool, + + #[arg(long, help = "The cursor to start from", default_value = "0")] + start_cursor: i64, + + #[arg(long, help = "The batch size", default_value = "100")] + batch_size: usize, + + #[arg(long, help = "The type of query to run", default_value = "document")] + query_type: NodeType, +} + +/* + * Backfills tags for documents in Elasticsearch using the postgres table `data_sources_documents` and `tables` + * + * Usage: + * cargo run --bin elasticsearch_backfill_document_tags_index -- --index-version [--skip-confirmation] [--start-cursor ] [--batch-size ] + * + */ +#[tokio::main] +async fn main() { + if let Err(e) = run().await { + eprintln!("Error: {}", e); + std::process::exit(1); + } +} + +async fn list_data_source_documents( + pool: &Pool>, + id_cursor: i64, + batch_size: i64, + query_type: NodeType, +) -> Result, String, String)>, Box> { + let c = pool.get().await?; + + let q = match query_type { + NodeType::Document => { + "SELECT dsd.id,dsd.document_id, dsd.tags_array, ds.data_source_id, ds.internal_id \ + FROM data_sources_documents dsd JOIN data_sources ds ON dsd.data_source = ds.id \ + WHERE dsd.id > $1 ORDER BY dsd.id ASC LIMIT $2" + } + NodeType::Table => { + "SELECT t.id,t.table_id, t.tags_array, ds.data_source_id, ds.internal_id \ + FROM tables t JOIN data_sources ds ON t.data_source = ds.id \ + WHERE t.id > $1 ORDER BY t.id ASC LIMIT $2" + } + }; + + let stmt = c.prepare(q).await?; + let rows = c.query(&stmt, &[&id_cursor, &batch_size]).await?; + + let nodes: Vec<(i64, String, Vec, String, String)> = rows + .iter() + .map(|row| { + let id: i64 = row.get::<_, i64>(0); + let document_id: String = row.get::<_, String>(1); + let tags: Vec = row.get::<_, Vec>(2); + let ds_id: String = row.get::<_, String>(3); + let ds_internal_id: String = row.get::<_, String>(4); + (id, document_id, tags, ds_id, ds_internal_id) + }) + .collect::>(); + Ok(nodes) +} + +async fn run() -> Result<(), Box> { + // parse args and env vars + let args = Args::parse(); + let index_name = "data_sources_nodes"; + let index_version = args.index_version; + let batch_size = args.batch_size; + let start_cursor = args.start_cursor; + let query_type = args.query_type; + + let url = std::env::var("ELASTICSEARCH_URL").expect("ELASTICSEARCH_URL must be set"); + let username = + std::env::var("ELASTICSEARCH_USERNAME").expect("ELASTICSEARCH_USERNAME must be set"); + let password = + std::env::var("ELASTICSEARCH_PASSWORD").expect("ELASTICSEARCH_PASSWORD must be set"); + + let region = std::env::var("DUST_REGION").expect("DUST_REGION must be set"); + + // create ES client + let search_store = ElasticsearchSearchStore::new(&url, &username, &password).await?; + + let index_fullname = format!("core.{}_{}", index_name, index_version); + + // err if index does not exist + let response = search_store + .client + .indices() + .exists(IndicesExistsParts::Index(&[index_fullname.as_str()])) + .send() + .await?; + + if response.status_code() != StatusCode::OK { + return Err(anyhow::anyhow!("Index does not exist").into()); + } + + if !args.skip_confirmation { + println!( + "Are you sure you want to backfill the index {} in region {}? (y/N)", + index_fullname, region + ); + let mut input = String::new(); + std::io::stdin().read_line(&mut input).unwrap(); + if input.trim() != "y" { + return Err(anyhow::anyhow!("Aborted").into()); + } + } + + let db_uri = std::env::var("CORE_DATABASE_READ_REPLICA_URI") + .expect("CORE_DATABASE_READ_REPLICA_URI must be set"); + let store = PostgresStore::new(&db_uri).await?; + // loop on all nodes in postgres using id as cursor, stopping when id is + // greated than the last id in data_sources_nodes at start of backfill + let mut next_cursor = start_cursor; + + // grab last id in data_sources_nodes + let pool = store.raw_pool(); + let c = pool.get().await?; + let last_id = c + .query_one("SELECT MAX(id) FROM data_sources_documents", &[]) + .await?; + let last_id: i64 = last_id.get(0); + println!("Last id in data_sources_nodes: {}", last_id); + while next_cursor <= last_id { + println!( + "Processing {} nodes, starting at id {}. ", + batch_size, next_cursor + ); + let (nodes, next_id_cursor) = + get_node_batch(pool, next_cursor, batch_size, query_type).await?; + + next_cursor = match next_id_cursor { + Some(cursor) => cursor, + None => { + println!( + "No more nodes to process (last id: {}). \nBackfill complete.", + last_id + ); + break; + } + }; + + let nodes_values: Vec<_> = nodes + .into_iter() + .filter(|node| node.2.len() > 0) + .flat_map(|node| { + [ + json!({"update": {"_id": format!("{}__{}", node.4, node.1) }}), + json!({"doc": {"tags": node.2}}), + ] + }) + .collect(); + + let nodes_body: Vec> = nodes_values.into_iter().map(|v| v.into()).collect(); + + search_store + .client + .bulk(BulkParts::Index(index_fullname.as_str())) + .body(nodes_body) + .send() + .await?; + match response.status_code() { + StatusCode::OK => println!("Succeeded."), + _ => { + let body = response.json::().await?; + eprintln!("\n{:?}", body); + return Err(anyhow::anyhow!("Failed to insert nodes").into()); + } + } + } + + Ok(()) +} + +async fn get_node_batch( + pool: &Pool>, + next_cursor: i64, + batch_size: usize, + query_type: NodeType, +) -> Result< + (Vec<(i64, String, Vec, String, String)>, Option), + Box, +> { + let nodes = list_data_source_documents( + &pool, + next_cursor, + batch_size.try_into().unwrap(), + query_type, + ) + .await?; + let last_node = nodes.last().cloned(); + let nodes_length = nodes.len(); + match last_node { + Some((last_row_id, _, _, _, _)) => Ok(( + nodes, + match nodes_length == batch_size { + true => Some(last_row_id), + false => None, + }, + )), + None => Ok((vec![], None)), + } +} diff --git a/core/src/data_sources/data_source.rs b/core/src/data_sources/data_source.rs index cd271fecc412..a4da59484477 100644 --- a/core/src/data_sources/data_source.rs +++ b/core/src/data_sources/data_source.rs @@ -236,6 +236,7 @@ impl From for Node { document.parent_id, document.parents.clone(), document.source_url, + Some(document.tags), ) } } diff --git a/core/src/data_sources/folder.rs b/core/src/data_sources/folder.rs index 06a75655034e..75c148dad1ce 100644 --- a/core/src/data_sources/folder.rs +++ b/core/src/data_sources/folder.rs @@ -89,6 +89,7 @@ impl From for Node { folder.parent_id, folder.parents, folder.source_url, + None, ) } } diff --git a/core/src/data_sources/node.rs b/core/src/data_sources/node.rs index ed1ad5304efb..1b7a57b71896 100644 --- a/core/src/data_sources/node.rs +++ b/core/src/data_sources/node.rs @@ -89,6 +89,7 @@ pub struct Node { pub parent_id: Option, pub parents: Vec, pub source_url: Option, + pub tags: Option>, } impl Node { @@ -104,6 +105,7 @@ impl Node { parent_id: Option, parents: Vec, source_url: Option, + tags: Option>, ) -> Self { Node { data_source_id: data_source_id.to_string(), @@ -117,6 +119,7 @@ impl Node { parent_id: parent_id.clone(), parents, source_url, + tags, } } diff --git a/core/src/databases/table.rs b/core/src/databases/table.rs index dbd6ef17e393..11581dbaebb2 100644 --- a/core/src/databases/table.rs +++ b/core/src/databases/table.rs @@ -324,6 +324,7 @@ impl From for Node { table.parents.get(1).cloned(), table.parents, table.source_url, + Some(table.tags), ) } } diff --git a/core/src/search_stores/indices/data_sources_nodes_3.mappings.json b/core/src/search_stores/indices/data_sources_nodes_3.mappings.json index 275cdc62b949..03be45d2c031 100644 --- a/core/src/search_stores/indices/data_sources_nodes_3.mappings.json +++ b/core/src/search_stores/indices/data_sources_nodes_3.mappings.json @@ -45,6 +45,19 @@ "provider_visibility": { "type": "keyword", "index": false + }, + "tags": { + "type": "text", + "analyzer": "standard", + "fields": { + "edge": { + "type": "text", + "analyzer": "edge_analyzer" + }, + "keyword": { + "type": "keyword" + } + } } } -} +} \ No newline at end of file diff --git a/core/src/search_stores/migrations/20250131_add_tags.http b/core/src/search_stores/migrations/20250131_add_tags.http new file mode 100644 index 000000000000..219f691cef80 --- /dev/null +++ b/core/src/search_stores/migrations/20250131_add_tags.http @@ -0,0 +1,18 @@ +PUT core.data_sources_nodes/_mapping +{ + "properties": { + "tags": { + "type": "text", + "analyzer": "standard", + "fields": { + "edge": { + "type": "text", + "analyzer": "edge_analyzer" + }, + "keyword": { + "type": "keyword" + } + } + } + } +} diff --git a/core/src/search_stores/search_store.rs b/core/src/search_stores/search_store.rs index 2a6662ebb2d1..bf8c510f05b5 100644 --- a/core/src/search_stores/search_store.rs +++ b/core/src/search_stores/search_store.rs @@ -9,7 +9,8 @@ use elasticsearch::{ DeleteByQueryParts, DeleteParts, Elasticsearch, IndexParts, SearchParts, }; use elasticsearch_dsl::{ - BoolQuery, FieldSort, Query, Script, ScriptSort, ScriptSortType, Search, Sort, SortOrder, + Aggregation, BoolQuery, FieldSort, Query, Script, ScriptSort, ScriptSortType, Search, Sort, + SortOrder, }; use serde_json::json; use tracing::{error, info}; @@ -30,6 +31,13 @@ pub enum SortDirection { Desc, } +#[derive(serde::Deserialize, Clone, Copy, Debug)] +#[serde(rename_all = "lowercase")] +pub enum TagsQueryType { + Exact, + Prefix, +} + #[derive(serde::Deserialize, Debug)] pub struct SortSpec { pub field: String, @@ -75,6 +83,15 @@ pub trait SearchStore { async fn delete_node(&self, node: Node) -> Result<()>; async fn delete_data_source_nodes(&self, data_source_id: &str) -> Result<()>; + async fn search_tags( + &self, + query: Option, + query_type: Option, + data_source_views: Vec, + node_ids: Option>, + limit: Option, + ) -> Result)>>; + fn clone_box(&self) -> Box; } @@ -329,6 +346,134 @@ impl SearchStore for ElasticsearchSearchStore { } } + async fn search_tags( + &self, + query: Option, + query_type: Option, + data_source_views: Vec, + node_ids: Option>, + limit: Option, + ) -> Result)>> { + let query_type = query_type.unwrap_or(TagsQueryType::Exact); + + // check there is at least one data source view filter + // !! do not remove; without data source view filter this endpoint is + // dangerous as any data from any workspace can be retrieved + if data_source_views.is_empty() { + return Err(anyhow::anyhow!("No data source views provided")); + } + + let bool_query = Query::bool().must( + Query::bool() + .should( + data_source_views + .into_iter() + .map(|f| { + let mut bool_query = Query::bool(); + + bool_query = + bool_query.filter(Query::term("data_source_id", f.data_source_id)); + + if !f.view_filter.is_empty() { + bool_query = + bool_query.filter(Query::terms("parents", f.view_filter)); + } + + Query::Bool(bool_query) + }) + .collect::>(), + ) + .minimum_should_match(1), + ); + + let bool_query = match node_ids { + None => bool_query, + Some(node_ids) => bool_query.must(Query::terms("node_id", node_ids)), + }; + let bool_query = match query.clone() { + None => bool_query, + Some(p) => match query_type { + TagsQueryType::Exact => bool_query.must(Query::term("tags.keyword", p)), + TagsQueryType::Prefix => bool_query.must(Query::match_phrase("tags.edge", p)), + }, + }; + let aggregate = Aggregation::terms("tags.keyword"); + let aggregate = match query.clone() { + None => aggregate, + Some(p) => match query_type { + TagsQueryType::Exact => aggregate.include(p), + // Prefix will be filtered in the code, as it needs to be filtered case insensitive + TagsQueryType::Prefix => aggregate, + }, + }; + let aggregate = + aggregate.aggregate("tags_in_datasource", Aggregation::terms("data_source_id")); + let search = Search::new() + .size(0) + .query(bool_query) + .aggregate("unique_tags", aggregate.size(limit.unwrap_or(100))); + + let response = self + .client + .search(SearchParts::Index(&[NODES_INDEX_NAME])) + .body(search) + .send() + .await?; + + // Parse response and return tags + match response.status_code().is_success() { + true => { + let response_body = response.json::().await?; + Ok(response_body["aggregations"]["unique_tags"]["buckets"] + .as_array() + .unwrap_or(&vec![]) + .iter() + .filter_map(|bucket| { + bucket["key"] + .as_str() + .map(|key| { + match query_type { + // For prefix query - only include if key matches query (case insensitive) + TagsQueryType::Prefix => { + if let Some(q) = query.as_ref() { + if !key.to_lowercase().starts_with(&q.to_lowercase()) { + return None; + } + } + } + // Exact query is already filtered in the aggregation + TagsQueryType::Exact => {} + } + + Some(( + key.to_string(), + bucket["doc_count"].as_u64().unwrap_or(0), + bucket["tags_in_datasource"]["buckets"] + .as_array() + .unwrap_or(&vec![]) + .iter() + .filter_map(|bucket| { + bucket["key"].as_str().map(|key| { + ( + key.to_string(), + bucket["doc_count"].as_u64().unwrap_or(0), + ) + }) + }) + .collect::>(), + )) + }) + .flatten() + }) + .collect()) + } + false => { + let error = response.json::().await?; + Err(anyhow::anyhow!("Failed to list tags: {}", error)) + } + } + } + fn clone_box(&self) -> Box { Box::new(self.clone()) } diff --git a/core/src/stores/migrations/20250204_nodes_table_tags.sql b/core/src/stores/migrations/20250204_nodes_table_tags.sql new file mode 100644 index 000000000000..2d56b8138e3d --- /dev/null +++ b/core/src/stores/migrations/20250204_nodes_table_tags.sql @@ -0,0 +1,4 @@ +ALTER TABLE + data_sources_nodes +ADD COLUMN IF NOT EXISTS + tags_array text[] NOT NULL DEFAULT array[]::text[]; diff --git a/core/src/stores/postgres.rs b/core/src/stores/postgres.rs index e66520b895de..4f1b99ffb25b 100644 --- a/core/src/stores/postgres.rs +++ b/core/src/stores/postgres.rs @@ -14,6 +14,7 @@ use tokio_postgres::{NoTls, Transaction}; use crate::data_sources::data_source::DocumentStatus; use crate::data_sources::node::{Node, NodeType, ProviderVisibility}; +use crate::search_filter::Filterable; use crate::{ blocks::block::BlockType, cached_request::CachedRequest, @@ -55,6 +56,7 @@ pub struct UpsertNode<'a> { pub provider_visibility: &'a Option, pub parents: &'a Vec, pub source_url: &'a Option, + pub tags: &'a Vec, } impl PostgresStore { @@ -167,9 +169,9 @@ impl PostgresStore { let stmt = tx .prepare( "INSERT INTO data_sources_nodes \ - (id, data_source, created, node_id, timestamp, title, mime_type, provider_visibility, parents, source_url, \ + (id, data_source, created, node_id, timestamp, title, mime_type, provider_visibility, parents, source_url, tags_array, \ document, \"table\", folder) \ - VALUES (DEFAULT, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) \ + VALUES (DEFAULT, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) \ ON CONFLICT (data_source, node_id) DO UPDATE \ SET timestamp = EXCLUDED.timestamp, title = EXCLUDED.title, \ mime_type = EXCLUDED.mime_type, parents = EXCLUDED.parents, \ @@ -193,6 +195,7 @@ impl PostgresStore { &upsert_params.provider_visibility, &upsert_params.parents, &upsert_params.source_url, + &upsert_params.tags, &document_row_id, &table_row_id, &folder_row_id, @@ -1331,7 +1334,7 @@ impl Store for PostgresStore { 1 => (r[0].get(0), r[0].get(1)), _ => unreachable!(), }; - + // TODO(Thomas-020425): Read tags from nodes table. let r = match version_hash { None => { c.query( @@ -1929,6 +1932,7 @@ impl Store for PostgresStore { provider_visibility: &document.provider_visibility, parents: &document.parents, source_url: &document.source_url, + tags: &document.tags, }, data_source_row_id, document_row_id, @@ -2708,6 +2712,7 @@ impl Store for PostgresStore { provider_visibility: table.provider_visibility(), parents: table.parents(), source_url: table.source_url(), + tags: &table.get_tags(), }, data_source_row_id, table_row_id, @@ -3225,6 +3230,7 @@ impl Store for PostgresStore { mime_type: folder.mime_type(), parents: folder.parents(), source_url: folder.source_url(), + tags: &vec![], }, data_source_row_id, folder_row_id, @@ -3469,7 +3475,7 @@ impl Store for PostgresStore { let stmt = c .prepare( - "SELECT timestamp, title, mime_type, provider_visibility, parents, node_id, document, \"table\", folder, source_url \ + "SELECT timestamp, title, mime_type, provider_visibility, parents, node_id, document, \"table\", folder, source_url, tags_array \ FROM data_sources_nodes \ WHERE data_source = $1 AND node_id = $2 LIMIT 1", ) @@ -3496,6 +3502,7 @@ impl Store for PostgresStore { _ => unreachable!(), }; let source_url: Option = row[0].get::<_, Option>(9); + let tags: Option> = row[0].get::<_, Option>>(10); Ok(Some(( Node::new( &data_source_id, @@ -3509,6 +3516,7 @@ impl Store for PostgresStore { parents.get(1).cloned(), parents, source_url, + tags, ), row_id, ))) @@ -3527,7 +3535,7 @@ impl Store for PostgresStore { let stmt = c .prepare( - "SELECT dsn.timestamp, dsn.title, dsn.mime_type, dsn.provider_visibility, dsn.parents, dsn.node_id, dsn.document, dsn.\"table\", dsn.folder, ds.data_source_id, ds.internal_id, dsn.source_url, dsn.id \ + "SELECT dsn.timestamp, dsn.title, dsn.mime_type, dsn.provider_visibility, dsn.parents, dsn.node_id, dsn.document, dsn.\"table\", dsn.folder, ds.data_source_id, ds.internal_id, dsn.source_url, dsn.tags_array, dsn.id \ FROM data_sources_nodes dsn JOIN data_sources ds ON dsn.data_source = ds.id \ WHERE dsn.id > $1 ORDER BY dsn.id ASC LIMIT $2", ) @@ -3557,7 +3565,8 @@ impl Store for PostgresStore { _ => unreachable!(), }; let source_url: Option = row.get::<_, Option>(11); - let row_id = row.get::<_, i64>(12); + let tags: Option> = row.get::<_, Option>>(12); + let row_id = row.get::<_, i64>(13); ( Node::new( &data_source_id, @@ -3571,6 +3580,7 @@ impl Store for PostgresStore { parents.get(1).cloned(), parents, source_url, + tags, ), row_id, element_row_id, diff --git a/front/migrations/20250204_backfill_tags.ts b/front/migrations/20250204_backfill_tags.ts new file mode 100644 index 000000000000..3ed162f74362 --- /dev/null +++ b/front/migrations/20250204_backfill_tags.ts @@ -0,0 +1,197 @@ +import { removeNulls } from "@dust-tt/types"; +import type { Sequelize } from "sequelize"; +import { QueryTypes } from "sequelize"; + +import { getCorePrimaryDbConnection } from "@app/lib/production_checks/utils"; +import { DataSourceModel } from "@app/lib/resources/storage/models/data_source"; +import type Logger from "@app/logger/logger"; +import { makeScript } from "@app/scripts/helpers"; + +const BATCH_SIZE = 128; + +async function backfillDataSource( + frontDataSource: DataSourceModel, + coreSequelize: Sequelize, + execute: boolean, + logger: typeof Logger +) { + logger.info("Processing data source"); + + // get datasource id from core + const rows: { id: number }[] = await coreSequelize.query( + `SELECT id FROM data_sources WHERE data_source_id = :dataSourceId;`, + { + replacements: { dataSourceId: frontDataSource.dustAPIDataSourceId }, + type: QueryTypes.SELECT, + } + ); + + if (rows.length === 0) { + logger.error(`Data source ${frontDataSource.id} not found in core`); + return; + } + + const dataSourceId = rows[0].id; + + await backfillDocuments( + dataSourceId, + coreSequelize, + execute, + logger.child({ type: "folders" }) + ); + + await backfillSpreadsheets( + dataSourceId, + coreSequelize, + execute, + logger.child({ type: "spreadsheets" }) + ); +} + +async function backfillSpreadsheets( + dataSourceId: number, + coreSequelize: Sequelize, + execute: boolean, + logger: typeof Logger +) { + logger.info("Processing spreadsheets"); + + // processing the spreadsheets chunk by chunk + let lastId = 0; + let rows: { id: number; tags_array: string[] }[] = []; + + do { + // querying connectors for the next batch of spreadsheets + + rows = await coreSequelize.query( + `SELECT id, "tags_array" + FROM "tables" + WHERE id > :lastId + AND "data_source" = :data_source + AND "tags_array" IS NOT NULL + ORDER BY id + LIMIT :batchSize;`, + { + replacements: { + batchSize: BATCH_SIZE, + lastId, + data_source: dataSourceId, + }, + type: QueryTypes.SELECT, + } + ); + + if (rows.length === 0) { + break; + } + // reconstructing the URLs and node IDs + const tableIds = rows.map((row) => row.id); + const tags = rows.map((row) => `{"${row.tags_array.join('","')}"}`); + + if (execute) { + // updating on core on the nodeIds + await coreSequelize.query( + `UPDATE data_sources_nodes + SET tags_array = CAST(unnest_tags.tags_array AS text[]) + FROM (SELECT unnest(ARRAY [:tableIds]::bigint[]) as table_id, + unnest(ARRAY [:tags]::text[][]) as tags_array) unnest_tags + WHERE data_sources_nodes.data_source = :dataSourceId AND data_sources_nodes.table = unnest_tags.table_id;`, + { replacements: { tags, tableIds, dataSourceId } } + ); + logger.info( + `Updated ${rows.length} spreadsheets from id ${rows[0].id} to id ${rows[rows.length - 1].id}.` + ); + } else { + logger.info( + `Would update ${rows.length} spreadsheets from id ${rows[0].id} to id ${rows[rows.length - 1].id}.` + ); + } + + lastId = rows[rows.length - 1].id; + } while (rows.length === BATCH_SIZE); +} + +async function backfillDocuments( + dataSourceId: number, + coreSequelize: Sequelize, + execute: boolean, + logger: typeof Logger +) { + logger.info("Processing folders"); + + // processing the folders chunk by chunk + let lastId = 0; + let rows: { + id: number; + tags_array: string[]; + }[] = []; + + do { + rows = await coreSequelize.query( + `SELECT id, "tags_array" + FROM data_sources_documents + WHERE id > :lastId + AND "data_source" = :data_source + AND "tags_array" IS NOT NULL + AND "status" = 'latest' + ORDER BY id + LIMIT :batchSize;`, + { + replacements: { + batchSize: BATCH_SIZE, + lastId, + data_source: dataSourceId, + }, + type: QueryTypes.SELECT, + } + ); + + if (rows.length === 0) { + break; + } + + // reconstructing the URLs and node IDs + const documentIds = rows.map((row) => row.id); + const tags = rows.map((row) => `{"${row.tags_array.join('","')}"}`); + + if (execute) { + // updating on core on the nodeIds + await coreSequelize.query( + `UPDATE data_sources_nodes + SET tags_array = CAST(unnest_tags.tags_array AS text[]) + FROM (SELECT unnest(ARRAY[:documentIds]::bigint[]) as document_id, + unnest(ARRAY[:tags]::text[]) as tags_array) unnest_tags + WHERE data_sources_nodes.data_source = :dataSourceId AND data_sources_nodes.document = unnest_tags.document_id;`, + { replacements: { tags, documentIds, dataSourceId } } + ); + logger.info( + `Updated ${rows.length} documents from id ${rows[0].id} to id ${rows[rows.length - 1].id}.` + ); + } else { + logger.info( + `Would update ${rows.length} documents from id ${rows[0].id} to id ${rows[rows.length - 1].id}.` + ); + } + + lastId = rows[rows.length - 1].id; + } while (rows.length === BATCH_SIZE); +} + +makeScript({}, async ({ execute }, logger) => { + const coreSequelize = getCorePrimaryDbConnection(); + const frontDataSources = await DataSourceModel.findAll(); + logger.info(`Found ${frontDataSources.length} Google Drive data sources`); + + for (const frontDataSource of frontDataSources) { + await backfillDataSource( + frontDataSource, + coreSequelize, + execute, + logger.child({ + dataSourceId: frontDataSource.id, + connectorId: frontDataSource.connectorId, + name: frontDataSource.name, + }) + ); + } +});