Skip to content

Commit

Permalink
Add run_sorcha: uses ray to parallelize calls to sorcha
Browse files Browse the repository at this point in the history
  • Loading branch information
moeyensj committed Sep 6, 2024
1 parent 08751fe commit b6f2b0a
Show file tree
Hide file tree
Showing 3 changed files with 217 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/adam_test_data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# ruff: noqa: F401
from .main import write_sorcha_inputs
from .main import write_sorcha_inputs, run_sorcha
215 changes: 215 additions & 0 deletions src/adam_test_data/main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import multiprocessing as mp
import os
import sqlite3 as sql
import subprocess
from typing import Literal, Optional, Union

import pyarrow as pa
import pyarrow.compute as pc
import quivr as qv
import ray
from adam_core.propagator.utils import _iterate_chunk_indices
from adam_core.ray_cluster import initialize_use_ray

from .observatory import Observatory, observatory_to_sorcha_config
from .pointings import Pointings
Expand Down Expand Up @@ -308,3 +314,212 @@ def sorcha(
sorcha_stats = SorchaOutputStats.from_csv(f"{output_dir}/{stats_file}.csv")

return sorcha_outputs, sorcha_stats


def sorcha_worker(
orbit_ids: pa.Array,
orbit_ids_indices: tuple[int, int],
output_dir: str,
small_bodies: SmallBodies,
pointings: Pointings,
observatory: Observatory,
time_range: Optional[list[float]] = None,
tag: str = "sorcha",
overwrite: bool = True,
randomization: bool = True,
output_columns: Literal["basic", "all"] = "all",
) -> tuple[Union[SorchaOutputBasic, SorchaOutputAll], SorchaOutputStats]:
"""
Run sorcha on a subset of the input small bodies.
Parameters
----------
orbit_ids : pa.Array
The orbit IDs of the small bodies.
orbit_ids_indices : tuple[int, int]
The indices of the orbit IDs to process.
output_dir : str
The directory to write the Sorcha output to.
small_bodies : SmallBodies
The small body population to run Sorcha on.
pointings : Pointings
The pointings to run Sorcha on.
observatory : Observatory
The observatory to run Sorcha on.
time_range : list[float], optional
The time range to filter the pointings by, by default None.
tag : str, optional
The tag to use for the output files, by default "sorcha".
overwrite : bool, optional
Whether to overwrite existing files, by default False.
randomization : bool, optional
Ramdomize the photometry and astrometry using the calculated uncertainties, by default True.
output_columns : Literal["basic", "all"], optional
The columns to output in the Sorcha output, by default "all".
Returns
-------
tuple[Union[SorchaOutputBasic, SorchaOutputAll], SorchaOutputStats]
Sorcha output observations (in basic or all formats) and statistics
per object and filter.
"""
orbit_ids_chunk = orbit_ids[orbit_ids_indices[0] : orbit_ids_indices[1]]

small_bodies_chunk = small_bodies.apply_mask(
pc.is_in(small_bodies.orbits.orbit_id, orbit_ids_chunk)
)

# Create a subdirectory for this chunk
output_dir_chunk = os.path.join(
output_dir, f"chunk_{orbit_ids_indices[0]:07d}_{orbit_ids_indices[1]:07d}"
)

return sorcha(
output_dir_chunk,
small_bodies_chunk,
pointings,
observatory,
time_range=time_range,
tag=tag,
overwrite=overwrite,
randomization=randomization,
output_columns=output_columns,
)


sorcha_worker_remote = ray.remote(sorcha_worker)
sorcha_worker_remote.options(num_cpus=1)


def run_sorcha(
output_dir: str,
small_bodies: SmallBodies,
pointings: Pointings,
observatory: Observatory,
time_range: Optional[list[float]] = None,
tag: str = "sorcha",
overwrite: bool = True,
randomization: bool = True,
output_columns: Literal["basic", "all"] = "all",
chunk_size: int = 1000,
max_processes: Optional[int] = 1,
) -> tuple[Union[SorchaOutputBasic, SorchaOutputAll], SorchaOutputStats]:
"""
Run sorcha on the given small bodies, pointings, and observatory.
Parameters
----------
output_dir : str
The directory to write the Sorcha output to.
small_bodies : SmallBodies
The small body population to run Sorcha on.
pointings : Pointings
The pointings to run Sorcha on.
observatory : Observatory
The observatory to run Sorcha on.
time_range : list[float], optional
The time range to filter the pointings by, by default None.
tag : str, optional
The tag to use for the output files, by default "sorcha".
overwrite : bool, optional
Whether to overwrite existing files, by default False.
randomization : bool, optional
Ramdomize the photometry and astrometry using the calculated uncertainties, by default True.
output_columns : Literal["basic", "all"], optional
The columns to output in the Sorcha output, by default "all".
chunk_size : int, optional
The number of small bodies to process in each chunk, by default 1000.
max_processes : Optional[int], optional
The maximum number of processes to use, by default 1.
Returns
-------
tuple[Union[SorchaOutputBasic, SorchaOutputAll], SorchaOutputStats]
Sorcha output observations (in basic or all formats) and statistics
per object and filter.
"""
if max_processes is None:
max_processes = mp.cpu_count()

orbit_ids = small_bodies.orbits.orbit_id

if output_columns == "basic":
sorcha_outputs = SorchaOutputBasic.empty()
else:
sorcha_outputs = SorchaOutputAll.empty()
sorcha_stats = SorchaOutputStats.empty()

use_ray = initialize_use_ray(num_cpus=max_processes)
if use_ray:

orbit_ids_ref = ray.put(orbit_ids)
small_bodies_ref = ray.put(small_bodies)
pointings_ref = ray.put(pointings)
observatory_ref = ray.put(observatory)

futures = []
for orbit_ids_indices in _iterate_chunk_indices(orbit_ids, chunk_size):
futures.append(
sorcha_worker_remote.remote(
orbit_ids_ref,
orbit_ids_indices,
output_dir,
small_bodies_ref,
pointings_ref,
observatory=observatory_ref,
time_range=time_range,
tag=tag,
overwrite=overwrite,
randomization=randomization,
output_columns=output_columns,
)
)

if len(futures) >= max_processes * 1.5:
finished, futures = ray.wait(futures, num_returns=1)

sorcha_outputs_chunk, sorcha_stats_chunk = ray.get(finished[0])
sorcha_outputs = qv.concatenate([sorcha_outputs, sorcha_outputs_chunk])
if sorcha_outputs.fragmented():
sorcha_outputs = qv.defragment(sorcha_outputs)

sorcha_stats = qv.concatenate([sorcha_stats, sorcha_stats_chunk])
if sorcha_stats.fragmented():
sorcha_stats = qv.defragment(sorcha_stats)

while futures:
finished, futures = ray.wait(futures, num_returns=1)
sorcha_outputs_chunk, sorcha_stats_chunk = ray.get(finished[0])

sorcha_outputs = qv.concatenate([sorcha_outputs, sorcha_outputs_chunk])
if sorcha_outputs.fragmented():
sorcha_outputs = qv.defragment(sorcha_outputs)

sorcha_stats = qv.concatenate([sorcha_stats, sorcha_stats_chunk])
if sorcha_stats.fragmented():
sorcha_stats = qv.defragment(sorcha_stats)

else:

for orbit_ids_indices in _iterate_chunk_indices(orbit_ids, chunk_size):
sorcha_outputs_chunk, sorcha_stats_chunk = sorcha_worker(
orbit_ids,
orbit_ids_indices,
output_dir,
small_bodies,
pointings,
observatory,
time_range=time_range,
tag=tag,
overwrite=overwrite,
randomization=randomization,
output_columns=output_columns,
)

sorcha_outputs = qv.concatenate([sorcha_outputs, sorcha_outputs_chunk])
if sorcha_outputs.fragmented():
sorcha_outputs = qv.defragment(sorcha_outputs)

sorcha_stats = qv.concatenate([sorcha_stats, sorcha_stats_chunk])
if sorcha_stats.fragmented():
sorcha_stats = qv.defragment(sorcha_stats)
2 changes: 1 addition & 1 deletion src/adam_test_data/tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from adam_core.orbits import Orbits
from adam_core.time import Timestamp

from ..main import sorcha, write_sorcha_inputs
from ..main import run_sorcha, sorcha, write_sorcha_inputs
from ..observatory import FieldOfView, Observatory, Simulation
from ..pointings import Pointings
from ..populations import PhotometricProperties, SmallBodies
Expand Down

0 comments on commit b6f2b0a

Please sign in to comment.