blip_pretrain.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. '''
  2. * Copyright (c) 2022, salesforce.com, inc.
  3. * All rights reserved.
  4. * SPDX-License-Identifier: BSD-3-Clause
  5. * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
  6. * By Junnan Li
  7. '''
  8. from models.med import BertConfig, BertModel, BertLMHeadModel
  9. from transformers import BertTokenizer
  10. import transformers
  11. transformers.logging.set_verbosity_error()
  12. import torch
  13. from torch import nn
  14. import torch.nn.functional as F
  15. from models.blip import create_vit, init_tokenizer, load_checkpoint
  16. class BLIP_Pretrain(nn.Module):
  17. def __init__(self,
  18. med_config = 'configs/bert_config.json',
  19. image_size = 224,
  20. vit = 'base',
  21. vit_grad_ckpt = False,
  22. vit_ckpt_layer = 0,
  23. embed_dim = 256,
  24. queue_size = 57600,
  25. momentum = 0.995,
  26. ):
  27. """
  28. Args:
  29. med_config (str): path for the mixture of encoder-decoder model's configuration file
  30. image_size (int): input image size
  31. vit (str): model size of vision transformer
  32. """
  33. super().__init__()
  34. self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
  35. if vit=='base':
  36. checkpoint = torch.hub.load_state_dict_from_url(
  37. url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
  38. map_location="cpu", check_hash=True)
  39. state_dict = checkpoint["model"]
  40. msg = self.visual_encoder.load_state_dict(state_dict,strict=False)
  41. elif vit=='large':
  42. from timm.models.helpers import load_custom_pretrained
  43. from timm.models.vision_transformer import default_cfgs
  44. load_custom_pretrained(self.visual_encoder,default_cfgs['vit_large_patch16_224_in21k'])
  45. self.tokenizer = init_tokenizer()
  46. encoder_config = BertConfig.from_json_file(med_config)
  47. encoder_config.encoder_width = vision_width
  48. self.text_encoder = BertModel.from_pretrained('bert-base-uncased',config=encoder_config, add_pooling_layer=False)
  49. self.text_encoder.resize_token_embeddings(len(self.tokenizer))
  50. text_width = self.text_encoder.config.hidden_size
  51. self.vision_proj = nn.Linear(vision_width, embed_dim)
  52. self.text_proj = nn.Linear(text_width, embed_dim)
  53. self.itm_head = nn.Linear(text_width, 2)
  54. # create momentum encoders
  55. self.visual_encoder_m, vision_width = create_vit(vit,image_size)
  56. self.vision_proj_m = nn.Linear(vision_width, embed_dim)
  57. self.text_encoder_m = BertModel(config=encoder_config, add_pooling_layer=False)
  58. self.text_proj_m = nn.Linear(text_width, embed_dim)
  59. self.model_pairs = [[self.visual_encoder,self.visual_encoder_m],
  60. [self.vision_proj,self.vision_proj_m],
  61. [self.text_encoder,self.text_encoder_m],
  62. [self.text_proj,self.text_proj_m],
  63. ]
  64. self.copy_params()
  65. # create the queue
  66. self.register_buffer("image_queue", torch.randn(embed_dim, queue_size))
  67. self.register_buffer("text_queue", torch.randn(embed_dim, queue_size))
  68. self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
  69. self.image_queue = nn.functional.normalize(self.image_queue, dim=0)
  70. self.text_queue = nn.functional.normalize(self.text_queue, dim=0)
  71. self.queue_size = queue_size
  72. self.momentum = momentum
  73. self.temp = nn.Parameter(0.07*torch.ones([]))
  74. # create the decoder
  75. decoder_config = BertConfig.from_json_file(med_config)
  76. decoder_config.encoder_width = vision_width
  77. self.text_decoder = BertLMHeadModel.from_pretrained('bert-base-uncased',config=decoder_config)
  78. self.text_decoder.resize_token_embeddings(len(self.tokenizer))
  79. tie_encoder_decoder_weights(self.text_encoder,self.text_decoder.bert,'','/attention')
  80. def forward(self, image, caption, alpha):
  81. with torch.no_grad():
  82. self.temp.clamp_(0.001,0.5)
  83. image_embeds = self.visual_encoder(image)
  84. image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
  85. image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
  86. text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=30,
  87. return_tensors="pt").to(image.device)
  88. text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
  89. return_dict = True, mode = 'text')
  90. text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
  91. # get momentum features
  92. with torch.no_grad():
  93. self._momentum_update()
  94. image_embeds_m = self.visual_encoder_m(image)
  95. image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1)
  96. image_feat_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1)
  97. text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask,
  98. return_dict = True, mode = 'text')
  99. text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1)
  100. text_feat_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)
  101. sim_i2t_m = image_feat_m @ text_feat_all / self.temp
  102. sim_t2i_m = text_feat_m @ image_feat_all / self.temp
  103. sim_targets = torch.zeros(sim_i2t_m.size()).to(image.device)
  104. sim_targets.fill_diagonal_(1)
  105. sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
  106. sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets
  107. sim_i2t = image_feat @ text_feat_all / self.temp
  108. sim_t2i = text_feat @ image_feat_all / self.temp
  109. loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
  110. loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean()
  111. loss_ita = (loss_i2t+loss_t2i)/2
  112. self._dequeue_and_enqueue(image_feat_m, text_feat_m)
  113. ###============== Image-text Matching ===================###
  114. encoder_input_ids = text.input_ids.clone()
  115. encoder_input_ids[:,0] = self.tokenizer.enc_token_id
  116. # forward the positve image-text pair
  117. bs = image.size(0)
  118. output_pos = self.text_encoder(encoder_input_ids,
  119. attention_mask = text.attention_mask,
  120. encoder_hidden_states = image_embeds,
  121. encoder_attention_mask = image_atts,
  122. return_dict = True,
  123. )
  124. with torch.no_grad():
  125. weights_t2i = F.softmax(sim_t2i[:,:bs],dim=1)+1e-4
  126. weights_t2i.fill_diagonal_(0)
  127. weights_i2t = F.softmax(sim_i2t[:,:bs],dim=1)+1e-4
  128. weights_i2t.fill_diagonal_(0)
  129. # select a negative image for each text
  130. image_embeds_neg = []
  131. for b in range(bs):
  132. neg_idx = torch.multinomial(weights_t2i[b], 1).item()
  133. image_embeds_neg.append(image_embeds[neg_idx])
  134. image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
  135. # select a negative text for each image
  136. text_ids_neg = []
  137. text_atts_neg = []
  138. for b in range(bs):
  139. neg_idx = torch.multinomial(weights_i2t[b], 1).item()
  140. text_ids_neg.append(encoder_input_ids[neg_idx])
  141. text_atts_neg.append(text.attention_mask[neg_idx])
  142. text_ids_neg = torch.stack(text_ids_neg,dim=0)
  143. text_atts_neg = torch.stack(text_atts_neg,dim=0)
  144. text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0)
  145. text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0)
  146. image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0)
  147. image_atts_all = torch.cat([image_atts,image_atts],dim=0)
  148. output_neg = self.text_encoder(text_ids_all,
  149. attention_mask = text_atts_all,
  150. encoder_hidden_states = image_embeds_all,
  151. encoder_attention_mask = image_atts_all,
  152. return_dict = True,
  153. )
  154. vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0)
  155. vl_output = self.itm_head(vl_embeddings)
  156. itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)],
  157. dim=0).to(image.device)
  158. loss_itm = F.cross_entropy(vl_output, itm_labels)
  159. ##================= LM ========================##
  160. decoder_input_ids = text.input_ids.clone()
  161. decoder_input_ids[:,0] = self.tokenizer.bos_token_id
  162. decoder_targets = decoder_input_ids.masked_fill(decoder_input_ids == self.tokenizer.pad_token_id, -100)
  163. decoder_output = self.text_decoder(decoder_input_ids,
  164. attention_mask = text.attention_mask,
  165. encoder_hidden_states = image_embeds,
  166. encoder_attention_mask = image_atts,
  167. labels = decoder_targets,
  168. return_dict = True,
  169. )
  170. loss_lm = decoder_output.loss
  171. return loss_ita, loss_itm, loss_lm
  172. @torch.no_grad()
  173. def copy_params(self):
  174. for model_pair in self.model_pairs:
  175. for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
  176. param_m.data.copy_(param.data) # initialize
  177. param_m.requires_grad = False # not update by gradient
  178. @torch.no_grad()
  179. def _momentum_update(self):
  180. for model_pair in self.model_pairs:
  181. for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
  182. param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)
  183. @torch.no_grad()
  184. def _dequeue_and_enqueue(self, image_feat, text_feat):
  185. # gather keys before updating queue
  186. image_feats = concat_all_gather(image_feat)
  187. text_feats = concat_all_gather(text_feat)
  188. batch_size = image_feats.shape[0]
  189. ptr = int(self.queue_ptr)
  190. assert self.queue_size % batch_size == 0 # for simplicity
  191. # replace the keys at ptr (dequeue and enqueue)
  192. self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
  193. self.text_queue[:, ptr:ptr + batch_size] = text_feats.T
  194. ptr = (ptr + batch_size) % self.queue_size # move pointer
  195. self.queue_ptr[0] = ptr
  196. def blip_pretrain(**kwargs):
  197. model = BLIP_Pretrain(**kwargs)
  198. return model
  199. @torch.no_grad()
  200. def concat_all_gather(tensor):
  201. """
  202. Performs all_gather operation on the provided tensors.
  203. *** Warning ***: torch.distributed.all_gather has no gradient.
  204. """
  205. tensors_gather = [torch.ones_like(tensor)
  206. for _ in range(torch.distributed.get_world_size())]
  207. torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
  208. output = torch.cat(tensors_gather, dim=0)
  209. return output
  210. from typing import List
  211. def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_key:str):
  212. uninitialized_encoder_weights: List[str] = []
  213. if decoder.__class__ != encoder.__class__:
  214. logger.info(
  215. f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized."
  216. )
  217. def tie_encoder_to_decoder_recursively(
  218. decoder_pointer: nn.Module,
  219. encoder_pointer: nn.Module,
  220. module_name: str,
  221. uninitialized_encoder_weights: List[str],
  222. skip_key: str,
  223. depth=0,
  224. ):
  225. assert isinstance(decoder_pointer, nn.Module) and isinstance(
  226. encoder_pointer, nn.Module
  227. ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module"
  228. if hasattr(decoder_pointer, "weight") and skip_key not in module_name:
  229. assert hasattr(encoder_pointer, "weight")
  230. encoder_pointer.weight = decoder_pointer.weight
  231. if hasattr(decoder_pointer, "bias"):
  232. assert hasattr(encoder_pointer, "bias")
  233. encoder_pointer.bias = decoder_pointer.bias
  234. print(module_name+' is tied')
  235. return
  236. encoder_modules = encoder_pointer._modules
  237. decoder_modules = decoder_pointer._modules
  238. if len(decoder_modules) > 0:
  239. assert (
  240. len(encoder_modules) > 0
  241. ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
  242. all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()])
  243. encoder_layer_pos = 0
  244. for name, module in decoder_modules.items():
  245. if name.isdigit():
  246. encoder_name = str(int(name) + encoder_layer_pos)
  247. decoder_name = name
  248. if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])) and len(
  249. encoder_modules
  250. ) != len(decoder_modules):
  251. # this can happen if the name corresponds to the position in a list module list of layers
  252. # in this case the decoder has added a cross-attention that the encoder does not have
  253. # thus skip this step and subtract one layer pos from encoder
  254. encoder_layer_pos -= 1
  255. continue
  256. elif name not in encoder_modules:
  257. continue
  258. elif depth > 500:
  259. raise ValueError(
  260. "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model."
  261. )
  262. else:
  263. decoder_name = encoder_name = name
  264. tie_encoder_to_decoder_recursively(
  265. decoder_modules[decoder_name],
  266. encoder_modules[encoder_name],
  267. module_name + "/" + name,
  268. uninitialized_encoder_weights,
  269. skip_key,
  270. depth=depth + 1,
  271. )
  272. all_encoder_weights.remove(module_name + "/" + encoder_name)
  273. uninitialized_encoder_weights += list(all_encoder_weights)
  274. # tie weights recursively
  275. tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights, skip_key)