From 2e3ee769e78a16a3dae3b0656b1570e3f7a6da77 Mon Sep 17 00:00:00 2001
From: Pang Wu <pang.wu@samsara.com>
Date: Sun, 8 Dec 2024 01:02:31 -0800
Subject: [PATCH] support spark 3.5.2/3

---
 .gitignore                                    |   3 +
 .../intel/raydp/shims/SparkShimProvider.scala |   5 +-
 python/raydp/spark/__init__.py                |   7 -
 python/raydp/spark/dataset.py                 | 361 +-----------------
 python/raydp/tests/test_spark_cluster.py      |   2 +-
 python/setup.py                               |   8 +-
 6 files changed, 18 insertions(+), 368 deletions(-)

diff --git a/.gitignore b/.gitignore
index 4642a059..571df90a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -23,3 +23,6 @@ examples/.ipynb_checkpoints/
 *.parquet
 *.crc
 _SUCCESS
+
+.metals/
+.bloop/
diff --git a/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala b/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala
index 0a2ba58a..f8caeb9a 100644
--- a/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala
+++ b/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala
@@ -22,7 +22,10 @@ import com.intel.raydp.shims.{SparkShims, SparkShimDescriptor}
 object SparkShimProvider {
   val SPARK350_DESCRIPTOR = SparkShimDescriptor(3, 5, 0)
   val SPARK351_DESCRIPTOR = SparkShimDescriptor(3, 5, 1)
-  val DESCRIPTOR_STRINGS = Seq(s"$SPARK350_DESCRIPTOR", s"$SPARK351_DESCRIPTOR")
+  val SPARK352_DESCRIPTOR = SparkShimDescriptor(3, 5, 2)
+  val SPARK353_DESCRIPTOR = SparkShimDescriptor(3, 5, 3)
+  val DESCRIPTOR_STRINGS = Seq(s"$SPARK350_DESCRIPTOR", s"$SPARK351_DESCRIPTOR", s"$SPARK352_DESCRIPTOR",
+      s"$SPARK353_DESCRIPTOR")
   val DESCRIPTOR = SPARK350_DESCRIPTOR
 }
 
diff --git a/python/raydp/spark/__init__.py b/python/raydp/spark/__init__.py
index 48a8a2ee..6e544c08 100644
--- a/python/raydp/spark/__init__.py
+++ b/python/raydp/spark/__init__.py
@@ -33,10 +33,3 @@
   "from_spark_recoverable"
 ]
 
-try:
-    import ray.util.data
-    from .dataset import RayMLDataset
-    __all__.append("RayMLDataset")
-except ImportError:
-    # Ray MLDataset is removed in Ray 2.0
-    pass
diff --git a/python/raydp/spark/dataset.py b/python/raydp/spark/dataset.py
index 61ca945b..47365430 100644
--- a/python/raydp/spark/dataset.py
+++ b/python/raydp/spark/dataset.py
@@ -18,10 +18,8 @@
 from typing import Callable, Dict, List, NoReturn, Optional, Iterable, Union
 from dataclasses import dataclass
 
-import numpy as np
 import pandas as pd
 import pyarrow as pa
-import pyarrow.parquet as pq
 import pyspark.sql as sql
 from pyspark.sql import SparkSession
 from pyspark.sql.dataframe import DataFrame
@@ -31,18 +29,7 @@
 import ray
 from ray.data import Dataset, from_arrow_refs
 from ray.types import ObjectRef
-import ray.util.iter as parallel_it
 from ray._private.client_mode_hook import client_mode_wrap
-try:
-    import ray.util.data as ml_dataset
-    from ray.util.data import MLDataset
-    from ray.util.data.interface import _SourceShard
-    HAS_MLDATASET = True
-except ImportError:
-    # Ray MLDataset is removed in Ray 2.0
-    HAS_MLDATASET = False
-from raydp.spark.parallel_iterator_worker import ParallelIteratorWorkerWithLen
-from raydp.utils import divide_blocks
 from raydp.spark.ray_cluster_master import RAYDP_SPARK_MASTER_SUFFIX
 
 
@@ -98,40 +85,6 @@ def with_row_ids(self, new_row_ids) -> "RayObjectPiece":
         return RayObjectPiece(self.obj_id, new_row_ids, num_rows)
 
 
-class ParquetPiece(RecordPiece):
-    def __init__(self,
-                 piece: pq.ParquetDatasetPiece,
-                 columns: List[str],
-                 partitions: pq.ParquetPartitions,
-                 row_ids: Optional[List[int]],
-                 num_rows: int):
-        super().__init__(row_ids, num_rows)
-        self.piece = piece
-        self.columns = columns
-        self.partitions = partitions
-
-    def read(self, shuffle: bool) -> pd.DataFrame:
-        pdf = self.piece.read(columns=self.columns,
-                              use_threads=False,
-                              partitions=self.partitions).to_pandas()
-        if self.row_ids:
-            pdf = pdf.loc[self.row_ids]
-
-        if shuffle:
-            pdf = pdf.sample(frac=1.0)
-        return pdf
-
-    def with_row_ids(self, new_row_ids) -> "ParquetPiece":
-        """
-        chang the num_rows to the length of new_row_ids. Keep the original size if
-        the new_row_ids is None.
-        """
-        if new_row_ids:
-            num_rows = len(new_row_ids)
-        else:
-            num_rows = self.num_rows
-        return ParquetPiece(self.piece, self.columns, self.partitions, new_row_ids, num_rows)
-
 
 @dataclass
 class PartitionObjectsOwner:
@@ -303,311 +256,9 @@ def ray_dataset_to_spark_dataframe(spark: sql.SparkSession,
     else:
         raise RuntimeError("ray.to_spark only supports arrow type blocks")
 
-if HAS_MLDATASET:
-    class RecordBatch(_SourceShard):
-        def __init__(self,
-                    shard_id: int,
-                    prefix: str,
-                    record_pieces: List[RecordPiece],
-                    shuffle: bool,
-                    shuffle_seed: int):
-            self._shard_id = shard_id
-            self._prefix = prefix
-            self.record_pieces = record_pieces
-            self.shuffle = shuffle
-            self.shuffle_seed = shuffle_seed
-
-        def prefix(self) -> str:
-            return self._prefix
-
-        @property
-        def shard_id(self) -> int:
-            return self._shard_id
-
-        def __iter__(self) -> Iterable[pd.DataFrame]:
-            if self.shuffle:
-                np.random.seed(self.shuffle_seed)
-                np.random.shuffle(self.record_pieces)
-
-            for piece in self.record_pieces:
-                yield piece.read(self.shuffle)
-
-        def __len__(self):
-            return sum([len(piece) for piece in self.record_pieces])
-
-
-    class RayRecordBatch(RecordBatch):
-        def __init__(self,
-                    shard_id: int,
-                    prefix: str,
-                    record_pieces: List[RecordPiece],
-                    shuffle: bool,
-                    shuffle_seed: int):
-            super().__init__(shard_id, prefix, record_pieces, shuffle, shuffle_seed)
-            self.resolved: bool = False
-
-        def resolve(self, timeout: Optional[float] = None) -> NoReturn:
-            """
-            This is just fetch object from remote object store to local and without deserialization.
-            :param timeout: The maximum amount of time in seconds to wait before returning.
-            """
-            if self.resolved:
-                return
-
-            worker = ray.worker.global_worker
-            worker.check_connected()
-            timeout_ms = int(timeout * 1000) if timeout else -1
-            object_ids = [record.obj_id for record in self.record_pieces]
-            worker.core_worker.get_objects(object_ids, worker.current_task_id, timeout_ms)
-            self.resolved = True
-
-
-    def _create_ml_dataset(name: str,
-                        record_pieces: List[RecordPiece],
-                        record_sizes: List[int],
-                        num_shards: int,
-                        shuffle: bool,
-                        shuffle_seed: int,
-                        RecordBatchCls,
-                        node_hints: List[str] = None) -> MLDataset:
-        if node_hints is not None:
-            assert num_shards % len(node_hints) == 0,\
-                (f"num_shards: {num_shards} should be a multiple"
-                 f" of length of node_hints: {node_hints}")
-        if shuffle_seed:
-            np.random.seed(shuffle_seed)
-        else:
-            np.random.seed(0)
-
-        # split the piece into num_shards partitions
-        divided_blocks = divide_blocks(blocks=record_sizes,
-                                    world_size=num_shards,
-                                    shuffle=shuffle,
-                                    shuffle_seed=shuffle_seed)
-
-        record_batches = []
-
-        for rank, blocks in divided_blocks.items():
-            pieces = []
-            for index, num_samples in blocks:
-                record_size = record_sizes[index]
-                piece = record_pieces[index]
-                if num_samples != record_size:
-                    assert num_samples < record_size
-                    new_row_ids = np.random.choice(
-                        record_size, size=num_samples).tolist()
-                    piece = piece.with_row_ids(new_row_ids)
-                pieces.append(piece)
-
-            if shuffle:
-                np.random.shuffle(pieces)
-            record_batches.append(RecordBatchCls(shard_id=rank,
-                                                prefix=name,
-                                                record_pieces=pieces,
-                                                shuffle=shuffle,
-                                                shuffle_seed=shuffle_seed))
-
-        worker_cls = ray.remote(ParallelIteratorWorkerWithLen)
-        if node_hints is not None:
-            actors = []
-            multiplier = num_shards // len(node_hints)
-            resource_keys = [f"node:{node_hints[i // multiplier]}" for i in range(num_shards)]
-            for g, resource_key in zip(record_batches, resource_keys):
-                actor = worker_cls.options(resources={resource_key: 0.01}).remote(g, False, len(g))
-                actors.append(actor)
-        else:
-            worker_cls = ray.remote(ParallelIteratorWorkerWithLen)
-            actors = [worker_cls.remote(g, False, len(g)) for g in record_batches]
-
-        it = parallel_it.from_actors(actors, name)
-        ds = ml_dataset.from_parallel_iter(
-            it, need_convert=False, batch_size=0, repeated=False)
-        return ds
-
-
-    class RayMLDataset:
-        @staticmethod
-        def from_spark(df: sql.DataFrame,
-                    num_shards: int,
-                    shuffle: bool = True,
-                    shuffle_seed: int = None,
-                    fs_directory: Optional[str] = None,
-                    compression: Optional[str] = None,
-                    node_hints: List[str] = None) -> MLDataset:
-            """ Create a MLDataset from Spark DataFrame
-
-            This method will create a MLDataset from Spark DataFrame.
-
-            :param df: the pyspark.sql.DataFrame
-            :param num_shards: the number of shards will be created for the MLDataset
-            :param shuffle: whether need to shuffle the blocks when create the MLDataset
-            :param shuffle_seed: the shuffle seed, default is 0
-            :param fs_directory: an optional distributed file system directory for cache the
-                DataFrame. We will write the DataFrame to the given directory with parquet
-                format if this is provided. Otherwise, we will write the DataFrame to ray
-                object store.
-            :param compression: the optional compression for write the DataFrame as parquet
-                file. This is only useful when the fs_directory set.
-            :param node_hints: the node hints to create MLDataset actors
-            :return: a MLDataset
-            """
-            df = df.repartition(num_shards)
-            if fs_directory is None:
-                # fs_directory has not provided, we save the Spark DataFrame to ray object store
-                blocks, block_sizes = _save_spark_df_to_object_store(df)
-                record_pieces = [RayObjectPiece(obj, None, num_rows)
-                                for obj, num_rows in zip(blocks, block_sizes)]
-
-                return _create_ml_dataset("from_spark", record_pieces, block_sizes, num_shards,
-                                        shuffle, shuffle_seed, RayRecordBatch,
-                                        node_hints)
-            else:
-                # fs_directory has provided, we write the Spark DataFrame as Parquet files
-                df.write.parquet(fs_directory, compression=compression)
-                # create the MLDataset from the parquet file
-                ds = RayMLDataset.from_parquet(
-                    fs_directory, num_shards, shuffle, shuffle_seed, node_hints)
-                return ds
-
-        @staticmethod
-        def from_parquet(paths: Union[str, List[str]],
-                        num_shards: int,
-                        shuffle: bool = True,
-                        shuffle_seed: int = None,
-                        columns: Optional[List[str]] = None,
-                        node_hints: List[str] = None,
-                        extra_parquet_arguments: Dict = None) -> MLDataset:
-            """ Create a MLDataset from Parquet files.
-
-            :param paths: the parquet files path
-            :param num_shards: the number of shards will be created for the MLDataset
-            :param shuffle: whether need to shuffle the blocks when create the MLDataset
-            :param shuffle_seed: the shuffle seed, default is 0
-            :param columns: the columns that need to read
-            :param node_hints: the node hints to create MLDataset actors
-            :param extra_parquet_arguments: the extra arguments need to pass into the parquet file
-                reading
-            :return: a MLDataset
-            """
-            if not extra_parquet_arguments:
-                extra_parquet_arguments = {}
-            ds = pq.ParquetDataset(paths, **extra_parquet_arguments)
-            pieces = ds.pieces
-            record_pieces = []
-            record_sizes = []
-
-            for piece in pieces:
-                meta_data = piece.get_metadata().to_dict()
-                num_row_groups = meta_data["num_row_groups"]
-                row_groups = meta_data["row_groups"]
-                for i in range(num_row_groups):
-                    num_rows = row_groups[i]["num_rows"]
-                    parquet_ds_piece = pq.ParquetDatasetPiece(piece.path, piece.open_file_func,
-                                                            piece.file_options, i,
-                                                            piece.partition_keys)
-                    # row_ids will be set later
-                    record_pieces.append(ParquetPiece(piece=parquet_ds_piece,
-                                                    columns=columns,
-                                                    partitions=ds.partitions,
-                                                    row_ids=None,
-                                                    num_rows=num_rows))
-                    record_sizes.append(num_rows)
-
-            return _create_ml_dataset("from_parquet", record_pieces, record_sizes, num_shards,
-                                    shuffle, shuffle_seed, RecordBatch, node_hints)
-
-        @staticmethod
-        def to_torch(
-                ds: MLDataset,
-                world_size: int,
-                world_rank: int,
-                batch_size: int,
-                collate_fn: Callable,
-                shuffle: bool = False,
-                shuffle_seed: int = None,
-                local_rank: int = -1,
-                prefer_node: str = None,
-                prefetch: bool = False):
-            """
-            Create DataLoader from a MLDataset
-            :param ds: the MLDataset
-            :param world_size: the world_size of distributed model training
-            :param world_rank: create the DataLoader for the given world_rank
-            :param batch_size: the batch_size of the DtaLoader
-            :param collate_fn: the collate_fn that create tensors from a pandas DataFrame
-            :param shuffle: whether shuffle each batch of data
-            :param shuffle_seed: the shuffle seed
-            :param local_rank: the node local rank. It must be provided if prefer_node is
-                not None.
-            :param prefer_node: the prefer node for create the MLDataset actor
-            :param prefetch: prefetch the data of DataLoader with one thread
-            :return: a pytorch DataLoader
-            """
-            # pylint: disable=C0415
-            import torch
-            from raydp.torch.torch_ml_dataset import PrefetchedDataLoader, TorchMLDataset
-
-            num_shards = ds.num_shards()
-            assert num_shards % world_size == 0, \
-                (f"The number shards of MLDataset({ds}) should be a multiple of "
-                f"world_size({world_size})")
-            multiplier = num_shards // world_size
-
-            selected_ds = None
-            if prefer_node is not None:
-                assert 0 <= local_rank < world_size
-
-                # get all actors
-                # there should be only one actor_set because of select_shards() is not allowed
-                # after union()
-
-                def location_check(actor):
-                    address = ray.actors(actor._actor_id.hex())["Address"]["IPAddress"]
-                    return address == prefer_node
-
-                actors = ds.actor_sets[0].actors
-                actor_indexes = [i for i, actor in enumerate(actors) if location_check(actor)]
-                if len(actor_indexes) % multiplier != 0:
-                    selected_ds = None
-                    logger.warning(f"We could not find enough shard actor in prefer "
-                                f"node({prefer_node}), fail back to normal select_shards(). "
-                                f"Found: ({actor_indexes}) which length is not multiple of "
-                                f"num_shards({num_shards}) // world_size({world_size}).")
-                else:
-                    shard_ids = actor_indexes[local_rank: local_rank + multiplier]
-                    selected_ds = ds.select_shards(shard_ids)
-
-            if selected_ds is None:
-                shard_ids = []
-                i = world_rank
-                step = world_size
-                while i < num_shards:
-                    shard_ids.append(i)
-                    i += step
-                selected_ds = ds.select_shards(shard_ids)
-
-            selected_ds = selected_ds.batch(batch_size)
-            torch_ds = TorchMLDataset(selected_ds, collate_fn, shuffle, shuffle_seed)
-            data_loader = torch.utils.data.DataLoader(dataset=torch_ds,
-                                                    batch_size=None,
-                                                    batch_sampler=None,
-                                                    shuffle=False,
-                                                    num_workers=0,
-                                                    collate_fn=None,
-                                                    pin_memory=False,
-                                                    drop_last=False,
-                                                    sampler=None)
-            if prefetch:
-                data_loader = PrefetchedDataLoader(data_loader)
-            return data_loader
-
-
-    def create_ml_dataset_from_spark(df: sql.DataFrame,
-                                    num_shards: int,
-                                    shuffle: bool,
-                                    shuffle_seed: int,
-                                    fs_directory: Optional[str] = None,
-                                    compression: Optional[str] = None,
-                                    node_hints: List[str] = None) -> MLDataset:
-        return RayMLDataset.from_spark(
-            df, num_shards, shuffle, shuffle_seed, fs_directory, compression, node_hints)
+
+
+    
+
+
+    
diff --git a/python/raydp/tests/test_spark_cluster.py b/python/raydp/tests/test_spark_cluster.py
index 3e793aeb..5d0cdae2 100644
--- a/python/raydp/tests/test_spark_cluster.py
+++ b/python/raydp/tests/test_spark_cluster.py
@@ -60,7 +60,7 @@ def test_legacy_spark_on_fractional_cpu():
     cluster.shutdown()
 
 
-def test_spark_on_fractional_cpu():
+def test_spark_executor_on_fractional_cpu():
     cluster = Cluster(
         initialize_head=True,
         connect=True,
diff --git a/python/setup.py b/python/setup.py
index 72a62b96..b3715ce6 100644
--- a/python/setup.py
+++ b/python/setup.py
@@ -95,12 +95,12 @@ def run(self):
     copy2(SCRIPT_PATH, SCRIPT_TARGET)
 
     install_requires = [
-        "numpy < 2.0.0",
+        "numpy",
         "pandas >= 1.1.4",
         "psutil",
-        "pyarrow >= 4.0.1, <15.0.0",
-        "ray >= 2.1.0, <= 2.38.0",
-        "pyspark >= 3.1.1, <=3.5.1",
+        "pyarrow >= 4.0.1",
+        "ray >= 2.1.0",
+        "pyspark >= 3.1.1, <=3.5.3",
         "netifaces",
         "protobuf > 3.19.5, <= 3.20.3"
     ]