blip.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. '''
  2. * Copyright (c) 2022, salesforce.com, inc.
  3. * All rights reserved.
  4. * SPDX-License-Identifier: BSD-3-Clause
  5. * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
  6. * By Junnan Li
  7. '''
  8. import warnings
  9. warnings.filterwarnings("ignore")
  10. from models.vit import VisionTransformer, interpolate_pos_embed
  11. from models.med import BertConfig, BertModel, BertLMHeadModel
  12. from transformers import BertTokenizer
  13. import torch
  14. from torch import nn
  15. import torch.nn.functional as F
  16. import os
  17. from urllib.parse import urlparse
  18. from timm.models.hub import download_cached_file
  19. class BLIP_Base(nn.Module):
  20. def __init__(self,
  21. med_config = 'configs/med_config.json',
  22. image_size = 224,
  23. vit = 'base',
  24. vit_grad_ckpt = False,
  25. vit_ckpt_layer = 0,
  26. ):
  27. """
  28. Args:
  29. med_config (str): path for the mixture of encoder-decoder model's configuration file
  30. image_size (int): input image size
  31. vit (str): model size of vision transformer
  32. """
  33. super().__init__()
  34. self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
  35. self.tokenizer = init_tokenizer()
  36. med_config = BertConfig.from_json_file(med_config)
  37. med_config.encoder_width = vision_width
  38. self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
  39. def forward(self, image, caption, mode):
  40. assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal"
  41. text = self.tokenizer(caption, return_tensors="pt").to(image.device)
  42. if mode=='image':
  43. # return image features
  44. image_embeds = self.visual_encoder(image)
  45. return image_embeds
  46. elif mode=='text':
  47. # return text features
  48. text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
  49. return_dict = True, mode = 'text')
  50. return text_output.last_hidden_state
  51. elif mode=='multimodal':
  52. # return multimodel features
  53. image_embeds = self.visual_encoder(image)
  54. image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
  55. text.input_ids[:,0] = self.tokenizer.enc_token_id
  56. output = self.text_encoder(text.input_ids,
  57. attention_mask = text.attention_mask,
  58. encoder_hidden_states = image_embeds,
  59. encoder_attention_mask = image_atts,
  60. return_dict = True,
  61. )
  62. return output.last_hidden_state
  63. current_dir = os.path.dirname(__file__)
  64. default_config_path = os.path.join(current_dir, '..', 'configs', 'med_config.json')
  65. class BLIP_Decoder(nn.Module):
  66. def __init__(self,
  67. med_config = default_config_path,
  68. image_size = 384,
  69. vit = 'base',
  70. vit_grad_ckpt = False,
  71. vit_ckpt_layer = 0,
  72. prompt = 'a picture of ',
  73. ):
  74. """
  75. Args:
  76. med_config (str): path for the mixture of encoder-decoder model's configuration file
  77. image_size (int): input image size
  78. vit (str): model size of vision transformer
  79. """
  80. super().__init__()
  81. self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
  82. self.tokenizer = init_tokenizer()
  83. med_config = BertConfig.from_json_file(med_config)
  84. med_config.encoder_width = vision_width
  85. self.text_decoder = BertLMHeadModel(config=med_config)
  86. self.prompt = prompt
  87. self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1
  88. def forward(self, image, caption):
  89. image_embeds = self.visual_encoder(image)
  90. image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
  91. text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device)
  92. text.input_ids[:,0] = self.tokenizer.bos_token_id
  93. decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
  94. decoder_targets[:,:self.prompt_length] = -100
  95. decoder_output = self.text_decoder(text.input_ids,
  96. attention_mask = text.attention_mask,
  97. encoder_hidden_states = image_embeds,
  98. encoder_attention_mask = image_atts,
  99. labels = decoder_targets,
  100. return_dict = True,
  101. )
  102. loss_lm = decoder_output.loss
  103. return loss_lm
  104. def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
  105. image_embeds = self.visual_encoder(image)
  106. if not sample:
  107. image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
  108. image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
  109. model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
  110. prompt = [self.prompt] * image.size(0)
  111. input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
  112. input_ids[:,0] = self.tokenizer.bos_token_id
  113. input_ids = input_ids[:, :-1]
  114. if sample:
  115. #nucleus sampling
  116. outputs = self.text_decoder.generate(input_ids=input_ids,
  117. max_length=max_length,
  118. min_length=min_length,
  119. do_sample=True,
  120. top_p=top_p,
  121. num_return_sequences=1,
  122. eos_token_id=self.tokenizer.sep_token_id,
  123. pad_token_id=self.tokenizer.pad_token_id,
  124. repetition_penalty=1.1,
  125. **model_kwargs)
  126. else:
  127. #beam search
  128. outputs = self.text_decoder.generate(input_ids=input_ids,
  129. max_length=max_length,
  130. min_length=min_length,
  131. num_beams=num_beams,
  132. eos_token_id=self.tokenizer.sep_token_id,
  133. pad_token_id=self.tokenizer.pad_token_id,
  134. repetition_penalty=repetition_penalty,
  135. **model_kwargs)
  136. captions = []
  137. for output in outputs:
  138. caption = self.tokenizer.decode(output, skip_special_tokens=True)
  139. captions.append(caption[len(self.prompt):])
  140. return captions
  141. def blip_decoder(pretrained='',**kwargs):
  142. model = BLIP_Decoder(**kwargs)
  143. if pretrained:
  144. model,msg = load_checkpoint(model,pretrained)
  145. print("Missing keys:", msg.missing_keys)
  146. assert(len(msg.missing_keys)==0)
  147. return model
  148. def blip_feature_extractor(pretrained='',**kwargs):
  149. model = BLIP_Base(**kwargs)
  150. if pretrained:
  151. model,msg = load_checkpoint(model,pretrained)
  152. assert(len(msg.missing_keys)==0)
  153. return model
  154. def init_tokenizer():
  155. tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
  156. tokenizer.add_special_tokens({'bos_token':'[DEC]'})
  157. tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
  158. tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
  159. return tokenizer
  160. def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
  161. assert vit in ['base', 'large'], "vit parameter must be base or large"
  162. if vit=='base':
  163. vision_width = 768
  164. visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
  165. num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
  166. drop_path_rate=0 or drop_path_rate
  167. )
  168. elif vit=='large':
  169. vision_width = 1024
  170. visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
  171. num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
  172. drop_path_rate=0.1 or drop_path_rate
  173. )
  174. return visual_encoder, vision_width
  175. def is_url(url_or_filename):
  176. parsed = urlparse(url_or_filename)
  177. return parsed.scheme in ("http", "https")
  178. def load_checkpoint(model,url_or_filename):
  179. if is_url(url_or_filename):
  180. cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
  181. checkpoint = torch.load(cached_file, map_location='cpu')
  182. elif os.path.isfile(url_or_filename):
  183. checkpoint = torch.load(url_or_filename, map_location='cpu')
  184. else:
  185. raise RuntimeError('checkpoint url or path is invalid')
  186. state_dict = checkpoint['model']
  187. state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
  188. if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
  189. state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
  190. model.visual_encoder_m)
  191. for key in model.state_dict().keys():
  192. if key in state_dict.keys():
  193. if state_dict[key].shape!=model.state_dict()[key].shape:
  194. del state_dict[key]
  195. msg = model.load_state_dict(state_dict,strict=False)
  196. print('load checkpoint from %s'%url_or_filename)
  197. return model,msg