Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PPMix No.07】 supporting Aria model #892

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions paddlemix/examples/aria/run_predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import paddle
from PIL import Image
from paddlemix.models.aria.modeling_aria import AriaPretrainedModel, AriaForConditionalGeneration
from paddlemix.processors.processing_aria import AriaProcessor
from paddlemix.processors.aria_vision_processor import AriaVisionProcessor
import json


model_id_or_path = 'rhymes-ai/Aria'

config_path = f"{model_id_or_path}/config.json"
with open(config_path, 'r') as f:
config = json.load(f)
print("Config loaded successfully:")
print(json.dumps(config, indent=2))

try:
model = AriaForConditionalGeneration.from_pretrained(model_id_or_path)
print(11)
processor = AriaProcessor.from_pretrained(model_id_or_path,
trust_remote_code=True)
print(12)
# image = Image.open(requests.get(image_path, stream=True).raw)
image = Image.open('paddlemix/demo_images/examples_image1.jpg').convert('RGB')
print(13)
messages = [{'role': 'user', 'content': [{'text': None, 'type': 'image'}, {
'text': 'what is the image?', 'type': 'text'}]}]
# text = processor.apply_chat_template(messages, add_generation_prompt=True)
text = processor.apply_chat_template(messages)
inputs = processor(text=text, images=image, return_tensors='pd')
inputs['pixel_values'] = inputs['pixel_values'].to(model.dtype)
inputs = {k: v.to(model.place) for k, v in inputs.items()}
with paddle.no_grad(), paddle.amp.auto_cast(dtype='bfloat16'):
output = model.generate(**inputs, max_new_tokens=500, stop_strings=[
'<|im_end|>'], tokenizer=processor.tokenizer, do_sample=True,
temperature=0.9)
output_ids = output[0][tuple(inputs['input_ids'].shape)[1]:]
result = processor.decode(output_ids, skip_special_tokens=True)
print(result)
except Exception as e:
print(f"Error loading model: {e}")
1 change: 1 addition & 0 deletions paddlemix/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# see the license for the specific language governing permissions and
# limitations under the license.

from .aria import *
from .audioldm2.configuration import *
from .audioldm2.modeling import *
from .blip2.modeling import *
Expand Down
7 changes: 7 additions & 0 deletions paddlemix/models/aria/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .configuration_aria import *
from .modeling_aria import *
from .moe_lm import *
from .projector import *
from .vision_encoder import *
from ...processors.processing_aria import *
from ...processors.aria_vision_processor import *
69 changes: 69 additions & 0 deletions paddlemix/models/aria/configuration_aria.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import paddlenlp
import logging
from .moe_lm import AriaMoELMConfig
from .vision_encoder import AriaVisionConfig
from paddlenlp.transformers.configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__)


class AriaConfig(PretrainedConfig):
"""
Configuration class for Aria model.

This class handles the configuration for both vision and text components of the Aria model,
as well as additional parameters for image token handling and projector mapping.

Args:
vision_config (AriaVisionConfig or dict): Configuration for the vision component.
text_config (AriaMoELMConfig or dict): Configuration for the text component.
projector_patch_to_query_dict (dict): Mapping of patch sizes to query dimensions.
ignore_index (int): Index to ignore in loss calculation.
image_token_index (int): Index used to represent image tokens.
**kwargs: Additional keyword arguments passed to the parent class.

Attributes:
model_type (str): Type of the model, set to "aria".
is_composition (bool): Whether the model is a composition of multiple components.
ignore_index (int): Index to ignore in loss calculation.
image_token_index (int): Index used to represent image tokens.
projector_patch_to_query_dict (dict): Mapping of patch sizes to query dimensions.
vision_config (AriaVisionConfig): Configuration for the vision component.
text_config (AriaMoELMConfig): Configuration for the text component.
"""
model_type = 'aria'
is_composition = False

def __init__(self, vision_config=AriaVisionConfig(), text_config=
AriaMoELMConfig(), projector_patch_to_query_dict={(1225): 128, (
4900): 256}, ignore_index=-100, image_token_index=32000,
tie_word_embeddings=False, **kwargs):
super().__init__(**kwargs)
self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.tie_word_embeddings = tie_word_embeddings
attn_implementation = kwargs.pop('attn_implementation', None)
self._attn_implementation = ('flash_attention_2' if
attn_implementation is None else attn_implementation)
self.projector_patch_to_query_dict = {int(k): int(v) for k, v in
projector_patch_to_query_dict.items()}
if isinstance(vision_config, dict) and 'model_type' in vision_config:
vision_config = AriaVisionConfig(**vision_config)
if attn_implementation is None:
vision_attn_implementation = 'flash_attention_2'
elif attn_implementation == 'sdpa':
logger.warning(
'SDPA is not supported for vit, using flash_attention_2 instead'
)
vision_attn_implementation = 'flash_attention_2'
else:
vision_attn_implementation = attn_implementation
vision_config._attn_implementation = vision_attn_implementation
self.vision_config = vision_config
if isinstance(text_config, dict) and 'model_type' in text_config:
text_attn_implementation = ('sdpa' if attn_implementation is
None else attn_implementation)

text_config = AriaMoELMConfig(**text_config)
text_config._attn_implementation = text_attn_implementation
self.text_config = text_config
self.num_hidden_layers = self.text_config.num_hidden_layers
Loading