Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] fix ruff rule C416: unnecessary-comprehension #49852

Merged
merged 5 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ ignore = [
"B027",
"B035",
"B904",
"C416",
"C419",
# Below are auto-fixable rules
"I001",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,7 @@ class ScaleRequest:

def get_non_terminated(self) -> Dict[CloudInstanceId, CloudInstance]:
self._sync_with_api_server()
return copy.deepcopy(
{id: instance for id, instance in self._cached_instances.items()}
)
return copy.deepcopy(dict(self._cached_instances))

def terminate(self, ids: List[CloudInstanceId], request_id: str) -> None:
if request_id in self._requests:
Expand Down
7 changes: 1 addition & 6 deletions python/ray/dashboard/modules/job/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,12 +373,7 @@ async def get_job_info(job_id: str):
job_info = await self.get_info(job_id, timeout)
return job_id, job_info

return {
job_id: job_info
for job_id, job_info in await asyncio.gather(
*[get_job_info(job_id) for job_id in job_ids]
)
}
return dict(await asyncio.gather(*[get_job_info(job_id) for job_id in job_ids]))
kenchung285 marked this conversation as resolved.
Show resolved Hide resolved


def uri_to_http_components(package_uri: str) -> Tuple[str, str]:
Expand Down
2 changes: 1 addition & 1 deletion python/ray/dashboard/modules/metrics/metrics_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def _create_default_grafana_configs(self):
if isinstance(prometheus_headers, list):
prometheus_header_pairs = prometheus_headers
elif isinstance(prometheus_headers, dict):
prometheus_header_pairs = [(k, v) for k, v in prometheus_headers.items()]
prometheus_header_pairs = list(prometheus_headers.items())

data_sources_path = os.path.join(grafana_provisioning_folder, "datasources")
os.makedirs(
Expand Down
10 changes: 2 additions & 8 deletions python/ray/data/tests/test_arrow_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,7 @@ def fixed_size_list_array():
@pytest.fixture
def map_array():
return pa.array(
[
[(key, item) for key, item in zip("abcdefghij", range(10))]
for _ in range(1000)
],
[list(zip("abcdefghij", range(10))) for _ in range(1000)],
type=pa.map_(pa.string(), pa.int64()),
)

Expand Down Expand Up @@ -349,10 +346,7 @@ def complex_nested_array():
]
),
pa.array(
[
[(key, item) for key, item in zip("abcdefghij", range(10))]
for _ in range(1000)
],
[list(zip("abcdefghij", range(10))) for _ in range(1000)],
type=pa.map_(pa.string(), pa.int64()),
),
],
Expand Down
4 changes: 2 additions & 2 deletions python/ray/serve/grpc_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ def __init__(self, grpc_context: grpc._cython.cygrpc._ServicerContext):
self._auth_context = grpc_context.auth_context()
self._code = grpc_context.code()
self._details = grpc_context.details()
self._invocation_metadata = [
self._invocation_metadata = [ # noqa: C416
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why? invocation_metadata() is not an iterator of tuples?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I failed in CI of previous commit using list()
But I pass CI in the following commit after reset to explicit list comprehension
I think elements in invocation_metadata() is not always tuple

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zcin or @edoakes , could you provide an opinion?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can just do: list(grpc_context.invocation_metadata)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's my new commit that use list(grpc_context.invocation_metadata()), but its CI failed.
0764a0b

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here's the error: "Can't pickle <class 'importlib._bootstrap._Metadatum'>: attribute lookup _Metadatum on importlib._bootstrap failed"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so somehow these two lines don't produce the same resulting list... amazing

i'm fine to just add the noqa

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I'll reset to the version of using noqa and rebase the master branch after work

(key, value) for key, value in grpc_context.invocation_metadata()
]
self._peer = grpc_context.peer()
self._peer_identities = grpc_context.peer_identities()
self._peer_identity_key = grpc_context.peer_identity_key()
self._trailing_metadata = [
self._trailing_metadata = [ # noqa: C416
(key, value) for key, value in grpc_context.trailing_metadata()
]
self._compression = None
Expand Down
8 changes: 2 additions & 6 deletions python/ray/serve/tests/test_proxy_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,9 +621,7 @@ def test_proxy_state_manager_timing_out_on_start(number_of_worker_nodes, all_nod
proxy_state._actor_proxy_wrapper.is_ready_response = False

# Capture current proxy states (prior to updating)
prev_proxy_states = {
node_id: state for node_id, state in proxy_state_manager._proxy_states.items()
}
prev_proxy_states = dict(proxy_state_manager._proxy_states)

# Trigger PSM to reconcile
proxy_state_manager.update(proxy_nodes=node_ids)
Expand All @@ -644,9 +642,7 @@ def test_proxy_state_manager_timing_out_on_start(number_of_worker_nodes, all_nod
proxy_state._actor_proxy_wrapper.is_ready_response = True

# Capture current proxy states again (prior to updating)
prev_proxy_states = {
node_id: state for node_id, state in proxy_state_manager._proxy_states.items()
}
prev_proxy_states = dict(proxy_state_manager._proxy_states)

# Trigger PSM to reconcile
proxy_state_manager.update(proxy_nodes=node_ids)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -598,10 +598,7 @@ def poll_status(self, timeout: Optional[float] = None) -> WorkerGroupStatus:
worker_group_status = WorkerGroupStatus(
num_workers=len(self._workers),
latest_start_time=self._latest_start_time,
worker_statuses={
world_rank: worker_status
for world_rank, worker_status in enumerate(poll_results)
},
worker_statuses=dict(enumerate(poll_results)),
)

for callback in self._callbacks:
Expand Down
4 changes: 1 addition & 3 deletions python/ray/train/v2/tests/test_report_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ def generate_worker_group_status(num_workers, num_ckpt, num_dummy, num_none):
)
random.shuffle(worker_statuses)

return WorkerGroupStatus(
num_workers, 0.0, {i: ws for i, ws in enumerate(worker_statuses)}
)
return WorkerGroupStatus(num_workers, 0.0, dict(enumerate(worker_statuses)))


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion python/ray/tune/execution/placement_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def resource_dict_to_pg_factory(spec: Optional[Dict[str, float]] = None):
memory = spec.pop("memory", 0.0)

# If there is a custom_resources key, use as base for bundle
bundle = {k: v for k, v in spec.pop("custom_resources", {}).items()}
bundle = dict(spec.pop("custom_resources", {}))

# Otherwise, consider all other keys as custom resources
if not bundle:
Expand Down
4 changes: 1 addition & 3 deletions python/ray/tune/search/optuna/optuna_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,9 +607,7 @@ def add_evaluated_point(
ot_trial_state = OptunaTrialState.PRUNED

if intermediate_values:
intermediate_values_dict = {
i: value for i, value in enumerate(intermediate_values)
}
intermediate_values_dict = dict(enumerate(intermediate_values))
else:
intermediate_values_dict = None

Expand Down
2 changes: 1 addition & 1 deletion python/ray/util/client/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,7 +707,7 @@ def _get_client_id_from_context(context: Any) -> str:
Get `client_id` from gRPC metadata. If the `client_id` is not present,
this function logs an error and sets the status_code.
"""
metadata = {k: v for k, v in context.invocation_metadata()}
metadata = dict(context.invocation_metadata())
client_id = metadata.get("client_id") or ""
if client_id == "":
logger.error("Client connecting with no client_id")
Expand Down
4 changes: 2 additions & 2 deletions python/ray/util/client/server/dataservicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _get_reconnecting_from_context(context: Any) -> bool:
"""
Get `reconnecting` from gRPC metadata, or False if missing.
"""
metadata = {k: v for k, v in context.invocation_metadata()}
metadata = dict(context.invocation_metadata())
val = metadata.get("reconnecting")
if val is None or val not in ("True", "False"):
logger.error(
Expand Down Expand Up @@ -155,7 +155,7 @@ def Datapath(self, request_iterator, context):
start_time = time.time()
# set to True if client shuts down gracefully
cleanup_requested = False
metadata = {k: v for k, v in context.invocation_metadata()}
metadata = dict(context.invocation_metadata())
client_id = metadata.get("client_id")
if client_id is None:
logger.error("Client connecting with no client_id")
Expand Down
2 changes: 1 addition & 1 deletion python/ray/util/client/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,7 @@ def get_cluster_info(
resp = self.server.ClusterInfo(req, timeout=timeout, metadata=self.metadata)
if resp.WhichOneof("response_type") == "resource_table":
# translate from a proto map to a python dict
output_dict = {k: v for k, v in resp.resource_table.table.items()}
output_dict = dict(resp.resource_table.table)
return output_dict
elif resp.WhichOneof("response_type") == "runtime_context":
return resp.runtime_context
Expand Down
14 changes: 6 additions & 8 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,23 +903,21 @@ def setup(self, config: AlgorithmConfig) -> None:
),
)
# Get the devices of each learner.
learner_locations = [
(i, loc)
for i, loc in enumerate(
learner_locations = list(
enumerate(
self.learner_group.foreach_learner(
func=lambda _learner: (_learner.node, _learner.device),
)
)
]
)
# Get the devices of each AggregatorActor.
aggregator_locations = [
(i, loc)
for i, loc in enumerate(
aggregator_locations = list(
enumerate(
self._aggregator_actor_manager.foreach_actor(
func=lambda actor: (actor._node, actor._device)
)
)
]
)
self._aggregator_actor_to_learner = {}
for agg_idx, aggregator_location in aggregator_locations:
for learner_idx, learner_location in learner_locations:
Expand Down
6 changes: 2 additions & 4 deletions rllib/algorithms/marwil/tests/test_marwil.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,7 @@ def possibly_masked_mean(data_):
# Calculate our own expected values (to then compare against the
# agent's loss output).
module = algo.learner_group._learner.module[DEFAULT_MODULE_ID].unwrapped()
fwd_out = module.forward_train(
{k: v for k, v in batch[DEFAULT_MODULE_ID].items()}
)
fwd_out = module.forward_train(dict(batch[DEFAULT_MODULE_ID]))
advantages = (
batch[DEFAULT_MODULE_ID][Columns.VALUE_TARGETS].detach().cpu().numpy()
- module.compute_values(batch[DEFAULT_MODULE_ID]).detach().cpu().numpy()
Expand Down Expand Up @@ -210,7 +208,7 @@ def possibly_masked_mean(data_):
# calculation above).
total_loss = algo.learner_group._learner.compute_loss_for_module(
module_id=DEFAULT_MODULE_ID,
batch={k: v for k, v in batch[DEFAULT_MODULE_ID].items()},
batch=dict(batch[DEFAULT_MODULE_ID]),
fwd_out=fwd_out,
config=config,
)
Expand Down
4 changes: 1 addition & 3 deletions rllib/algorithms/sac/tests/test_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,7 @@ def test_sac_dict_obs_order(self):

# Dict space .sample() returns an ordered dict.
# Make sure the keys in samples are ordered differently.
dict_samples = [
{k: v for k, v in reversed(dict_space.sample().items())} for _ in range(10)
]
dict_samples = [dict(reversed(dict_space.sample().items())) for _ in range(10)]

class NestedDictEnv(gym.Env):
def __init__(self):
Expand Down
2 changes: 1 addition & 1 deletion rllib/env/remote_base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def stop(self) -> None:
@override(BaseEnv)
def get_sub_environments(self, as_dict: bool = False) -> List[EnvType]:
if as_dict:
return {env_id: actor for env_id, actor in enumerate(self.actors)}
return dict(enumerate(self.actors))
return self.actors

@property
Expand Down
5 changes: 1 addition & 4 deletions rllib/env/tests/test_multi_agent_episode.py
Original file line number Diff line number Diff line change
Expand Up @@ -3498,10 +3498,7 @@ def _mock_multi_agent_records_from_env(
# In the other case we need at least the last observations for the next
# actions.
else:
obs = {
agent_id: agent_obs
for agent_id, agent_obs in episode.get_observations(-1).items()
}
obs = dict(episode.get_observations(-1))

# Sample `size` many records.
done_agents = {aid for aid, t in episode.get_terminateds().items() if t}
Expand Down
5 changes: 1 addition & 4 deletions rllib/env/vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,10 +498,7 @@ def get_sub_environments(self, as_dict: bool = False) -> Union[List[EnvType], di
if not as_dict:
return self.vector_env.get_sub_environments()
else:
return {
_id: env
for _id, env in enumerate(self.vector_env.get_sub_environments())
}
return dict(enumerate(self.vector_env.get_sub_environments()))

@override(BaseEnv)
def try_render(self, env_id: Optional[EnvID] = None) -> None:
Expand Down
4 changes: 2 additions & 2 deletions rllib/env/wrappers/open_spiel.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def step(self, action):
penalties[curr_player] = -0.1

# Compile rewards dict.
rewards = {ag: r for ag, r in enumerate(self.state.returns())}
rewards = dict(enumerate(self.state.returns()))
# Simultaneous game.
else:
assert self.state.current_player() == -2
Expand All @@ -73,7 +73,7 @@ def step(self, action):

# Compile rewards dict and add the accumulated penalties
# (for taking invalid actions).
rewards = {ag: r for ag, r in enumerate(self.state.returns())}
rewards = dict(enumerate(self.state.returns()))
for ag, penalty in penalties.items():
rewards[ag] += penalty

Expand Down
2 changes: 1 addition & 1 deletion rllib/examples/rl_modules/classes/modelv2_to_rlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def compute_values(self, batch: Dict[str, Any], embeddings: Optional[Any] = None
def get_initial_state(self):
"""Converts the initial state list of ModelV2 into a dict (new API stack)."""
init_state_list = self._model_v2.get_initial_state()
return {i: s for i, s in enumerate(init_state_list)}
return dict(enumerate(init_state_list))

def _translate_dist_class(self, old_dist_class):
map_ = {
Expand Down
2 changes: 1 addition & 1 deletion rllib/models/torch/mingpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def configure_gpt_optimizer(
no_decay.add(fpn)

# validate that we considered every parameter
param_dict = {pn: p for pn, p in model.named_parameters()}
param_dict = dict(model.named_parameters())
inter_params = decay & no_decay
union_params = decay | no_decay
assert (
Expand Down
2 changes: 1 addition & 1 deletion rllib/policy/dynamic_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ def _initialize_loss_from_dummy_batch(
{SampleBatch.SEQ_LENS: train_batch[SampleBatch.SEQ_LENS]}
)

self._loss_input_dict.update({k: v for k, v in train_batch.items()})
self._loss_input_dict.update(dict(train_batch))

if log_once("loss_init"):
logger.debug(
Expand Down
2 changes: 1 addition & 1 deletion rllib/policy/dynamic_tf_policy_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,7 @@ def _initialize_loss_from_dummy_batch(
{SampleBatch.SEQ_LENS: train_batch[SampleBatch.SEQ_LENS]}
)

self._loss_input_dict.update({k: v for k, v in train_batch.items()})
self._loss_input_dict.update(dict(train_batch))

if log_once("loss_init"):
logger.debug(
Expand Down
2 changes: 1 addition & 1 deletion rllib/policy/tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def compute_log_likelihoods(
self._state_inputs, state_batches
)
)
builder.add_feed_dict({k: v for k, v in zip(self._state_inputs, state_batches)})
builder.add_feed_dict(dict(zip(self._state_inputs, state_batches)))
if state_batches:
builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))})
# Prev-a and r.
Expand Down
Loading