Shortcuts

Source code for mmpretrain.models.multimodal.minigpt4.minigpt4

# Copyright (c) OpenMMLab. All rights reserved.
import random
import re
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
from mmengine.logging import MMLogger
from mmengine.model import BaseModel

from mmpretrain.registry import MODELS, TOKENIZER
from mmpretrain.structures import DataSample


[docs]@MODELS.register_module() class MiniGPT4(BaseModel): """The multi-modality model of MiniGPT-4. The implementation of `MiniGPT-4 <https://arxiv.org/abs/2304.10592>`_. Modified from https://github.com/Vision-CAIR/MiniGPT-4/blob/main/minigpt4/models/mini_gpt4.py Args: vision_encoder (dict): The config for vision encoder. q_former_model (dict): The config for Qformer. lang_encoder (dict): The config for language model. tokenizer (dict): The config for tokenizer. task (str): To define the task, which control the processing of text. Defaults to 'caption'. freeze_vit (bool): Freeze the training of ViT. Defaults to True. freeze_q_former (bool): Freeze the training of Qformer. Defaults to True. num_query_token (int): Number of query tokens of Qformer. Defaults to 32. prompt_template (dict): Multi-language prompt template of the model. Defaults to dict([ ('en', '###Ask: {} ###Answer: '), ('zh', '###问:{} ###答:')]) raw_prompts (dict): Prompts for training. Defaults to dict(). max_txt_len (int): Max token length while doing tokenization. Defaults to 32. end_sym (str): Ended symbol of the sequence. Defaults to '###'. generation_cfg (dict): The config of text generation. Defaults to dict(). data_preprocessor (:obj:`BaseDataPreprocessor`): Used for pre-processing data sampled by dataloader to the format accepted by :meth:`forward`. Defaults to None. init_cfg (dict): Initialization config dict. Defaults to None. """ # noqa def __init__(self, vision_encoder: dict, q_former_model: dict, lang_encoder: dict, tokenizer: dict, task: str = 'caption', freeze_vit: bool = True, freeze_q_former: bool = True, num_query_token: int = 32, prompt_template: dict = dict([('en', '###Ask: {} ###Answer: '), ('zh', '###问:{} ###答:')]), raw_prompts: dict = dict(), max_txt_len: int = 32, end_sym: str = '###', generation_cfg: dict = dict(), data_preprocessor: Optional[dict] = None, init_cfg: Optional[dict] = None): if data_preprocessor is None: data_preprocessor = {} data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor') data_preprocessor = MODELS.build(data_preprocessor) super().__init__( data_preprocessor=data_preprocessor, init_cfg=init_cfg) self.task = task logger = MMLogger.get_current_instance() # build vision model vision_encoder_weight = vision_encoder.pop('pretrained', None) self.vision_encoder = MODELS.build(vision_encoder) self.ln_vision = nn.LayerNorm(self.vision_encoder.embed_dims) if vision_encoder_weight is not None: from mmengine.runner.checkpoint import load_checkpoint load_checkpoint(self.vision_encoder, vision_encoder_weight) self.vision_encoder.is_init = True if freeze_vit: for name, param in self.ln_vision.named_parameters(): param.requires_grad = False self.ln_vision = self.ln_vision.eval() else: logger.warning('Please check `frozen_stages` in the dict of' '`vision_encoder`. Also set it to be -1 if do not' 'freeze ViT.') # build Qformer q_former_model_weight = q_former_model.pop('pretrained', None) self.q_former = MODELS.build(q_former_model) self.q_former.cls = None self.q_former.bert.embeddings.word_embeddings = None self.q_former.bert.embeddings.position_embeddings = None for layer in self.q_former.bert.encoder.layer: layer.output = None layer.intermediate = None self.query_tokens = nn.Parameter( torch.zeros(1, num_query_token, self.q_former.config.hidden_size)) self.query_tokens.data.normal_( mean=0.0, std=self.q_former.config.initializer_range) if q_former_model_weight is not None: from mmengine.runner.checkpoint import CheckpointLoader state_dict = CheckpointLoader.load_checkpoint( q_former_model_weight)['state_dict'] self.load_state_dict(state_dict, strict=False) # The ln_vision weights are also in the q-former checkpoint. setattr(self.ln_vision, 'is_init', True) setattr(self.q_former, 'is_init', True) if freeze_q_former: for name, param in self.q_former.named_parameters(): param.requires_grad = False self.q_former.eval() self.query_tokens.requires_grad = False # build language model self.llama_tokenizer = TOKENIZER.build(tokenizer) self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token self.llama_model = MODELS.build(lang_encoder) for name, param in self.llama_model.named_parameters(): param.requires_grad = False # build linear projection layer self.llama_proj = nn.Linear(self.q_former.config.hidden_size, self.llama_model.config.hidden_size) self.max_txt_len = max_txt_len self.end_sym = end_sym self.end_token_id = self.llama_tokenizer.encode(end_sym)[-1] # set prompts self.en_prompt_list, self.zh_prompt_list = [], [] if raw_prompts.get('en') is not None: en_filted_prompts = [ raw_prompt for raw_prompt in raw_prompts['en'] if '<ImageHere>' in raw_prompt ] self.en_prompt_list = [ prompt_template['en'].format(p) for p in en_filted_prompts ] if raw_prompts.get('zh') is not None: zh_filted_prompts = [ raw_prompt for raw_prompt in raw_prompts['zh'] if '<ImageHere>' in raw_prompt ] self.zh_prompt_list = [ prompt_template['zh'].format(p) for p in zh_filted_prompts ] # update generation configs self.generation_cfg = dict( max_new_tokens=300, num_beams=1, do_sample=True, min_length=1, top_p=0.9, repetition_penalty=1.1, length_penalty=1.0, temperature=1.0) self.generation_cfg.update(**generation_cfg) if hasattr(self, 'register_load_state_dict_post_hook'): self.register_load_state_dict_post_hook(self._load_llama_proj_hook) def half(self): self.llama_model = self.llama_model.half() return self
[docs] def encode_img(self, images: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """The function to encode the images.""" device = images.device x = self.vision_encoder(images)[0] image_embeds = self.ln_vision(x).to(device) image_atts = torch.ones( image_embeds.size()[:-1], dtype=torch.long).to(device) query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) query_output = self.q_former.bert( query_embeds=query_tokens, encoder_hidden_states=image_embeds, encoder_attention_mask=image_atts, return_dict=True, ) inputs_llama = self.llama_proj(query_output.last_hidden_state) atts_llama = torch.ones( inputs_llama.size()[:-1], dtype=torch.long).to(images.device) return inputs_llama, atts_llama
[docs] def prompt_wrap(self, img_embeds: torch.Tensor, atts_img: torch.Tensor, prompt: List[str]) -> Tuple[torch.Tensor, torch.Tensor]: """The function to wrap the image and prompt. Make sure that len(prompt) == img_embeds.shape[0]. Args: img_embeds (torch.Tensor): The embedding of the input images. atts_img (torch.Tensor): Attention map of the image embeddings. prompt (List[str]): The prompt of the batch data. Returns: Tuple[torch.Tensor, torch.Tensor]: The embedding and attention map. """ if len(prompt) > 0: p_before_list, p_after_list = [], [] for pro in prompt: p_before, p_after = pro.split('<ImageHere>') p_before_list.append(p_before) p_after_list.append(p_after) p_before_tokens = self.llama_tokenizer( p_before_list, return_tensors='pt', padding='longest', add_special_tokens=False).to(img_embeds.device) p_after_tokens = self.llama_tokenizer( p_after_list, return_tensors='pt', padding='longest', add_special_tokens=False).to(img_embeds.device) p_before_embeds = self.llama_model.model.embed_tokens( p_before_tokens.input_ids) p_after_embeds = self.llama_model.model.embed_tokens( p_after_tokens.input_ids) wrapped_img_embeds = torch.cat( [p_before_embeds, img_embeds, p_after_embeds], dim=1) wrapped_atts_img = atts_img[:, :1].expand( -1, wrapped_img_embeds.shape[1]) return wrapped_img_embeds, wrapped_atts_img else: return img_embeds, atts_img
[docs] def loss(self, images: torch.Tensor, data_samples: Optional[List[DataSample]] = None) -> dict: """The forward function in training. Args: inputs (List[torch.Tensor]): The input images. data_samples (List[DataSample]): All elements required during the forward function. Returns: Dict[str, torch.Tensor]: A dictionary of loss components. """ img_embeds, atts_img = self.encode_img(images) self.llama_tokenizer.padding_side = 'right' prompts, texts = [], [] for t in data_samples: chat_content = t.chat_content split_mark = '###Answer: ' if t.lang == 'en' else '###答:' prompt, text = chat_content.split(split_mark) prompt += split_mark text += self.end_sym prompts.append(prompt) texts.append(text) img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompts) to_regress_tokens = self.llama_tokenizer( texts, return_tensors='pt', padding='longest', truncation=True, max_length=self.max_txt_len, add_special_tokens=False).to(images.device) targets = to_regress_tokens.input_ids.masked_fill( to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100) empty_targets = ( torch.ones([atts_img.shape[0], atts_img.shape[1] + 1], dtype=torch.long).to(images.device).fill_( -100) # plus one for bos ) targets = torch.cat([empty_targets, targets], dim=1) batch_size = img_embeds.shape[0] bos = torch.ones([batch_size, 1], dtype=to_regress_tokens.input_ids.dtype, device=to_regress_tokens.input_ids.device ) * self.llama_tokenizer.bos_token_id bos_embeds = self.llama_model.model.embed_tokens(bos) atts_bos = atts_img[:, :1] to_regress_embeds = self.llama_model.model.embed_tokens( to_regress_tokens.input_ids) inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1) attention_mask = torch.cat( [atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1) outputs = self.llama_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, return_dict=True, labels=targets, ) loss = outputs.loss return dict(loss=loss)
def predict( self, images: torch.Tensor, data_samples: Optional[List[DataSample]] = None ) -> List[DataSample]: with torch.no_grad(): img_embeds, atts_img = self.encode_img(images) prompts = [ random.choice(self.zh_prompt_list) if hasattr(t, 'lang') and t.lang == 'zh' else random.choice(self.en_prompt_list) for t in data_samples ] img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompts) batch_size = img_embeds.shape[0] bos = torch.ones( [batch_size, 1], dtype=torch.long, device=img_embeds.device) * self.llama_tokenizer.bos_token_id bos_embeds = self.llama_model.model.embed_tokens(bos) inputs_embeds = torch.cat([bos_embeds, img_embeds], dim=1) outputs = self.llama_model.generate( inputs_embeds=inputs_embeds, eos_token_id=self.end_token_id, **self.generation_cfg) return self.post_process(outputs, data_samples)
[docs] def post_process( self, outputs: torch.Tensor, data_samples: Optional[List[DataSample]]) -> List[DataSample]: """Perform post process for outputs for different task. Args: outputs (torch.Tensor): The generated outputs. data_samples (List[DataSample], optional): The annotation data of every samples. Returns: List[DataSample]: Return list of data samples. """ outputs = self.llama_tokenizer.batch_decode( outputs, skip_special_tokens=True) if data_samples is None: data_samples = [DataSample() for _ in range(len(outputs))] for output, data_sample in zip(outputs, data_samples): if self.task == 'caption': output = output.split('###')[0] data_sample.pred_caption = output else: # raw output data_sample.pred_output = output return data_samples
[docs] def forward( self, images: torch.Tensor, data_samples: Optional[list] = None, mode: str = 'predict', **kwargs, ): """The unified entry for a forward process in both training and test. The method accepts the following modes: - "predict": Forward and return a list of data samples contain the predict results. Args: images (torch.Tensor): the preprocessed image tensor of shape ``(N, C, H, W)``. data_samples (List[DataSample], optional): The annotation data of every samples. Defaults to None. mode (str): Return what kind of value. Defaults to 'predict'. """ if mode == 'loss': return self.loss(images, data_samples) elif mode == 'predict': return self.predict(images, data_samples, **kwargs) else: raise RuntimeError(f'Invalid mode "{mode}".')
@staticmethod def _load_llama_proj_hook(module, incompatible_keys): """Avoid warning missing keys except LLaMA projection keys.""" proj_patterns = [ 'vision_encoder.*', 'ln_vision.*', 'q_former.*', 'query_tokens', 'llama_model.*', ] for key in list(incompatible_keys.missing_keys): if any(re.match(pattern, key) for pattern in proj_patterns): incompatible_keys.missing_keys.remove(key)
Read the Docs v: latest
Versions
latest
stable
mmcls-1.x
mmcls-0.x
dev
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.