blip_retrieval.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. from models.med import BertConfig, BertModel
  2. from transformers import BertTokenizer
  3. import torch
  4. from torch import nn
  5. import torch.nn.functional as F
  6. from models.blip import create_vit, init_tokenizer, load_checkpoint
  7. class BLIP_Retrieval(nn.Module):
  8. def __init__(self,
  9. med_config = 'configs/med_config.json',
  10. image_size = 384,
  11. vit = 'base',
  12. vit_grad_ckpt = False,
  13. vit_ckpt_layer = 0,
  14. embed_dim = 256,
  15. queue_size = 57600,
  16. momentum = 0.995,
  17. negative_all_rank = False,
  18. ):
  19. """
  20. Args:
  21. med_config (str): path for the mixture of encoder-decoder model's configuration file
  22. image_size (int): input image size
  23. vit (str): model size of vision transformer
  24. """
  25. super().__init__()
  26. self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
  27. self.tokenizer = init_tokenizer()
  28. med_config = BertConfig.from_json_file(med_config)
  29. med_config.encoder_width = vision_width
  30. self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
  31. text_width = self.text_encoder.config.hidden_size
  32. self.vision_proj = nn.Linear(vision_width, embed_dim)
  33. self.text_proj = nn.Linear(text_width, embed_dim)
  34. self.itm_head = nn.Linear(text_width, 2)
  35. # create momentum encoders
  36. self.visual_encoder_m, vision_width = create_vit(vit,image_size)
  37. self.vision_proj_m = nn.Linear(vision_width, embed_dim)
  38. self.text_encoder_m = BertModel(config=med_config, add_pooling_layer=False)
  39. self.text_proj_m = nn.Linear(text_width, embed_dim)
  40. self.model_pairs = [[self.visual_encoder,self.visual_encoder_m],
  41. [self.vision_proj,self.vision_proj_m],
  42. [self.text_encoder,self.text_encoder_m],
  43. [self.text_proj,self.text_proj_m],
  44. ]
  45. self.copy_params()
  46. # create the queue
  47. self.register_buffer("image_queue", torch.randn(embed_dim, queue_size))
  48. self.register_buffer("text_queue", torch.randn(embed_dim, queue_size))
  49. self.register_buffer("idx_queue", torch.full((1,queue_size),-100))
  50. self.register_buffer("ptr_queue", torch.zeros(1, dtype=torch.long))
  51. self.image_queue = nn.functional.normalize(self.image_queue, dim=0)
  52. self.text_queue = nn.functional.normalize(self.text_queue, dim=0)
  53. self.queue_size = queue_size
  54. self.momentum = momentum
  55. self.temp = nn.Parameter(0.07*torch.ones([]))
  56. self.negative_all_rank = negative_all_rank
  57. def forward(self, image, caption, alpha, idx):
  58. with torch.no_grad():
  59. self.temp.clamp_(0.001,0.5)
  60. image_embeds = self.visual_encoder(image)
  61. image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
  62. image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
  63. text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35,
  64. return_tensors="pt").to(image.device)
  65. text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
  66. return_dict = True, mode = 'text')
  67. text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
  68. ###============== Image-text Contrastive Learning ===================###
  69. idx = idx.view(-1,1)
  70. idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()],dim=1)
  71. pos_idx = torch.eq(idx, idx_all).float()
  72. sim_targets = pos_idx / pos_idx.sum(1,keepdim=True)
  73. # get momentum features
  74. with torch.no_grad():
  75. self._momentum_update()
  76. image_embeds_m = self.visual_encoder_m(image)
  77. image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1)
  78. image_feat_m_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1)
  79. text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask,
  80. return_dict = True, mode = 'text')
  81. text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1)
  82. text_feat_m_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)
  83. sim_i2t_m = image_feat_m @ text_feat_m_all / self.temp
  84. sim_t2i_m = text_feat_m @ image_feat_m_all / self.temp
  85. sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
  86. sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets
  87. sim_i2t = image_feat @ text_feat_m_all / self.temp
  88. sim_t2i = text_feat @ image_feat_m_all / self.temp
  89. loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
  90. loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean()
  91. loss_ita = (loss_i2t+loss_t2i)/2
  92. idxs = concat_all_gather(idx)
  93. self._dequeue_and_enqueue(image_feat_m, text_feat_m, idxs)
  94. ###============== Image-text Matching ===================###
  95. encoder_input_ids = text.input_ids.clone()
  96. encoder_input_ids[:,0] = self.tokenizer.enc_token_id
  97. # forward the positve image-text pair
  98. bs = image.size(0)
  99. output_pos = self.text_encoder(encoder_input_ids,
  100. attention_mask = text.attention_mask,
  101. encoder_hidden_states = image_embeds,
  102. encoder_attention_mask = image_atts,
  103. return_dict = True,
  104. )
  105. if self.negative_all_rank:
  106. # compute sample similarity
  107. with torch.no_grad():
  108. mask = torch.eq(idx, idxs.t())
  109. image_feat_world = concat_all_gather(image_feat)
  110. text_feat_world = concat_all_gather(text_feat)
  111. sim_i2t = image_feat @ text_feat_world.t() / self.temp
  112. sim_t2i = text_feat @ image_feat_world.t() / self.temp
  113. weights_i2t = F.softmax(sim_i2t,dim=1)
  114. weights_i2t.masked_fill_(mask, 0)
  115. weights_t2i = F.softmax(sim_t2i,dim=1)
  116. weights_t2i.masked_fill_(mask, 0)
  117. image_embeds_world = all_gather_with_grad(image_embeds)
  118. # select a negative image (from all ranks) for each text
  119. image_embeds_neg = []
  120. for b in range(bs):
  121. neg_idx = torch.multinomial(weights_t2i[b], 1).item()
  122. image_embeds_neg.append(image_embeds_world[neg_idx])
  123. image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
  124. # select a negative text (from all ranks) for each image
  125. input_ids_world = concat_all_gather(encoder_input_ids)
  126. att_mask_world = concat_all_gather(text.attention_mask)
  127. text_ids_neg = []
  128. text_atts_neg = []
  129. for b in range(bs):
  130. neg_idx = torch.multinomial(weights_i2t[b], 1).item()
  131. text_ids_neg.append(input_ids_world[neg_idx])
  132. text_atts_neg.append(att_mask_world[neg_idx])
  133. else:
  134. with torch.no_grad():
  135. mask = torch.eq(idx, idx.t())
  136. sim_i2t = image_feat @ text_feat.t() / self.temp
  137. sim_t2i = text_feat @ image_feat.t() / self.temp
  138. weights_i2t = F.softmax(sim_i2t,dim=1)
  139. weights_i2t.masked_fill_(mask, 0)
  140. weights_t2i = F.softmax(sim_t2i,dim=1)
  141. weights_t2i.masked_fill_(mask, 0)
  142. # select a negative image (from same rank) for each text
  143. image_embeds_neg = []
  144. for b in range(bs):
  145. neg_idx = torch.multinomial(weights_t2i[b], 1).item()
  146. image_embeds_neg.append(image_embeds[neg_idx])
  147. image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
  148. # select a negative text (from same rank) for each image
  149. text_ids_neg = []
  150. text_atts_neg = []
  151. for b in range(bs):
  152. neg_idx = torch.multinomial(weights_i2t[b], 1).item()
  153. text_ids_neg.append(encoder_input_ids[neg_idx])
  154. text_atts_neg.append(text.attention_mask[neg_idx])
  155. text_ids_neg = torch.stack(text_ids_neg,dim=0)
  156. text_atts_neg = torch.stack(text_atts_neg,dim=0)
  157. text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0)
  158. text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0)
  159. image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0)
  160. image_atts_all = torch.cat([image_atts,image_atts],dim=0)
  161. output_neg = self.text_encoder(text_ids_all,
  162. attention_mask = text_atts_all,
  163. encoder_hidden_states = image_embeds_all,
  164. encoder_attention_mask = image_atts_all,
  165. return_dict = True,
  166. )
  167. vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0)
  168. vl_output = self.itm_head(vl_embeddings)
  169. itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)],
  170. dim=0).to(image.device)
  171. loss_itm = F.cross_entropy(vl_output, itm_labels)
  172. return loss_ita, loss_itm
  173. @torch.no_grad()
  174. def copy_params(self):
  175. for model_pair in self.model_pairs:
  176. for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
  177. param_m.data.copy_(param.data) # initialize
  178. param_m.requires_grad = False # not update by gradient
  179. @torch.no_grad()
  180. def _momentum_update(self):
  181. for model_pair in self.model_pairs:
  182. for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
  183. param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)
  184. @torch.no_grad()
  185. def _dequeue_and_enqueue(self, image_feat, text_feat, idxs):
  186. # gather keys before updating queue
  187. image_feats = concat_all_gather(image_feat)
  188. text_feats = concat_all_gather(text_feat)
  189. batch_size = image_feats.shape[0]
  190. ptr = int(self.ptr_queue)
  191. assert self.queue_size % batch_size == 0 # for simplicity
  192. # replace the keys at ptr (dequeue and enqueue)
  193. self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
  194. self.text_queue[:, ptr:ptr + batch_size] = text_feats.T
  195. self.idx_queue[:, ptr:ptr + batch_size] = idxs.T
  196. ptr = (ptr + batch_size) % self.queue_size # move pointer
  197. self.ptr_queue[0] = ptr
  198. def blip_retrieval(pretrained='',**kwargs):
  199. model = BLIP_Retrieval(**kwargs)
  200. if pretrained:
  201. model,msg = load_checkpoint(model,pretrained)
  202. print("missing keys:")
  203. print(msg.missing_keys)
  204. return model
  205. @torch.no_grad()
  206. def concat_all_gather(tensor):
  207. """
  208. Performs all_gather operation on the provided tensors.
  209. *** Warning ***: torch.distributed.all_gather has no gradient.
  210. """
  211. tensors_gather = [torch.ones_like(tensor)
  212. for _ in range(torch.distributed.get_world_size())]
  213. torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
  214. output = torch.cat(tensors_gather, dim=0)
  215. return output
  216. class GatherLayer(torch.autograd.Function):
  217. """
  218. Gather tensors from all workers with support for backward propagation:
  219. This implementation does not cut the gradients as torch.distributed.all_gather does.
  220. """
  221. @staticmethod
  222. def forward(ctx, x):
  223. output = [torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())]
  224. torch.distributed.all_gather(output, x)
  225. return tuple(output)
  226. @staticmethod
  227. def backward(ctx, *grads):
  228. all_gradients = torch.stack(grads)
  229. torch.distributed.all_reduce(all_gradients)
  230. return all_gradients[torch.distributed.get_rank()]
  231. def all_gather_with_grad(tensors):
  232. """
  233. Performs all_gather operation on the provided tensors.
  234. Graph remains connected for backward grad computation.
  235. """
  236. # Queue the gathered tensors
  237. world_size = torch.distributed.get_world_size()
  238. # There is no need for reduction in the single-proc case
  239. if world_size == 1:
  240. return tensors
  241. tensor_all = GatherLayer.apply(tensors)
  242. return torch.cat(tensor_all, dim=0)