forked from PaddlePaddle/PaddleMIX
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinternvl_chat_finetune.py
894 lines (787 loc) · 38.9 KB
/
internvl_chat_finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import json
import logging
import math
import os
import random
import sys
import traceback
import warnings
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Dict, Optional
import numpy as np
import paddle
import paddle.distributed as dist
from paddlemix.models.internvl2.internlm2.modeling_internlm2 import InternLM2ForCausalLM
from paddlemix.models.internvl2.internvl_chat import (InternVisionConfig,
InternVisionModel,
InternVLChatConfig,
InternVLChatModel)
from paddlemix.models.internvl2.patch import concat_pad_data_collator
from paddlemix.models.internvl2.constants import (BOX_END_TOKEN, BOX_START_TOKEN,
IMG_CONTEXT_TOKEN, IMG_END_TOKEN,
IMG_START_TOKEN, QUAD_END_TOKEN,
QUAD_START_TOKEN, REF_END_TOKEN,
REF_START_TOKEN)
from paddlemix.datasets.internvl_dataset import (ConcatDataset,
WeightedConcatDataset, build_transform,
dynamic_preprocess, preprocess,
preprocess_internlm, preprocess_mpt,
preprocess_phi3)
from PIL import Image, ImageFile, PngImagePlugin, UnidentifiedImageError
from paddle.io import Dataset
from paddlenlp.transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from paddlenlp.trainer import TrainingArguments, PdArgumentParser, set_seed
from paddlenlp.trainer.trainer_utils import get_last_checkpoint, is_main_process
from paddlenlp.transformers import AutoTokenizer, Qwen2Tokenizer, LlamaTokenizer, Llama3Tokenizer
from paddlenlp.trainer.trainer import Trainer
# Set constants for image processing and logging
IGNORE_INDEX = -100
Image.MAX_IMAGE_PIXELS = None
ImageFile.LOAD_TRUNCATED_IMAGES = True
MaximumDecompressedSize = 1024
MegaByte = 2 ** 20
PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
warnings.filterwarnings('ignore')
logger = logging.getLogger(__name__)
@dataclass
class ModelArguments:
"""
Arguments for specifying model, tokenizer, and configurations.
"""
model_name_or_path: Optional[str] = field(
default=None,
metadata={'help': 'Path to pretrained model or model identifier from huggingface.co/models'}
)
vision_path: Optional[str] = field(
default=None,
metadata={'help': 'Path to pretrained model or model identifier from huggingface.co/models'}
)
llm_path: Optional[str] = field(
default=None,
metadata={'help': 'Path to pretrained model or model identifier from huggingface.co/models'}
)
mlp_path: Optional[str] = field(
default=None,
metadata={'help': 'Path to pretrained model or model identifier from huggingface.co/models'}
)
freeze_llm: bool = field(
default=False,
metadata={'help': 'Set to True to freeze the LLM decoder.'},
)
freeze_backbone: bool = field(
default=False,
metadata={'help': 'Set to True to freeze the vision backbone of the model.'},
)
freeze_mlp: bool = field(
default=False,
metadata={'help': 'Set to True to freeze the MLP layers of the model.'},
)
unfreeze_vit_layers: int = field(
default=0,
metadata={'help': 'Specify the number of ViT layers to unfreeze. Default is 0.'},
)
vision_select_layer: int = field(
default=-1,
metadata={'help': 'Specify the layer of ViT feature map to use. Default is last layer.'},
)
use_backbone_lora: int = field(
default=0,
metadata={'help': 'Set the LoRA adapter rank for the backbone model. Default is 0.'}
)
use_llm_lora: int = field(
default=0,
metadata={'help': 'Set the LoRA adapter rank for the LLM. Default is 0.'}
)
unfreeze_lm_head: bool = field(
default=False,
metadata={'help': "Set to True to unfreeze the language model's head."},
)
use_custom_trainer: bool = field(
default=False,
metadata={'help': 'Set to True to enable the use of a custom trainer.'},
)
grad_checkpoint: Optional[bool] = field(
default=False,
metadata={'help': 'Set to True to use gradient checkpointing.'},
)
drop_path_rate: float = field(
default=0.0,
metadata={'help': 'Set the drop path rate for the ViT model. Default is 0.'},
)
ps_version: str = field(
default='v2',
metadata={'help': 'Specify the version of pixel shuffle implementation. Default is `v1`.'
'Please use `v2` to fix the bug of transposed image.'}
)
@dataclass
class DataTrainingArguments:
"""
Arguments for specifying data input for training and evaluation.
"""
max_seq_length: Optional[int] = field(
default=2048,
metadata={
'help': (
'The maximum total input sequence length after tokenization. Sequences longer '
'than this will be truncated, sequences shorter will be padded.'
)
},
)
force_image_size: Optional[int] = field(
default=448,
metadata={'help': 'Set the desired size for the image. Default is 224.'},
)
down_sample_ratio: Optional[float] = field(
default=0.5,
metadata={'help': 'Set the desired down-sampling ratio for the image. Default is 1.0.'},
)
pad2square: Optional[bool] = field(
default=False,
metadata={'help': 'Pad the image to a square shape if set to True.'},
)
conv_style: Optional[str] = field(
default='internlm2-chat', metadata={'help': 'Prompt style for a conversation.'}
)
meta_path: Optional[str] = field(
default=None,
metadata={'help': 'The path of the meta file of datasets.'},
)
use_data_resampling: Optional[bool] = field(
default=False,
metadata={'help': 'Set to True to use data resampling.'},
)
dynamic_image_size: Optional[bool] = field(
default=False,
metadata={'help': 'Set to True to use dynamic image size.'},
)
use_thumbnail: Optional[bool] = field(
default=False,
metadata={'help': 'Set to True to add a thumbnail image.'},
)
min_dynamic_patch: Optional[int] = field(
default=1,
metadata={'help': 'The minimum number of dynamic patches. Default is 1.'},
)
max_dynamic_patch: Optional[int] = field(
default=12,
metadata={'help': 'The maximum number of dynamic patches. Default is 6.'},
)
normalize_type: Optional[str] = field(
default='imagenet',
metadata={'help': 'The normalize type for the image. Default is imagenet.'},
)
@dataclass
class PreTrainingArguments(TrainingArguments):
"""
Arguments pertaining to what training options we are going to use during pretraining.
"""
group_by_length: bool = field(
default=True,
metadata={"help": ""},
)
save_safetensors: bool = field(
default=True,
metadata={"help": ""},
)
class LazySupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(
self,
template_name,
meta,
tokenizer,
tcs_loader, # None
ds_name,
num_image_token,
image_size=224,
is_train=True,
pad2square=False,
group_by_length=False,
dynamic_image_size=False,
use_thumbnail=False,
min_dynamic_patch=1,
max_dynamic_patch=6,
min_num_frame=4, # for video data
max_num_frame=12, # for video data
sampling_method='rand', # for video data
repeat_time=1,
normalize_type='imagenet',
random_seed=0,
):
super(LazySupervisedDataset, self).__init__()
self.ds_name = ds_name
self.tokenizer = tokenizer
self.template_name = template_name
self.num_image_token = num_image_token
logger.info(f'[Dataset] num_image_token: {num_image_token}')
logger.info(f'[Dataset] dynamic_image_size: {dynamic_image_size}')
logger.info(f'[Dataset] use_thumbnail: {use_thumbnail}')
logger.info(f'[Dataset] min_dynamic_patch: {min_dynamic_patch}, max_dynamic_patch: {max_dynamic_patch}')
self.image_size = image_size
self.is_train = is_train
self.pad2square = pad2square
self.max_num_frame = max_num_frame
self.min_num_frame = min_num_frame
self.sampling_method = sampling_method
logger.info('Formatting inputs...Skip in lazy mode')
assert meta['annotation'].endswith('jsonl'), f'annotation must be jsonl, but got {meta["annotation"]}'
with open(meta['annotation'], 'r') as f:
self.raw_data = f.readlines()
if repeat_time < 1:
# If repeat_time is less than 1, select a portion of the data
self.raw_data = self.raw_data[:int(len(self.raw_data) * repeat_time)]
if repeat_time > 1:
assert isinstance(repeat_time, int)
# Repeat the list if repeat_time is greater than 1
self.raw_data = self.raw_data * repeat_time
self.rng = np.random.default_rng(seed=random_seed)
self.rng.shuffle(self.raw_data)
gc.collect()
self.root = meta['root']
self.cached_data_dict = {}
self.tcs_loader = tcs_loader
self.group_by_length = group_by_length
self.dynamic_image_size = dynamic_image_size
self.use_thumbnail = use_thumbnail
self.min_dynamic_patch = min_dynamic_patch
self.max_dynamic_patch = max_dynamic_patch
self.normalize_type = normalize_type
# If the precomputed length does not exist, roughly estimate the length of
# each sample to improve the efficiency of group_by_length.
if self.group_by_length:
self.conv2length = {} # Using a dictionary to speed up token length calculation
self.length = []
for data_item in self.raw_data:
data_item = json.loads(data_item)
if 'length' in data_item:
token_length = data_item['length'] # Use precomputed length if available
else:
# Compute token length using the tokenizer
conversations = '\n'.join([temp['value'] for temp in data_item['conversations']])
str_length = len(conversations)
if str_length not in self.conv2length:
token_length = tokenizer(
conversations, return_tensors='pd', padding=False, truncation=False,
).input_ids.shape[1]
self.conv2length[str_length] = token_length + num_image_token * (
max_dynamic_patch + use_thumbnail)
else:
token_length = self.conv2length[str_length]
self.length.append(token_length)
gc.collect()
def __len__(self):
return len(self.raw_data)
def get_preprocess_function(self):
# Select the appropriate preprocessing function based on the template name
if self.template_name == 'Hermes-2':
preprocess_function = preprocess_mpt
elif self.template_name == 'internlm2-chat':
preprocess_function = preprocess_internlm
elif self.template_name == 'phi3-chat':
preprocess_function = preprocess_phi3
else:
preprocess_function = preprocess
return preprocess_function
def load_image(self, image_path):
# Load the image using tcs_loader if available, otherwise use PIL
if self.tcs_loader is not None and 's3://' in image_path:
return self.tcs_loader(image_path)
return Image.open(image_path).convert('RGB')
def get_image_path(self, image_path):
if image_path.startswith('s3://'): # for ceph
image_path = self.root + image_path
else: # for local image
image_path = os.path.join(self.root, image_path)
return image_path
def get_transform(self):
# Build transformation function
transform = build_transform(is_train=self.is_train, input_size=self.image_size,
pad2square=self.pad2square, normalize_type=self.normalize_type)
return transform
def multi_modal_get_item(self, data_item):
# Build transformation function
transform = self.get_transform()
# Ensure the first conversation contains an image placeholder
if '<image>' not in data_item['conversations'][0]['value']:
data_item['conversations'][0]['value'] = '<image>\n' + data_item['conversations'][0]['value']
# Merge the image path
image_path = self.get_image_path(data_item['image'])
# Load the image using tcs_loader if available, otherwise use PIL
image = self.load_image(image_path)
if self.dynamic_image_size: # If dynamic image size is enabled, preprocess the image dynamically
images = dynamic_preprocess(image, min_num=self.min_dynamic_patch, max_num=self.max_dynamic_patch,
image_size=self.image_size, use_thumbnail=self.use_thumbnail)
else: # Otherwise, use the original image as a single patch
images = [image]
# Apply the transformation to each image and stack the results into a tensor
pixel_values = [transform(image) for image in images]
pixel_values = paddle.stack(pixel_values)
# Ensure that there is only one patch if dynamic image size is not enabled
num_patches = pixel_values.shape[0]
if not self.dynamic_image_size:
assert num_patches == 1, f'The number of patches should be 1, but got {num_patches}.'
# Select the appropriate preprocessing function based on the template name
preprocess_function = self.get_preprocess_function()
# Preprocess the conversations and generate the return dictionary
ret = preprocess_function(self.template_name, [deepcopy(data_item['conversations'])],
self.tokenizer, [self.num_image_token * num_patches],
group_by_length=self.group_by_length, ds_name=self.ds_name)
# Create the final return dictionary
ret = dict(
input_ids=ret['input_ids'][0],
labels=ret['labels'][0],
attention_mask=ret['attention_mask'][0],
pixel_values=pixel_values,
image_flags=paddle.to_tensor([1] * num_patches, dtype=paddle.int64)
)
return ret
def multi_modal_multi_image_get_item(self, data_item):
# Build transformation function
transform = self.get_transform()
images, num_tiles = [], []
num_image = len(data_item['image'])
for image_path in data_item['image']:
# Merge the image path
image_path = self.get_image_path(image_path)
# Load the image using tcs_loader if available, otherwise use PIL
image = self.load_image(image_path)
if self.dynamic_image_size: # If dynamic image size is enabled, preprocess the image dynamically
image = dynamic_preprocess(image, min_num=self.min_dynamic_patch,
max_num=self.max_dynamic_patch // num_image,
image_size=self.image_size, use_thumbnail=self.use_thumbnail)
images += image
num_tiles.append(len(image))
else: # Otherwise, use the original image as a single patch
images.append(image)
num_tiles.append(1)
pixel_values = [transform(image) for image in images]
pixel_values = paddle.stack(pixel_values)
num_patches = pixel_values.shape[0]
# Select the appropriate preprocessing function based on the template name
preprocess_function = self.get_preprocess_function()
# Preprocess the conversations and generate the return dictionary
num_image_tokens = [self.num_image_token * num_tile for num_tile in num_tiles]
ret = preprocess_function(self.template_name, [deepcopy(data_item['conversations'])],
self.tokenizer, num_image_tokens, group_by_length=self.group_by_length,
ds_name=self.ds_name, num_image=num_image)
# Create the final return dictionary
ret = dict(
input_ids=ret['input_ids'][0],
labels=ret['labels'][0],
attention_mask=ret['attention_mask'][0],
pixel_values=pixel_values,
image_flags=paddle.to_tensor([1] * num_patches, dtype=paddle.int64)
)
return ret
def video_get_item(self, data_item):
# Build transformation function
transform = self.get_transform()
# Ensure the first conversation contains a video placeholder
if '<video>' not in data_item['conversations'][0]['value']:
data_item['conversations'][0]['value'] = '<video>\n' + data_item['conversations'][0]['value']
# Get the video file path
video_file = data_item['video']
video_path = os.path.join(self.root, video_file)
# Load the video frames using tcs_loader
# TODO: Load videos without using tcsloader.
image_list = self.tcs_loader(
video_path,
image_type='video',
max_num_frames=self.max_num_frame,
min_num_frames=self.min_num_frame,
sample=self.sampling_method,
clip=data_item.get('clip', None))
# Generate special tokens for each video frame
special_tokens = '\n'.join(['Frame{}: <image>'.format(i + 1) for i in range(len(image_list))])
data_item['conversations'][0]['value'] = data_item['conversations'][0]['value'].replace(
'<video>\n', special_tokens)
# Transform each frame image and stack them into a tensor
pixel_values = [transform(image) for image in image_list]
pixel_values = paddle.stack(pixel_values)
num_patches = pixel_values.shape[0]
# Select the appropriate preprocessing function based on the template name
preprocess_function = self.get_preprocess_function()
# Preprocess the conversations and generate the return dictionary
num_image_tokens = [self.num_image_token] * num_patches
ret = preprocess_function(self.template_name, [deepcopy(data_item['conversations'])],
self.tokenizer, num_image_tokens, group_by_length=self.group_by_length,
ds_name=self.ds_name, num_image=num_patches)
# Create the final return dictionary
ret = dict(
input_ids=ret['input_ids'][0],
labels=ret['labels'][0],
attention_mask=ret['attention_mask'][0],
pixel_values=pixel_values,
image_flags=paddle.to_tensor([1] * num_patches, dtype=paddle.int64)
)
return ret
def pure_text_get_item(self, data_item):
# Build transformation function
transform = self.get_transform()
# Create a blank white image
image = Image.new('RGB', (224, 224), (255, 255, 255))
# Dynamically preprocess the image to generate patches
images = dynamic_preprocess(image, min_num=self.min_dynamic_patch, max_num=1,
image_size=self.image_size, use_thumbnail=self.use_thumbnail)
# Apply the transformation to each image patch and stack them into a tensor
pixel_values = [transform(image) for image in images]
pixel_values = paddle.stack(pixel_values)
num_patches = pixel_values.shape[0]
# Ensure there is only one patch
assert num_patches == 1, f'The number of patches should be 1, but got {num_patches}.'
# Select the appropriate preprocessing function based on the template name
preprocess_function = self.get_preprocess_function()
# Preprocess the conversations and generate the return dictionary
ret = preprocess_function(self.template_name, [deepcopy(data_item['conversations'])],
self.tokenizer, [self.num_image_token * num_patches], text_only=True,
group_by_length=self.group_by_length, ds_name=self.ds_name)
# Create the final return dictionary
ret = dict(
input_ids=ret['input_ids'][0],
labels=ret['labels'][0],
attention_mask=ret['attention_mask'][0],
pixel_values=pixel_values,
image_flags=paddle.to_tensor([0] * num_patches, dtype=paddle.int64)
)
return ret
def __getitem__(self, i) -> Dict[str, paddle.Tensor]:
i = i % len(self.raw_data)
while True:
try:
data_item = json.loads(self.raw_data[i])
if 'image' in data_item and len(data_item['image']) != 0:
if type(data_item['image']) == list:
ret = self.multi_modal_multi_image_get_item(data_item)
else:
ret = self.multi_modal_get_item(data_item)
elif 'video' in data_item and data_item['video'] is not None and data_item['video'] != '':
ret = self.video_get_item(data_item) # video_get_item
else:
ret = self.pure_text_get_item(data_item)
break
except Exception as e:
print(e, self.ds_name, flush=True)
if not isinstance(e, UnidentifiedImageError):
traceback.print_exc()
data_item = json.loads(self.raw_data[i])
if 'image' in data_item:
if type(data_item['image']) == list:
images = [self.root + item for item in data_item['image']]
print(f'Failed to load image: {images}, the dataset is: {self.ds_name}')
else:
if data_item['image'].startswith('s3://'):
data_path = self.root + data_item['image']
else:
data_path = os.path.join(self.root, data_item['image'])
print(f'Failed to load image: {data_path}, the dataset is: {self.ds_name}')
elif 'video' in data_item:
data_path = os.path.join(self.root, data_item['video'])
print(f'Failed to load video: {data_path}, the dataset is: {self.ds_name}')
i = random.randint(0, len(self.raw_data) - 1)
return ret
def build_datasets(
data_args,
tokenizer,
tcs_loader,
model,
group_by_length=False,
dynamic_image_size=False,
use_thumbnail=False,
min_dynamic_patch=1,
max_dynamic_patch=12,
normalize_type='imagenet',
):
datasets = []
lengths = []
ds_collections = json.loads(open(data_args.meta_path).read())
for ds_idx, ds_name in enumerate(ds_collections.keys()):
repeat_time = ds_collections[ds_name]['repeat_time']
if 'max_dynamic_patch' in ds_collections[ds_name]:
max_num = ds_collections[ds_name]['max_dynamic_patch']
logger.info(f'max_dynamic_patch is set to {max_num} according to the meta file')
else:
max_num = max_dynamic_patch
dataset = LazySupervisedDataset(
data_args.conv_style,
ds_collections[ds_name],
tokenizer,
tcs_loader,
ds_name=ds_name,
num_image_token=model.num_image_token,
image_size=data_args.force_image_size,
is_train=ds_collections[ds_name]['data_augment'], # false
pad2square=data_args.pad2square,
group_by_length=group_by_length,
dynamic_image_size=dynamic_image_size,
use_thumbnail=use_thumbnail,
min_dynamic_patch=min_dynamic_patch,
max_dynamic_patch=max_num,
repeat_time=repeat_time,
normalize_type=normalize_type,
random_seed=ds_idx,
)
logger.info(f'Add dataset: {ds_name} with length: {len(dataset)}')
datasets.append(dataset)
if data_args.use_data_resampling: #
lengths.append(math.sqrt(len(dataset)))
else:
lengths.append(len(dataset))
if data_args.use_data_resampling: #
total_length = sum(lengths)
weights = [l / total_length for l in lengths]
train_dataset = WeightedConcatDataset(datasets, weights)
else:
train_dataset = ConcatDataset(datasets)
return train_dataset
def main():
parser = PdArgumentParser((ModelArguments, DataTrainingArguments, PreTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith('.json'):
# If we pass only one argument to the script, and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
training_args.print_config(model_args, "Model")
training_args.print_config(data_args, "Data")
# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, "
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16 or training_args.bf16}"
)
logger.info(f'Training/evaluation parameters:\n {training_args}')
# Detecting last checkpoint and eventually continue from last checkpoint.
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError(
f'Output directory ({training_args.output_dir}) already exists and is not empty. '
'Use --overwrite_output_dir to overcome.'
)
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(
f'Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change '
'the `--output_dir` or add `--overwrite_output_dir` to train from scratch.'
)
# Load model
if training_args.fp16_opt_level == "O2":
if training_args.fp16:
dtype = "float16"
elif training_args.bf16 and paddle.amp.is_bfloat16_supported():
dtype = "bfloat16"
else:
raise ValueError("Please specific dtype: --fp16 or --bf16")
else:
dtype = "float32"
# Set seed before initializing model.
set_seed(training_args.seed)
# Load pretrained model, tokenizer, and image processor
tokenizer_path = model_args.model_name_or_path or model_args.llm_path
print(f'Loading Tokenizer: {tokenizer_path}')
if 'qwen' in tokenizer_path.lower() or '1B' in tokenizer_path:
tokenizer = Qwen2Tokenizer.from_pretrained(
tokenizer_path, add_eos_token=False, trust_remote_code=True)
else:
from paddlemix.models.internvl2.internlm2 import InternLM2Tokenizer
tokenizer = InternLM2Tokenizer.from_pretrained(
tokenizer_path, add_eos_token=False, trust_remote_code=True)
# tokenizer = AutoTokenizer.from_pretrained(
# tokenizer_path, add_eos_token=False, trust_remote_code=True)
tokenizer.tokenizer_path = tokenizer_path
tokenizer.model_max_length = data_args.max_seq_length
# token_list = [IMG_START_TOKEN, IMG_END_TOKEN, IMG_CONTEXT_TOKEN,
# QUAD_START_TOKEN, QUAD_END_TOKEN, REF_START_TOKEN,
# REF_END_TOKEN, BOX_START_TOKEN, BOX_END_TOKEN]
# num_new_tokens = tokenizer.add_tokens(token_list, special_tokens=True)
# img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) ###
tcs_loader = None #TCSLoader('~/petreloss.conf') if has_tcs_loader else None
model_size = tokenizer_path.split('-')[-1]
if 'qwen' in tokenizer_path.lower() or model_size in ['1B']:
# TODO:
tokenizer.added_tokens_encoder = {'<|endoftext|>': 151643, '<|im_start|>': 151644, '<|im_end|>': 151645, '<img>': 151646, '</img>': 151647, '<IMG_CONTEXT>': 151648, '<quad>': 151649, '</quad>': 151650, '<ref>': 151651, '</ref>': 151652, '<box>': 151653, '</box>': 151654}
tokenizer.added_tokens_decoder = {v: k for k, v in tokenizer.added_tokens_encoder.items()}
elif model_size in ['2B', '8B', '26B']:
# TODO:
tokenizer.added_tokens_encoder = {'<unk>': 0, '<s>': 1, '</s>': 2, '<|plugin|>': 92538, '<|interpreter|>': 92539, '<|action_end|>': 92540, '<|action_start|>': 92541, '<|im_end|>': 92542, '<|im_start|>': 92543, '<img>': 92544, '</img>': 92545, '<IMG_CONTEXT>': 92546, '<quad>': 92547, '</quad>': 92548, '<ref>': 92549, '</ref>': 92550, '<box>': 92551, '</box>': 92552}
tokenizer.added_tokens_decoder = {v: k for k, v in tokenizer.added_tokens_encoder.items()}
elif model_size in ['4B', '40B', '76B']:
raise NotImplementedError
else:
raise ValueError
num_new_tokens = 0 #tokenizer.add_tokens(token_list, special_tokens=True)
img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) ###
print('tokenizer', tokenizer)
print('num_new_tokens', num_new_tokens) # 0
print('img_context_token_id', img_context_token_id) # 92546
print('len(tokenizer)', len(tokenizer)) # 92553
print('tokenizer.added_tokens_encoder', tokenizer.added_tokens_encoder)
print('tokenizer.added_tokens_decoder', tokenizer.added_tokens_decoder)
# tokenizer.added_tokens_encoder {'<unk>': 0, '<s>': 1, '</s>': 2, '<|plugin|>': 92538, '<|interpreter|>': 92539, '<|action_end|>': 92540, '<|action_start|>': 92541, '<|im_end|>': 92542, '<|im_start|>': 92543, '<img>': 92544, '</img>': 92545, '<IMG_CONTEXT>': 92546, '<quad>': 92547, '</quad>': 92548, '<ref>': 92549, '</ref>': 92550, '<box>': 92551, '</box>': 92552}
if model_args.model_name_or_path is not None:
logger.info('Loading InternVLChatModel...')
config = InternVLChatConfig.from_pretrained(model_args.model_name_or_path)
config.vision_config.drop_path_rate = model_args.drop_path_rate
if config.llm_config.model_type == 'internlm2':
config.llm_config.attn_implementation = 'flash_attention_2' # for InternLM
logger.info('Using flash_attention_2 for InternLM')
else:
config.llm_config._attn_implementation = 'flash_attention_2' # for LLaMA
logger.info('Using flash_attention_2 for LLaMA')
config.template = data_args.conv_style
config.select_layer = model_args.vision_select_layer
config.dynamic_image_size = data_args.dynamic_image_size
config.use_thumbnail = data_args.use_thumbnail
config.ps_version = model_args.ps_version
config.min_dynamic_patch = data_args.min_dynamic_patch
config.max_dynamic_patch = data_args.max_dynamic_patch
model = InternVLChatModel.from_pretrained(
model_args.model_name_or_path, config=config, dtype=dtype) # TODO: fix dtype for layernorm
else:
logger.info('Loading ViT-6B...')
vision_config = InternVisionConfig.from_pretrained(model_args.vision_path)
vision_config.drop_path_rate = model_args.drop_path_rate
vision_model = InternVisionModel.from_pretrained(
model_args.vision_path, config=vision_config, dtype=dtype)
logger.info('Loading LLaMA...')
#llm_config = AutoConfig.from_pretrained(model_args.llm_path, trust_remote_code=True)
from paddlemix.models.internvl2.internlm2 import InternLM2Tokenizer
llm_config = InternLM2Tokenizer.from_pretrained(model_args.llm_path, trust_remote_code=True)
if llm_config.model_type == 'internlm2':
model_type = InternLM2ForCausalLM
llm_config.attn_implementation = 'flash_attention_2' # for InternLM
logger.info('Using flash_attention_2 for InternLM')
else:
model_type = AutoModelForCausalLM
llm_config._attn_implementation = 'flash_attention_2' # for LLaMA
logger.info('Using flash_attention_2 for LLaMA')
llm = model_type.from_pretrained(
model_args.llm_path, dtype=dtype,
config=llm_config, trust_remote_code=True)
logger.info('Building InternVLChatConfig...')
internvl_chat_config = InternVLChatConfig(
vision_config.to_dict(),
llm_config.to_dict(),
downsample_ratio=data_args.down_sample_ratio,
pad2square=data_args.pad2square,
template=data_args.conv_style,
select_layer=model_args.vision_select_layer,
dynamic_image_size=data_args.dynamic_image_size,
use_thumbnail=data_args.use_thumbnail,
ps_version=model_args.ps_version,
min_dynamic_patch=data_args.min_dynamic_patch,
max_dynamic_patch=data_args.max_dynamic_patch,
)
internvl_chat_config.force_image_size = data_args.force_image_size
logger.info('Building InternVLChatModel...')
model = InternVLChatModel(internvl_chat_config, vision_model, llm)
model.img_context_token_id = img_context_token_id
assert model.config.downsample_ratio == data_args.down_sample_ratio
if model_args.mlp_path is not None:
logger.info('Loading pretrained MLP projector...')
state_dict = paddle.load(model_args.mlp_path)
message = model.mlp1.load_state_dict(state_dict)
logger.info(message)
logger.info('Finished')
patch_size = model.config.vision_config.patch_size
logger.info(f'model.config.force_image_size: {model.config.force_image_size}')
logger.info(f'data_args.force_image_size: {data_args.force_image_size}')
logger.info(f'model.config.vision_config.image_size: {model.config.vision_config.image_size}')
if model.config.vision_config.image_size != data_args.force_image_size:
logger.info(f'Resizing position embedding from '
f'{model.config.vision_config.image_size} '
f'to {data_args.force_image_size}...')
model.vision_model.resize_pos_embeddings(old_size=model.config.vision_config.image_size,
new_size=data_args.force_image_size,
patch_size=patch_size)
model.config.vision_config.image_size = data_args.force_image_size
model.config.force_image_size = data_args.force_image_size
model.num_image_token = int((data_args.force_image_size // patch_size) ** 2 * (data_args.down_sample_ratio ** 2))
if num_new_tokens > 0: # false
model.language_model.resize_token_embeddings(len(tokenizer))
output_embeddings = model.language_model.get_output_embeddings().weight
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(axis=0, keepdim=True)
output_embeddings[-num_new_tokens:] = output_embeddings_avg
model.config.llm_config.vocab_size = len(tokenizer)
model.language_model.config.vocab_size = len(tokenizer)
model.language_model.config.use_cache = False
model.vision_model.gradient_checkpointing = True
model.vision_model.encoder.gradient_checkpointing = True
# if model_args.grad_checkpoint:
# model.language_model._set_gradient_checkpointing()
train_dataset = build_datasets(
data_args, tokenizer, tcs_loader, model, group_by_length=training_args.group_by_length,
dynamic_image_size=data_args.dynamic_image_size, use_thumbnail=data_args.use_thumbnail,
min_dynamic_patch=data_args.min_dynamic_patch, max_dynamic_patch=data_args.max_dynamic_patch,
normalize_type=data_args.normalize_type)
def _freeze_params(module):
for param in module.parameters():
param.stop_gradient = not False
if model_args.freeze_backbone:
_freeze_params(model.vision_model)
if model_args.freeze_llm:
model.language_model = model.language_model.eval()
_freeze_params(model.language_model)
if model_args.unfreeze_lm_head:
model.language_model.lm_head.stop_gradient = not True
if model_args.use_backbone_lora:
model.wrap_backbone_lora(r=model_args.use_backbone_lora, lora_alpha=2 * model_args.use_backbone_lora)
model.config.use_backbone_lora = model_args.use_backbone_lora
if model_args.use_llm_lora:
model.wrap_llm_lora(r=model_args.use_llm_lora, lora_alpha=2 * model_args.use_llm_lora)
model.config.use_llm_lora = model_args.use_llm_lora
if model_args.freeze_mlp:
_freeze_params(model.mlp1)
if model_args.unfreeze_vit_layers != 0:
layers = model.vision_model.encoder.layers[model_args.unfreeze_vit_layers:]
for k, v in layers.named_parameters():
logger.info(f'Unfreezing ViT layer: {k}')
v.stop_gradient = not True
# print trainable parameters
if dist.get_rank() == 0:
for name, param in model.named_parameters():
if not param.stop_gradient:
logger.info(name)
# set seed for torch dataloaders
set_seed(training_args.seed)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=None,
tokenizer=tokenizer,
data_collator=concat_pad_data_collator,
)
# Training
if training_args.do_train:
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_model() # Saves the tokenizer too for easy upload
metrics = train_result.metrics
try:
metrics['train_samples'] = len(train_dataset)
except:
metrics['train_samples'] = -1
trainer.log_metrics('train', metrics)
trainer.save_metrics('train', metrics)
trainer.save_state()
if __name__ == '__main__':
main()