diff --git a/paddlemix/examples/sam2/README.md b/paddlemix/examples/sam2/README.md index 24d4eb5fd..205b3427f 100644 --- a/paddlemix/examples/sam2/README.md +++ b/paddlemix/examples/sam2/README.md @@ -11,6 +11,7 @@

+ ## 2. 快速开始 ### 获取权重 @@ -29,8 +30,3 @@ python paddlemix/examples/sam2/grounded_sam2_tracking_demo.py \ --output_path output.mp4 \ --prompt "input your prompt here" ``` - - - - - diff --git a/paddlemix/examples/sam2/grounded_sam2_tracking_demo.py b/paddlemix/examples/sam2/grounded_sam2_tracking_demo.py index 961e4b36c..d80361cd3 100644 --- a/paddlemix/examples/sam2/grounded_sam2_tracking_demo.py +++ b/paddlemix/examples/sam2/grounded_sam2_tracking_demo.py @@ -90,7 +90,7 @@ # filter output logits_filt = logits.clone() boxes_filt = boxes.clone() - filt_mask = logits_filt.max(axis=1)[0] > args.box_threshold + filt_mask = logits_filt.max(axis=1) > args.box_threshold logits_filt = logits_filt[filt_mask] # num_filt, 256 boxes_filt = boxes_filt[filt_mask] # num_filt, 4 diff --git a/paddlemix/models/sam2/modeling/backbones/hieradet.py b/paddlemix/models/sam2/modeling/backbones/hieradet.py index 76f109388..a60110a23 100644 --- a/paddlemix/models/sam2/modeling/backbones/hieradet.py +++ b/paddlemix/models/sam2/modeling/backbones/hieradet.py @@ -49,16 +49,17 @@ def __init__(self, dim: int, dim_out: int, num_heads: int, q_pool: paddle.nn.Lay def forward(self, x: paddle.Tensor) -> paddle.Tensor: B, H, W, _ = tuple(x.shape) - qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1) + + qkv = self.qkv(x).reshape([B, H * W, 3, self.num_heads, -1]) q, k, v = paddle.unbind(input=qkv, axis=2) if self.q_pool: - q = do_pool(q.reshape(B, H, W, -1), self.q_pool) + q = do_pool(q.reshape([B, H, W, -1]), self.q_pool) H, W = tuple(q.shape)[1:3] - q = q.reshape(B, H * W, self.num_heads, -1) + q = q.reshape([B, H * W, self.num_heads, -1]) x = paddle.nn.functional.scaled_dot_product_attention_(q, k, v) - x = x.reshape(B, H, W, -1) + x = x.reshape([B, H, W, -1]) x = self.proj(x) return x diff --git a/paddlemix/models/sam2/modeling/backbones/utils.py b/paddlemix/models/sam2/modeling/backbones/utils.py index db67232a2..d344eda39 100644 --- a/paddlemix/models/sam2/modeling/backbones/utils.py +++ b/paddlemix/models/sam2/modeling/backbones/utils.py @@ -35,8 +35,8 @@ def window_partition(x, window_size): if pad_h > 0 or pad_w > 0: x = paddle.nn.functional.pad(x=x, pad=(0, 0, 0, pad_w, 0, pad_h, 0, 0)) Hp, Wp = H + pad_h, W + pad_w - x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) - windows = x.transpose(perm=[0, 1, 3, 2, 4, 5]).view(-1, window_size, window_size, C) + x = x.reshape([B, Hp // window_size, window_size, Wp // window_size, window_size, C]) + windows = x.transpose(perm=[0, 1, 3, 2, 4, 5]).reshape([-1, window_size, window_size, C]) return windows, (Hp, Wp) @@ -54,8 +54,8 @@ def window_unpartition(windows, window_size, pad_hw, hw): Hp, Wp = pad_hw H, W = hw B = tuple(windows.shape)[0] // (Hp * Wp // window_size // window_size) - x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) - x = x.transpose(perm=[0, 1, 3, 2, 4, 5]).view(B, Hp, Wp, -1) + x = windows.reshape([B, Hp // window_size, Wp // window_size, window_size, window_size, -1]) + x = x.transpose(perm=[0, 1, 3, 2, 4, 5]).reshape([B, Hp, Wp, -1]) if Hp > H or Wp > W: x = x[:, :H, :W, :] return x diff --git a/paddlemix/models/sam2/modeling/position_encoding.py b/paddlemix/models/sam2/modeling/position_encoding.py index 86e0387fb..f6c5b577f 100644 --- a/paddlemix/models/sam2/modeling/position_encoding.py +++ b/paddlemix/models/sam2/modeling/position_encoding.py @@ -63,7 +63,7 @@ def encode_points(self, x, y, labels): (bx, nx), (by, ny), (bl, nl) = tuple(x.shape), tuple(y.shape), tuple(labels.shape) assert bx == by and nx == ny and bx == bl and nx == nl pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten()) - pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1) + pos_x, pos_y = pos_x.reshape([bx, nx, -1]), pos_y.reshape([by, ny, -1]) pos = paddle.concat(x=(pos_y, pos_x, labels[:, :, None]), axis=2) return pos @@ -71,16 +71,16 @@ def encode_points(self, x, y, labels): def forward(self, x: paddle.Tensor): cache_key = tuple(x.shape)[-2], tuple(x.shape)[-1] if cache_key in self.cache: - return self.cache[cache_key][None].repeat(tuple(x.shape)[0], 1, 1, 1) + return self.cache[cache_key][None].tile([tuple(x.shape)[0], 1, 1, 1]) y_embed = ( paddle.arange(start=1, end=tuple(x.shape)[-2] + 1, dtype="float32") - .view(1, -1, 1) - .repeat(tuple(x.shape)[0], 1, tuple(x.shape)[-1]) + .reshape([1, -1, 1]) + .tile((tuple(x.shape)[0], 1, tuple(x.shape)[-1])) ) x_embed = ( paddle.arange(start=1, end=tuple(x.shape)[-1] + 1, dtype="float32") - .view(1, 1, -1) - .repeat(tuple(x.shape)[0], tuple(x.shape)[-2], 1) + .reshape([1, 1, -1]) + .tile((tuple(x.shape)[0], tuple(x.shape)[-2], 1)) ) if self.normalize: eps = 1e-06 @@ -165,13 +165,13 @@ def reshape_for_broadcast(freqs_cis: paddle.Tensor, x: paddle.Tensor): assert 0 <= 1 < ndim assert tuple(freqs_cis.shape) == (tuple(x.shape)[-2], tuple(x.shape)[-1]) shape = [(d if i >= ndim - 2 else 1) for i, d in enumerate(tuple(x.shape))] - return freqs_cis.view(*shape) + return freqs_cis.reshape(shape) def apply_rotary_enc(xq: paddle.Tensor, xk: paddle.Tensor, freqs_cis: paddle.Tensor, repeat_freqs_k: bool = False): - xq_ = paddle.as_complex(x=xq.astype(dtype="float32").reshape(*tuple(xq.shape)[:-1], -1, 2)) + xq_ = paddle.as_complex(x=xq.astype(dtype="float32").reshape([*tuple(xq.shape)[:-1], -1, 2])) xk_ = ( - paddle.as_complex(x=xk.astype(dtype="float32").reshape(*tuple(xk.shape)[:-1], -1, 2)) + paddle.as_complex(x=xk.astype(dtype="float32").reshape([*tuple(xk.shape)[:-1], -1, 2])) if tuple(xk.shape)[-2] != 0 else None ) @@ -182,7 +182,7 @@ def apply_rotary_enc(xq: paddle.Tensor, xk: paddle.Tensor, freqs_cis: paddle.Ten if repeat_freqs_k: r = tuple(xk_.shape)[-2] // tuple(xq_.shape)[-2] if "gpu" in str(freqs_cis.place): - freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) + freqs_cis = freqs_cis.tile((*([1] * (freqs_cis.ndim - 2)), r, 1)) else: freqs_cis = ( freqs_cis.unsqueeze(axis=2).expand(shape=[-1, -1, r, -1, -1]).flatten(start_axis=2, stop_axis=3) diff --git a/paddlemix/models/sam2/modeling/sam/mask_decoder.py b/paddlemix/models/sam2/modeling/sam/mask_decoder.py index ced60e493..e95d9b8ba 100644 --- a/paddlemix/models/sam2/modeling/sam/mask_decoder.py +++ b/paddlemix/models/sam2/modeling/sam/mask_decoder.py @@ -188,7 +188,7 @@ def predict_masks( perm_53 = list(range(x.ndim)) perm_53[1] = 2 perm_53[2] = 1 - src = x.transpose(perm=perm_53).view(b, c, h, w) + src = x.transpose(perm=perm_53).reshape([b, c, h, w]) if not self.use_high_res_features: upscaled_embedding = self.output_upscaling(src) else: @@ -201,7 +201,7 @@ def predict_masks( hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) hyper_in = paddle.stack(x=hyper_in_list, axis=1) b, c, h, w = tuple(upscaled_embedding.shape) - masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + masks = (hyper_in @ upscaled_embedding.reshape([b, c, h * w])).reshape([b, -1, h, w]) iou_pred = self.iou_prediction_head(iou_token_out) if self.pred_obj_scores: assert s == 1 diff --git a/paddlemix/models/sam2/modeling/sam/prompt_encoder.py b/paddlemix/models/sam2/modeling/sam/prompt_encoder.py index 53da8e2b5..c3a2ced0b 100644 --- a/paddlemix/models/sam2/modeling/sam/prompt_encoder.py +++ b/paddlemix/models/sam2/modeling/sam/prompt_encoder.py @@ -104,7 +104,7 @@ def _embed_points(self, points: paddle.Tensor, labels: paddle.Tensor, pad: bool) def _embed_boxes(self, boxes: paddle.Tensor) -> paddle.Tensor: """Embeds box prompts.""" boxes = boxes + 0.5 - coords = boxes.reshape(-1, 2, 2) + coords = boxes.reshape([-1, 2, 2]) corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) corner_embedding[:, 0, :] += self.point_embeddings[2].weight corner_embedding[:, 1, :] += self.point_embeddings[3].weight @@ -171,7 +171,7 @@ def forward( if masks is not None: dense_embeddings = self._embed_masks(masks) else: - dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + dense_embeddings = self.no_mask_embed.weight.reshape([1, -1, 1, 1]).expand( shape=[bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]] ) return sparse_embeddings, dense_embeddings diff --git a/paddlemix/models/sam2/modeling/sam/transformer.py b/paddlemix/models/sam2/modeling/sam/transformer.py index 0efaa066b..2e3e3ea71 100644 --- a/paddlemix/models/sam2/modeling/sam/transformer.py +++ b/paddlemix/models/sam2/modeling/sam/transformer.py @@ -184,7 +184,7 @@ def __init__( def _separate_heads(self, x: paddle.Tensor, num_heads: int) -> paddle.Tensor: b, n, c = tuple(x.shape) - x = x.reshape(b, n, num_heads, c // num_heads) + x = x.reshape([b, n, num_heads, c // num_heads]) return x.transpose([0, 2, 1, 3]) @@ -192,7 +192,8 @@ def _recombine_heads(self, x: paddle.Tensor) -> paddle.Tensor: b, n_heads, n_tokens, c_per_head = tuple(x.shape) x = x.transpose(perm=[0, 2, 1, 3]) - return x.reshape(b, n_tokens, n_heads * c_per_head) + + return x.reshape([b, n_tokens, n_heads * c_per_head]) def forward(self, q: paddle.Tensor, k: paddle.Tensor, v: paddle.Tensor) -> paddle.Tensor: q = self.q_proj(q) diff --git a/paddlemix/models/sam2/modeling/sam2_base.py b/paddlemix/models/sam2/modeling/sam2_base.py index 5894b3ef3..9ad56bcfc 100644 --- a/paddlemix/models/sam2/modeling/sam2_base.py +++ b/paddlemix/models/sam2/modeling/sam2_base.py @@ -468,7 +468,7 @@ def _prepare_memory_conditioned_features( else: obj_pos = paddle.zeros(shape=[len(pos_list), B, self.mem_dim], dtype=obj_ptrs.dtype) if self.mem_dim < C: - obj_ptrs = obj_ptrs.reshape(-1, B, C // self.mem_dim, self.mem_dim) + obj_ptrs = obj_ptrs.reshape([-1, B, C // self.mem_dim, self.mem_dim]) obj_ptrs = obj_ptrs.transpose(perm=[0, 2, 1, 3]).flatten(start_axis=0, stop_axis=1) obj_pos = obj_pos.repeat_interleave(repeats=C // self.mem_dim, axis=0) to_cat_memory.append(obj_ptrs) @@ -479,7 +479,7 @@ def _prepare_memory_conditioned_features( else: if self.directly_add_no_mem_embed: pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed - pix_feat_with_mem = pix_feat_with_mem.transpose(perm=[1, 2, 0]).view(B, C, H, W) + pix_feat_with_mem = pix_feat_with_mem.transpose(perm=[1, 2, 0]).reshape([B, C, H, W]) return pix_feat_with_mem to_cat_memory = [self.no_mem_embed.expand(shape=[1, B, self.mem_dim])] @@ -494,7 +494,7 @@ def _prepare_memory_conditioned_features( memory_pos=memory_pos_embed, num_obj_ptr_tokens=num_obj_ptr_tokens, ) - pix_feat_with_mem = pix_feat_with_mem.transpose(perm=[1, 2, 0]).view(B, C, H, W) + pix_feat_with_mem = pix_feat_with_mem.transpose(perm=[1, 2, 0]).reshape([B, C, H, W]) return pix_feat_with_mem def _encode_new_memory( @@ -504,7 +504,7 @@ def _encode_new_memory( B = current_vision_feats[-1].shape[1] C = self.hidden_dim H, W = feat_sizes[-1] - pix_feat = current_vision_feats[-1].transpose(perm=[1, 2, 0]).view(B, C, H, W) + pix_feat = current_vision_feats[-1].transpose(perm=[1, 2, 0]).reshape([B, C, H, W]) if self.non_overlap_masks_for_mem_enc and not self.training: pred_masks_high_res = self._apply_non_overlapping_constraints(pred_masks_high_res) binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts @@ -543,14 +543,14 @@ def _track_step( current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} if len(current_vision_feats) > 1: high_res_features = [ - x.transpose(perm=[1, 2, 0]).view(x.shape[1], x.shape[2], *s) + x.transpose(perm=[1, 2, 0]).reshape([x.shape[1], x.shape[2], *s]) for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) ] else: high_res_features = None if mask_inputs is not None and self.use_mask_input_as_output_without_sam: pix_feat = current_vision_feats[-1].transpose(perm=[1, 2, 0]) - pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) + pix_feat = pix_feat.reshape([-1, self.hidden_dim, *feat_sizes[-1]]) sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs) else: pix_feat = self._prepare_memory_conditioned_features( diff --git a/paddlemix/models/sam2/sam2_image_predictor.py b/paddlemix/models/sam2/sam2_image_predictor.py index 13d2c55a9..772e8d929 100644 --- a/paddlemix/models/sam2/sam2_image_predictor.py +++ b/paddlemix/models/sam2/sam2_image_predictor.py @@ -324,7 +324,7 @@ def _predict( else: concat_points = None if boxes is not None: - box_coords = boxes.reshape(-1, 2, 2) + box_coords = boxes.reshape([-1, 2, 2]) box_labels = paddle.to_tensor(data=[[2, 3]], dtype="int32", place=boxes.place) box_labels = box_labels.repeat(boxes.shape[0], 1) if concat_points is not None: diff --git a/paddlemix/models/sam2/sam2_video_predictor.py b/paddlemix/models/sam2/sam2_video_predictor.py index 9186ac3aa..606904e18 100644 --- a/paddlemix/models/sam2/sam2_video_predictor.py +++ b/paddlemix/models/sam2/sam2_video_predictor.py @@ -174,9 +174,10 @@ def add_new_points_or_box( ) if not isinstance(box, paddle.Tensor): box = paddle.to_tensor(data=box, dtype="float32", place=points.place) - box_coords = box.reshape(1, 2, 2) + + box_coords = box.reshape([1, 2, 2]) box_labels = paddle.to_tensor(data=[2, 3], dtype="int32", place=labels.place) - box_labels = box_labels.reshape(1, 2) + box_labels = box_labels.reshape([1, 2]) points = paddle.concat(x=[box_coords, points], axis=1) labels = paddle.concat(x=[box_labels, labels], axis=1) if normalize_coords: diff --git a/paddlemix/models/sam2/utils/transforms.py b/paddlemix/models/sam2/utils/transforms.py index df5ad2de2..2dea5815a 100644 --- a/paddlemix/models/sam2/utils/transforms.py +++ b/paddlemix/models/sam2/utils/transforms.py @@ -68,7 +68,8 @@ def transform_boxes(self, boxes: paddle.Tensor, normalize=False, orig_hw=None) - Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates, if the coords are in absolute image coordinates, normalize should be set to True and original image size is required. """ - boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw) + + boxes = self.transform_coords(boxes.reshape([-1, 2, 2]), normalize, orig_hw) return boxes def postprocess_masks(self, masks: paddle.Tensor, orig_hw) -> paddle.Tensor: