From 9584ff6ea8b4127e1fb7eb9ca4302eb553918c37 Mon Sep 17 00:00:00 2001 From: zhuofeng Date: Sat, 29 Jul 2023 11:20:51 +0800 Subject: [PATCH] 2 --- hyperopt/spark.py | 23 +++++++--------- hyperopt/tests/integration/test_spark.py | 34 ------------------------ setup.py | 1 - 3 files changed, 9 insertions(+), 49 deletions(-) diff --git a/hyperopt/spark.py b/hyperopt/spark.py index 7c1e739..d90d36f 100644 --- a/hyperopt/spark.py +++ b/hyperopt/spark.py @@ -8,8 +8,6 @@ from hyperopt import base, fmin, Trials from hyperopt.base import validate_timeout, validate_loss_threshold from hyperopt.utils import coarse_utcnow, _get_logger, _get_random_id -from py4j.clientserver import ClientServer - try: from pyspark.sql import SparkSession from pyspark.util import VersionUtils @@ -88,12 +86,13 @@ class SparkTrials(Trials): else spark_session ) self._spark_context = self._spark.sparkContext - self._spark_pinned_threads_enabled = isinstance( - self._spark_context._gateway, ClientServer - ) # The feature to support controlling jobGroupIds is in SPARK-22340 self._spark_supports_job_cancelling = ( - self._spark_pinned_threads_enabled + _spark_major_minor_version + >= ( + 3, + 2, + ) or hasattr(self._spark_context.parallelize([1]), "collectWithJobGroup") ) spark_default_parallelism = self._spark_context.defaultParallelism @@ -479,7 +478,7 @@ class _SparkFMinState: try: worker_rdd = self.spark.sparkContext.parallelize([0], 1) if self.trials._spark_supports_job_cancelling: - if self.trials._spark_pinned_threads_enabled: + if _spark_major_minor_version >= (3, 2): spark_context = self.spark.sparkContext spark_context.setLocalProperty( "spark.jobGroup.id", self._job_group_id @@ -520,14 +519,10 @@ class _SparkFMinState: # The exceptions captured in run_task_on_executor would be returned in the result_or_e finish_trial_run(result_or_e) - if self.trials._spark_pinned_threads_enabled: - try: - # pylint: disable=no-name-in-module,import-outside-toplevel - from pyspark import inheritable_thread_target + if _spark_major_minor_version >= (3, 2): + from pyspark import inheritable_thread_target - run_task_thread = inheritable_thread_target(run_task_thread) - except ImportError: - pass + run_task_thread = inheritable_thread_target(run_task_thread) task_thread = threading.Thread(target=run_task_thread) task_thread.setDaemon(True) diff --git a/hyperopt/tests/integration/test_spark.py b/hyperopt/tests/integration/test_spark.py index 9ea0f19..3146d74 100644 --- a/hyperopt/tests/integration/test_spark.py +++ b/hyperopt/tests/integration/test_spark.py @@ -14,7 +14,6 @@ from six import StringIO from hyperopt import SparkTrials, anneal, base, fmin, hp, rand from hyperopt.tests.unit.test_fmin import test_quadratic1_tpe -from py4j.clientserver import ClientServer @contextlib.contextmanager @@ -62,7 +61,6 @@ class BaseSparkContext: .getOrCreate() ) cls._sc = cls._spark.sparkContext - cls._pin_mode_enabled = isinstance(cls._sc._gateway, ClientServer) cls.checkpointDir = tempfile.mkdtemp() cls._sc.setCheckpointDir(cls.checkpointDir) # Small tests run much faster with spark.sql.shuffle.partitions=4 @@ -590,35 +588,3 @@ class FMinTestCase(unittest.TestCase, BaseSparkContext): call_count = len(os.listdir(output_dir)) self.assertEqual(NUM_TRIALS, call_count) - - def test_pin_thread_off(self): - if self._pin_mode_enabled: - raise unittest.SkipTest() - - spark_trials = SparkTrials(parallelism=2) - self.assertFalse(spark_trials._spark_pinned_threads_enabled) - self.assertTrue(spark_trials._spark_supports_job_cancelling) - fmin( - fn=lambda x: x + 1, - space=hp.uniform("x", -1, 1), - algo=rand.suggest, - max_evals=5, - trials=spark_trials, - ) - self.assertEqual(spark_trials.count_successful_trials(), 5) - - def test_pin_thread_on(self): - if not self._pin_mode_enabled: - raise unittest.SkipTest() - - spark_trials = SparkTrials(parallelism=2) - self.assertTrue(spark_trials._spark_pinned_threads_enabled) - self.assertTrue(spark_trials._spark_supports_job_cancelling) - fmin( - fn=lambda x: x + 1, - space=hp.uniform("x", -1, 1), - algo=rand.suggest, - max_evals=5, - trials=spark_trials, - ) - self.assertEqual(spark_trials.count_successful_trials(), 5) diff --git a/setup.py b/setup.py index d21c3a1..cdfcd62 100644 --- a/setup.py +++ b/setup.py @@ -48,7 +48,6 @@ setuptools.setup( "future", "tqdm", "cloudpickle", - "py4j", ], extras_require={ "SparkTrials": "pyspark", -- 2.37.3.1