Skip to content

Commit

Permalink
feat(BeitV2):add finetune and pretrain code
Browse files Browse the repository at this point in the history
  • Loading branch information
oozhuzaioo committed May 17, 2023
1 parent 5d06a88 commit d9b56ca
Show file tree
Hide file tree
Showing 20 changed files with 2,862 additions and 10 deletions.
22 changes: 21 additions & 1 deletion ppcls/arch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import copy
import importlib
import paddle
import paddle.nn as nn
from paddle.jit import to_static
from paddle.static import InputSpec
Expand All @@ -28,7 +29,7 @@
from .slim import prune_model, quantize_model
from .distill.afd_attention import LinearTransformStudent, LinearTransformTeacher

__all__ = ["build_model", "RecModel", "DistillationModel", "AttentionModel"]
__all__ = ["build_model", "RecModel", "DistillationModel", "AttentionModel", "Beitv2Model"]


def build_model(config, mode="train"):
Expand Down Expand Up @@ -168,4 +169,23 @@ def forward(self, x, label=None):
else:
out = self.model_list[idx](out, label)
result_dict.update(out)
return result_dict

class Beitv2Model(DistillationModel):
def __init__(self,
models=None,
pretrained_list=None,
freeze_params_list=None,
**kargs):
super().__init__(models, pretrained_list, freeze_params_list, **kargs)
def forward(self, samples, images, bool_masked):
result_dict = dict()
for idx, model_name in enumerate(self.model_name_list):
bool_masked_pos = bool_masked.flatten(1).astype(paddle.bool)
if model_name == "Teacher":
with paddle.no_grad():
input_ids = self.model_list[idx].get_codebook_indices(images)
result_dict[model_name] = input_ids[bool_masked_pos]
else:
result_dict[model_name] = self.model_list[idx](samples, bool_masked_pos)
return result_dict
3 changes: 3 additions & 0 deletions ppcls/arch/backbone/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@
from .model_zoo.convnext import ConvNeXt_tiny, ConvNeXt_small, ConvNeXt_base_224, ConvNeXt_base_384, ConvNeXt_large_224, ConvNeXt_large_384
from .model_zoo.nextvit import NextViT_small_224, NextViT_base_224, NextViT_large_224, NextViT_small_384, NextViT_base_384, NextViT_large_384
from .model_zoo.cae import cae_base_patch16_224, cae_large_patch16_224
from .model_zoo.vqkd import vqkd_encoder_base_decoder_3x768x12_clip
from .model_zoo.modeling_pretrain import beit_base_patch16_224_8k_vocab_cls_pt
from .model_zoo.modeling_finetune import beit_base_patch16_224

from .variant_models.resnet_variant import ResNet50_last_stage_stride1
from .variant_models.resnet_variant import ResNet50_adaptive_max_pool2d
Expand Down
Loading

0 comments on commit d9b56ca

Please sign in to comment.