Skip to content

Commit

Permalink
apply Black 2024 style in fbcode (11/16)
Browse files Browse the repository at this point in the history
Summary:
Formats the covered files with pyfmt.

paintitblack

Reviewed By: aleivag

Differential Revision: D54447738

fbshipit-source-id: acf3c943c52ac1acaeee46bc1b9c8a3173f9aef4
  • Loading branch information
amyreese authored and facebook-github-bot committed Mar 3, 2024
1 parent fa04b43 commit 0390a00
Show file tree
Hide file tree
Showing 60 changed files with 373 additions and 282 deletions.
2 changes: 1 addition & 1 deletion examples/01_resnet-50/modeling/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def make_default_stages(depth, block_class=None, **kwargs):
in_channels = [64, 256, 512, 1024]
out_channels = [256, 512, 1024, 2048]
ret = []
for (n, s, i, o) in zip(
for n, s, i, o in zip(
num_blocks_per_stage, [1, 2, 2, 2], in_channels, out_channels
):
if depth >= 50:
Expand Down
2 changes: 1 addition & 1 deletion examples/02_detectron2/modeling/backbone/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def make_default_stages(depth, block_class=None, **kwargs):
in_channels = [64, 256, 512, 1024]
out_channels = [256, 512, 1024, 2048]
ret = []
for (n, s, i, o) in zip(
for n, s, i, o in zip(
num_blocks_per_stage, [1, 2, 2, 2], in_channels, out_channels
):
if depth >= 50:
Expand Down
6 changes: 3 additions & 3 deletions examples/04_vit/modeling/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,9 @@ def __init__(

self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = (
self.embed_dim
) = embed_dim # num_features for consistency with other models
self.num_features = self.embed_dim = (
embed_dim # num_features for consistency with other models
)
self.num_prefix_tokens = 1 if class_token else 0
self.no_embed_class = no_embed_class
self.grad_checkpointing = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,12 @@ def convert_ldm_vae_checkpoint(vae_state_dict):
]

if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
new_checkpoint[
f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"
] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight")
new_checkpoint[
f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"
] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias")
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = (
vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight")
)
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = (
vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias")
)

paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
Expand Down Expand Up @@ -295,12 +295,12 @@ def convert_ldm_vae_checkpoint(vae_state_dict):
]

if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
new_checkpoint[
f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"
] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"]
new_checkpoint[
f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"
] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"]
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = (
vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"]
)
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = (
vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"]
)

paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
Expand Down Expand Up @@ -410,12 +410,12 @@ def convert_ldm_unet_checkpoint(unet_state_dict, layers_per_block=2):
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]

if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
new_checkpoint[
f"down_blocks.{block_id}.downsamplers.0.conv.weight"
] = unet_state_dict.pop(f"input_blocks.{i}.0.op.weight")
new_checkpoint[
f"down_blocks.{block_id}.downsamplers.0.conv.bias"
] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias")
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = (
unet_state_dict.pop(f"input_blocks.{i}.0.op.weight")
)
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = (
unet_state_dict.pop(f"input_blocks.{i}.0.op.bias")
)

paths = renew_resnet_paths(resnets)
meta_path = {
Expand Down Expand Up @@ -496,12 +496,12 @@ def convert_ldm_unet_checkpoint(unet_state_dict, layers_per_block=2):
index = list(output_block_list.values()).index(
["conv.bias", "conv.weight"]
)
new_checkpoint[
f"up_blocks.{block_id}.upsamplers.0.conv.weight"
] = unet_state_dict[f"output_blocks.{i}.{index}.conv.weight"]
new_checkpoint[
f"up_blocks.{block_id}.upsamplers.0.conv.bias"
] = unet_state_dict[f"output_blocks.{i}.{index}.conv.bias"]
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = (
unet_state_dict[f"output_blocks.{i}.{index}.conv.weight"]
)
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = (
unet_state_dict[f"output_blocks.{i}.{index}.conv.bias"]
)

# Clear attentions as they have been attributed above.
if len(attentions) == 2:
Expand Down Expand Up @@ -703,9 +703,9 @@ def __init__(self, hf_hub_or_path, ckpt):
hf_hub_or_path, subfolder="text_encoder"
)
self.clip_pt = CLIPTextModel(config)
clip_state_dict[
"text_model.embeddings.position_ids"
] = self.clip_pt.text_model.embeddings.get_buffer("position_ids")
clip_state_dict["text_model.embeddings.position_ids"] = (
self.clip_pt.text_model.embeddings.get_buffer("position_ids")
)
self.clip_pt.load_state_dict(clip_state_dict)
clip_params_ait = map_clip_state_dict(dict(self.clip_pt.named_parameters()))
print("Setting constants")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,12 @@ def convert_ldm_vae_checkpoint(vae_state_dict):
]

if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
new_checkpoint[
f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"
] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight")
new_checkpoint[
f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"
] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias")
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = (
vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight")
)
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = (
vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias")
)

paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
Expand Down Expand Up @@ -295,12 +295,12 @@ def convert_ldm_vae_checkpoint(vae_state_dict):
]

if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
new_checkpoint[
f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"
] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"]
new_checkpoint[
f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"
] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"]
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = (
vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"]
)
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = (
vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"]
)

paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
Expand Down Expand Up @@ -410,12 +410,12 @@ def convert_ldm_unet_checkpoint(unet_state_dict, layers_per_block=2):
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]

if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
new_checkpoint[
f"down_blocks.{block_id}.downsamplers.0.conv.weight"
] = unet_state_dict.pop(f"input_blocks.{i}.0.op.weight")
new_checkpoint[
f"down_blocks.{block_id}.downsamplers.0.conv.bias"
] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias")
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = (
unet_state_dict.pop(f"input_blocks.{i}.0.op.weight")
)
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = (
unet_state_dict.pop(f"input_blocks.{i}.0.op.bias")
)

paths = renew_resnet_paths(resnets)
meta_path = {
Expand Down Expand Up @@ -496,12 +496,12 @@ def convert_ldm_unet_checkpoint(unet_state_dict, layers_per_block=2):
index = list(output_block_list.values()).index(
["conv.bias", "conv.weight"]
)
new_checkpoint[
f"up_blocks.{block_id}.upsamplers.0.conv.weight"
] = unet_state_dict[f"output_blocks.{i}.{index}.conv.weight"]
new_checkpoint[
f"up_blocks.{block_id}.upsamplers.0.conv.bias"
] = unet_state_dict[f"output_blocks.{i}.{index}.conv.bias"]
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = (
unet_state_dict[f"output_blocks.{i}.{index}.conv.weight"]
)
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = (
unet_state_dict[f"output_blocks.{i}.{index}.conv.bias"]
)

# Clear attentions as they have been attributed above.
if len(attentions) == 2:
Expand Down
36 changes: 20 additions & 16 deletions fx2ait/fx2ait/converters/ait_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,14 +957,16 @@ def make_slice(x, slice_idx, name):
),
f"{name}.weight.slice_{i}",
),
None
if bias is None
else make_slice( # bias[wgs*i:wgs*i + wgs,]
bias,
get_batch_dim_slice_idx(
i * w_group_size, i * w_group_size + w_group_size, 1
),
f"{name}.bias.slice_{i}",
(
None
if bias is None
else make_slice( # bias[wgs*i:wgs*i + wgs,]
bias,
get_batch_dim_slice_idx(
i * w_group_size, i * w_group_size + w_group_size, 1
),
f"{name}.bias.slice_{i}",
)
),
transposed=True,
)
Expand Down Expand Up @@ -1435,14 +1437,16 @@ def make_slice(x, slice_idx, name):
),
f"{name}.weight.slice_{i}",
),
None
if bias is None
else make_slice( # bias[wgs*i:wgs*i + wgs,]
bias,
get_batch_dim_slice_idx(
i * w_group_size, i * w_group_size + w_group_size, 1
),
f"{name}.bias.slice_{i}",
(
None
if bias is None
else make_slice( # bias[wgs*i:wgs*i + wgs,]
bias,
get_batch_dim_slice_idx(
i * w_group_size, i * w_group_size + w_group_size, 1
),
f"{name}.bias.slice_{i}",
)
),
transposed=False,
)
Expand Down
20 changes: 11 additions & 9 deletions fx2ait/fx2ait/converters/aten2ait_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,15 +365,17 @@ def make_slice(x, dim, start, end, step, name):
1,
f"{name}.weight.slice_{i}",
),
None
if bias is None
else make_slice( # bias[wgs*i:wgs*i + wgs,]
bias,
0,
i * w_group_size,
i * w_group_size + w_group_size,
1,
f"{name}.bias.slice_{i}",
(
None
if bias is None
else make_slice( # bias[wgs*i:wgs*i + wgs,]
bias,
0,
i * w_group_size,
i * w_group_size + w_group_size,
1,
f"{name}.bias.slice_{i}",
)
),
transposed=transposed,
)
Expand Down
8 changes: 5 additions & 3 deletions fx2ait/fx2ait/lower/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,11 @@ def lower_func(
lowering_start_time = datetime.datetime.now()

self.lower_settings.additional_inputs = (
additional_submodule_inputs[submod_name]
if additional_submodule_inputs
else None,
(
additional_submodule_inputs[submod_name]
if additional_submodule_inputs
else None
),
)

lowered_module = self.lower_pass(
Expand Down
16 changes: 10 additions & 6 deletions fx2ait/fx2ait/tools/common_fx2ait.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,11 @@ def run_test(
mod.to(torch_dtype)
inputs = map_aggregate(
inputs,
lambda inp: inp.to(torch_dtype).contiguous()
if inp.dtype not in (torch.bool, torch.int64)
else inp.contiguous(),
lambda inp: (
inp.to(torch_dtype).contiguous()
if inp.dtype not in (torch.bool, torch.int64)
else inp.contiguous()
),
)
interp = AITInterpreter(
mod,
Expand Down Expand Up @@ -427,9 +429,11 @@ def benchmark_function(
mod.to(torch_dtype)
inputs = map_aggregate(
inputs,
lambda inp: inp.to(torch_dtype).contiguous()
if inp.dtype not in (torch.bool, torch.int64)
else inp.contiguous(),
lambda inp: (
inp.to(torch_dtype).contiguous()
if inp.dtype not in (torch.bool, torch.int64)
else inp.contiguous()
),
)
interp = AITInterpreter(
mod,
Expand Down
6 changes: 3 additions & 3 deletions python/aitemplate/backend/backend_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,9 @@ class GPUBackendSpec(BackendSpec):
"half2": "h2sin",
"bfloat16_2": "h2sin",
"half": "hsin" if Target.current().name() == "cuda" else "hsin_custom",
"bfloat16": "hsin"
if Target.current().name() == "cuda"
else "hsin_custom",
"bfloat16": (
"hsin" if Target.current().name() == "cuda" else "hsin_custom"
),
"float": "sinf",
},
FuncEnum.TANH: {
Expand Down
6 changes: 3 additions & 3 deletions python/aitemplate/backend/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,9 +1066,9 @@ def generate_source(self) -> Dict[str, str]:
"""
device_functions_header_name = f"{self.target.name()}_device_functions.h"
result = {}
result[
"device_functions-generated.h"
] = f'#include "{device_functions_header_name}"'
result["device_functions-generated.h"] = (
f'#include "{device_functions_header_name}"'
)

result["model-generated.h"] = self.generate_model()

Expand Down
2 changes: 1 addition & 1 deletion python/aitemplate/backend/common/concatenate_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,7 @@ def _make_loop_ranges(init_values: List[str]):

loop_ranges = _make_loop_ranges(init_values)
loop_range_strs = []
for (start_idx, end_idx, val) in loop_ranges:
for start_idx, end_idx, val in loop_ranges:
loop_range_strs.append(
INITIALIZATION_LOOP_TEMPLATE.render(
var_name=init_var,
Expand Down
8 changes: 5 additions & 3 deletions python/aitemplate/backend/common/elementwise_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,9 +591,11 @@ def _get_input_alignments(

# Cap alignments based on global_max_alignment.
alignments = [
min(alignment, global_max_alignment)
if alignment is not None
else global_max_alignment
(
min(alignment, global_max_alignment)
if alignment is not None
else global_max_alignment
)
for alignment in alignments
]
return alignments
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -466,9 +466,9 @@ def gen_function_call(func_attrs, backend_spec, indent=" "):
sampling_ratio=func_attrs["sampling_ratio"],
spatial_scale=func_attrs["spatial_scale"],
position_sensitive="true" if func_attrs["position_sensitive"] else "false",
continuous_coordinate="true"
if func_attrs["continuous_coordinate"]
else "false",
continuous_coordinate=(
"true" if func_attrs["continuous_coordinate"] else "false"
),
backend_spec=backend_spec,
elem_input_type=input_type,
elem_output_type=output_type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -407,9 +407,9 @@ def gen_function_call(func_attrs, backend_spec, indent=" "):
sampling_ratio=func_attrs["sampling_ratio"],
spatial_scale=func_attrs["spatial_scale"],
position_sensitive="true" if func_attrs["position_sensitive"] else "false",
continuous_coordinate="true"
if func_attrs["continuous_coordinate"]
else "false",
continuous_coordinate=(
"true" if func_attrs["continuous_coordinate"] else "false"
),
backend_spec=backend_spec,
indent=indent,
)
Loading

0 comments on commit 0390a00

Please sign in to comment.