Skip to content

Commit

Permalink
feat: Refactor concurrency handling in CassandraOnlineStore using Sem…
Browse files Browse the repository at this point in the history
…aphore
  • Loading branch information
Bhargav Dodla committed Jan 29, 2025
1 parent a208568 commit 2ad5ce7
Showing 1 changed file with 14 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
"""

import logging
import time
from datetime import datetime
from functools import partial
from threading import Condition, Lock, Semaphore
from threading import Semaphore
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple

from cassandra.auth import PlainTextAuthProvider
Expand Down Expand Up @@ -552,9 +553,7 @@ def online_write_batch(
print(
f"{get_current_time_in_ms()} Started writing data of size {len(data)} to CassandraOnlineStore"
)
active_tasks = 0
lock = Lock()
condition = Condition(lock)
write_concurrency = config.online_store.write_concurrency

# def clusterStatus(cluster, client_id):
# if cluster is None:
Expand Down Expand Up @@ -583,16 +582,11 @@ def online_write_batch(
# )

def on_success(result, semaphore):
global active_tasks
with condition:
active_tasks -= 1
if active_tasks == 0:
print(f"{get_current_time_in_ms()} Notifying all tasks to complete")
condition.notify_all()
semaphore.release()

def on_failure(exc):
logger.error(f"Error writing a batch: {exc}")
def on_failure(exc, semaphore):
semaphore.release()
logger.exception(f"Error writing a batch: {exc}")
print(f"Error writing a batch: {exc}")
raise Exception("Error writing a batch") from exc

Expand All @@ -605,7 +599,7 @@ def on_failure(exc):
config, "insert4", fqtable=fqtable, session=session
)

semaphore = Semaphore(config.online_store.write_concurrency)
semaphore = Semaphore(write_concurrency)
for entity_key, values, timestamp, created_ts in data:
batch = BatchStatement(batch_type=BatchType.UNLOGGED)
entity_key_bin = serialize_entity_key(
Expand All @@ -620,23 +614,21 @@ def on_failure(exc):
timestamp,
)
batch.add(insert_cql, params)
with condition:
active_tasks += 1
semaphore.acquire()
future = session.execute_async(batch)
future.add_callbacks(partial(on_success, semaphore=semaphore), on_failure)
future.add_callbacks(
partial(on_success, semaphore=semaphore),
partial(on_failure, semaphore=semaphore),
)

# this happens N-1 times, will be corrected outside:
if progress:
progress(1)

# Wait for all tasks to complete
with condition:
while active_tasks > 0:
print(
f"{get_current_time_in_ms()} Waiting for active tasks to complete"
)
condition.wait()
while semaphore._value < write_concurrency:
print(f"{get_current_time_in_ms()} Waiting for active tasks to complete")
time.sleep(0.01)

# with cluster.connect(keyspace) as session:
# for connection in cluster.get_connection_holders():
Expand Down

0 comments on commit 2ad5ce7

Please sign in to comment.