Skip to content

Commit

Permalink
Sam2 (#956)
Browse files Browse the repository at this point in the history
  • Loading branch information
LokeZhou authored Jan 8, 2025
1 parent 2dfd2db commit 178bebe
Show file tree
Hide file tree
Showing 12 changed files with 40 additions and 40 deletions.
6 changes: 1 addition & 5 deletions paddlemix/examples/sam2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
<img src="https://github.com/user-attachments/assets/62626ba4-d81f-4c09-bc79-dc8310eddd5d" align="middle" width = "600" />
</p>


## 2. 快速开始

### 获取权重
Expand All @@ -29,8 +30,3 @@ python paddlemix/examples/sam2/grounded_sam2_tracking_demo.py \
--output_path output.mp4 \
--prompt "input your prompt here"
```





2 changes: 1 addition & 1 deletion paddlemix/examples/sam2/grounded_sam2_tracking_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
# filter output
logits_filt = logits.clone()
boxes_filt = boxes.clone()
filt_mask = logits_filt.max(axis=1)[0] > args.box_threshold
filt_mask = logits_filt.max(axis=1) > args.box_threshold
logits_filt = logits_filt[filt_mask] # num_filt, 256
boxes_filt = boxes_filt[filt_mask] # num_filt, 4

Expand Down
9 changes: 5 additions & 4 deletions paddlemix/models/sam2/modeling/backbones/hieradet.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,17 @@ def __init__(self, dim: int, dim_out: int, num_heads: int, q_pool: paddle.nn.Lay

def forward(self, x: paddle.Tensor) -> paddle.Tensor:
B, H, W, _ = tuple(x.shape)
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)

qkv = self.qkv(x).reshape([B, H * W, 3, self.num_heads, -1])
q, k, v = paddle.unbind(input=qkv, axis=2)
if self.q_pool:
q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
q = do_pool(q.reshape([B, H, W, -1]), self.q_pool)
H, W = tuple(q.shape)[1:3]
q = q.reshape(B, H * W, self.num_heads, -1)
q = q.reshape([B, H * W, self.num_heads, -1])

x = paddle.nn.functional.scaled_dot_product_attention_(q, k, v)

x = x.reshape(B, H, W, -1)
x = x.reshape([B, H, W, -1])
x = self.proj(x)
return x

Expand Down
8 changes: 4 additions & 4 deletions paddlemix/models/sam2/modeling/backbones/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def window_partition(x, window_size):
if pad_h > 0 or pad_w > 0:
x = paddle.nn.functional.pad(x=x, pad=(0, 0, 0, pad_w, 0, pad_h, 0, 0))
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = x.transpose(perm=[0, 1, 3, 2, 4, 5]).view(-1, window_size, window_size, C)
x = x.reshape([B, Hp // window_size, window_size, Wp // window_size, window_size, C])
windows = x.transpose(perm=[0, 1, 3, 2, 4, 5]).reshape([-1, window_size, window_size, C])
return windows, (Hp, Wp)


Expand All @@ -54,8 +54,8 @@ def window_unpartition(windows, window_size, pad_hw, hw):
Hp, Wp = pad_hw
H, W = hw
B = tuple(windows.shape)[0] // (Hp * Wp // window_size // window_size)
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
x = x.transpose(perm=[0, 1, 3, 2, 4, 5]).view(B, Hp, Wp, -1)
x = windows.reshape([B, Hp // window_size, Wp // window_size, window_size, window_size, -1])
x = x.transpose(perm=[0, 1, 3, 2, 4, 5]).reshape([B, Hp, Wp, -1])
if Hp > H or Wp > W:
x = x[:, :H, :W, :]
return x
Expand Down
20 changes: 10 additions & 10 deletions paddlemix/models/sam2/modeling/position_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,24 +63,24 @@ def encode_points(self, x, y, labels):
(bx, nx), (by, ny), (bl, nl) = tuple(x.shape), tuple(y.shape), tuple(labels.shape)
assert bx == by and nx == ny and bx == bl and nx == nl
pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
pos_x, pos_y = pos_x.reshape([bx, nx, -1]), pos_y.reshape([by, ny, -1])
pos = paddle.concat(x=(pos_y, pos_x, labels[:, :, None]), axis=2)
return pos

@paddle.no_grad()
def forward(self, x: paddle.Tensor):
cache_key = tuple(x.shape)[-2], tuple(x.shape)[-1]
if cache_key in self.cache:
return self.cache[cache_key][None].repeat(tuple(x.shape)[0], 1, 1, 1)
return self.cache[cache_key][None].tile([tuple(x.shape)[0], 1, 1, 1])
y_embed = (
paddle.arange(start=1, end=tuple(x.shape)[-2] + 1, dtype="float32")
.view(1, -1, 1)
.repeat(tuple(x.shape)[0], 1, tuple(x.shape)[-1])
.reshape([1, -1, 1])
.tile((tuple(x.shape)[0], 1, tuple(x.shape)[-1]))
)
x_embed = (
paddle.arange(start=1, end=tuple(x.shape)[-1] + 1, dtype="float32")
.view(1, 1, -1)
.repeat(tuple(x.shape)[0], tuple(x.shape)[-2], 1)
.reshape([1, 1, -1])
.tile((tuple(x.shape)[0], tuple(x.shape)[-2], 1))
)
if self.normalize:
eps = 1e-06
Expand Down Expand Up @@ -165,13 +165,13 @@ def reshape_for_broadcast(freqs_cis: paddle.Tensor, x: paddle.Tensor):
assert 0 <= 1 < ndim
assert tuple(freqs_cis.shape) == (tuple(x.shape)[-2], tuple(x.shape)[-1])
shape = [(d if i >= ndim - 2 else 1) for i, d in enumerate(tuple(x.shape))]
return freqs_cis.view(*shape)
return freqs_cis.reshape(shape)


def apply_rotary_enc(xq: paddle.Tensor, xk: paddle.Tensor, freqs_cis: paddle.Tensor, repeat_freqs_k: bool = False):
xq_ = paddle.as_complex(x=xq.astype(dtype="float32").reshape(*tuple(xq.shape)[:-1], -1, 2))
xq_ = paddle.as_complex(x=xq.astype(dtype="float32").reshape([*tuple(xq.shape)[:-1], -1, 2]))
xk_ = (
paddle.as_complex(x=xk.astype(dtype="float32").reshape(*tuple(xk.shape)[:-1], -1, 2))
paddle.as_complex(x=xk.astype(dtype="float32").reshape([*tuple(xk.shape)[:-1], -1, 2]))
if tuple(xk.shape)[-2] != 0
else None
)
Expand All @@ -182,7 +182,7 @@ def apply_rotary_enc(xq: paddle.Tensor, xk: paddle.Tensor, freqs_cis: paddle.Ten
if repeat_freqs_k:
r = tuple(xk_.shape)[-2] // tuple(xq_.shape)[-2]
if "gpu" in str(freqs_cis.place):
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
freqs_cis = freqs_cis.tile((*([1] * (freqs_cis.ndim - 2)), r, 1))
else:
freqs_cis = (
freqs_cis.unsqueeze(axis=2).expand(shape=[-1, -1, r, -1, -1]).flatten(start_axis=2, stop_axis=3)
Expand Down
4 changes: 2 additions & 2 deletions paddlemix/models/sam2/modeling/sam/mask_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def predict_masks(
perm_53 = list(range(x.ndim))
perm_53[1] = 2
perm_53[2] = 1
src = x.transpose(perm=perm_53).view(b, c, h, w)
src = x.transpose(perm=perm_53).reshape([b, c, h, w])
if not self.use_high_res_features:
upscaled_embedding = self.output_upscaling(src)
else:
Expand All @@ -201,7 +201,7 @@ def predict_masks(
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
hyper_in = paddle.stack(x=hyper_in_list, axis=1)
b, c, h, w = tuple(upscaled_embedding.shape)
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
masks = (hyper_in @ upscaled_embedding.reshape([b, c, h * w])).reshape([b, -1, h, w])
iou_pred = self.iou_prediction_head(iou_token_out)
if self.pred_obj_scores:
assert s == 1
Expand Down
4 changes: 2 additions & 2 deletions paddlemix/models/sam2/modeling/sam/prompt_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _embed_points(self, points: paddle.Tensor, labels: paddle.Tensor, pad: bool)
def _embed_boxes(self, boxes: paddle.Tensor) -> paddle.Tensor:
"""Embeds box prompts."""
boxes = boxes + 0.5
coords = boxes.reshape(-1, 2, 2)
coords = boxes.reshape([-1, 2, 2])
corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
Expand Down Expand Up @@ -171,7 +171,7 @@ def forward(
if masks is not None:
dense_embeddings = self._embed_masks(masks)
else:
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
dense_embeddings = self.no_mask_embed.weight.reshape([1, -1, 1, 1]).expand(
shape=[bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]]
)
return sparse_embeddings, dense_embeddings
5 changes: 3 additions & 2 deletions paddlemix/models/sam2/modeling/sam/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,15 +184,16 @@ def __init__(

def _separate_heads(self, x: paddle.Tensor, num_heads: int) -> paddle.Tensor:
b, n, c = tuple(x.shape)
x = x.reshape(b, n, num_heads, c // num_heads)
x = x.reshape([b, n, num_heads, c // num_heads])

return x.transpose([0, 2, 1, 3])

def _recombine_heads(self, x: paddle.Tensor) -> paddle.Tensor:
b, n_heads, n_tokens, c_per_head = tuple(x.shape)

x = x.transpose(perm=[0, 2, 1, 3])
return x.reshape(b, n_tokens, n_heads * c_per_head)

return x.reshape([b, n_tokens, n_heads * c_per_head])

def forward(self, q: paddle.Tensor, k: paddle.Tensor, v: paddle.Tensor) -> paddle.Tensor:
q = self.q_proj(q)
Expand Down
12 changes: 6 additions & 6 deletions paddlemix/models/sam2/modeling/sam2_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def _prepare_memory_conditioned_features(
else:
obj_pos = paddle.zeros(shape=[len(pos_list), B, self.mem_dim], dtype=obj_ptrs.dtype)
if self.mem_dim < C:
obj_ptrs = obj_ptrs.reshape(-1, B, C // self.mem_dim, self.mem_dim)
obj_ptrs = obj_ptrs.reshape([-1, B, C // self.mem_dim, self.mem_dim])
obj_ptrs = obj_ptrs.transpose(perm=[0, 2, 1, 3]).flatten(start_axis=0, stop_axis=1)
obj_pos = obj_pos.repeat_interleave(repeats=C // self.mem_dim, axis=0)
to_cat_memory.append(obj_ptrs)
Expand All @@ -479,7 +479,7 @@ def _prepare_memory_conditioned_features(
else:
if self.directly_add_no_mem_embed:
pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed
pix_feat_with_mem = pix_feat_with_mem.transpose(perm=[1, 2, 0]).view(B, C, H, W)
pix_feat_with_mem = pix_feat_with_mem.transpose(perm=[1, 2, 0]).reshape([B, C, H, W])
return pix_feat_with_mem

to_cat_memory = [self.no_mem_embed.expand(shape=[1, B, self.mem_dim])]
Expand All @@ -494,7 +494,7 @@ def _prepare_memory_conditioned_features(
memory_pos=memory_pos_embed,
num_obj_ptr_tokens=num_obj_ptr_tokens,
)
pix_feat_with_mem = pix_feat_with_mem.transpose(perm=[1, 2, 0]).view(B, C, H, W)
pix_feat_with_mem = pix_feat_with_mem.transpose(perm=[1, 2, 0]).reshape([B, C, H, W])
return pix_feat_with_mem

def _encode_new_memory(
Expand All @@ -504,7 +504,7 @@ def _encode_new_memory(
B = current_vision_feats[-1].shape[1]
C = self.hidden_dim
H, W = feat_sizes[-1]
pix_feat = current_vision_feats[-1].transpose(perm=[1, 2, 0]).view(B, C, H, W)
pix_feat = current_vision_feats[-1].transpose(perm=[1, 2, 0]).reshape([B, C, H, W])
if self.non_overlap_masks_for_mem_enc and not self.training:
pred_masks_high_res = self._apply_non_overlapping_constraints(pred_masks_high_res)
binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
Expand Down Expand Up @@ -543,14 +543,14 @@ def _track_step(
current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
if len(current_vision_feats) > 1:
high_res_features = [
x.transpose(perm=[1, 2, 0]).view(x.shape[1], x.shape[2], *s)
x.transpose(perm=[1, 2, 0]).reshape([x.shape[1], x.shape[2], *s])
for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
]
else:
high_res_features = None
if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
pix_feat = current_vision_feats[-1].transpose(perm=[1, 2, 0])
pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
pix_feat = pix_feat.reshape([-1, self.hidden_dim, *feat_sizes[-1]])
sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs)
else:
pix_feat = self._prepare_memory_conditioned_features(
Expand Down
2 changes: 1 addition & 1 deletion paddlemix/models/sam2/sam2_image_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def _predict(
else:
concat_points = None
if boxes is not None:
box_coords = boxes.reshape(-1, 2, 2)
box_coords = boxes.reshape([-1, 2, 2])
box_labels = paddle.to_tensor(data=[[2, 3]], dtype="int32", place=boxes.place)
box_labels = box_labels.repeat(boxes.shape[0], 1)
if concat_points is not None:
Expand Down
5 changes: 3 additions & 2 deletions paddlemix/models/sam2/sam2_video_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,10 @@ def add_new_points_or_box(
)
if not isinstance(box, paddle.Tensor):
box = paddle.to_tensor(data=box, dtype="float32", place=points.place)
box_coords = box.reshape(1, 2, 2)

box_coords = box.reshape([1, 2, 2])
box_labels = paddle.to_tensor(data=[2, 3], dtype="int32", place=labels.place)
box_labels = box_labels.reshape(1, 2)
box_labels = box_labels.reshape([1, 2])
points = paddle.concat(x=[box_coords, points], axis=1)
labels = paddle.concat(x=[box_labels, labels], axis=1)
if normalize_coords:
Expand Down
3 changes: 2 additions & 1 deletion paddlemix/models/sam2/utils/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def transform_boxes(self, boxes: paddle.Tensor, normalize=False, orig_hw=None) -
Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates,
if the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
"""
boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw)

boxes = self.transform_coords(boxes.reshape([-1, 2, 2]), normalize, orig_hw)
return boxes

def postprocess_masks(self, masks: paddle.Tensor, orig_hw) -> paddle.Tensor:
Expand Down

0 comments on commit 178bebe

Please sign in to comment.