blip_vqa.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. from models.med import BertConfig, BertModel, BertLMHeadModel
  2. from models.blip import create_vit, init_tokenizer, load_checkpoint
  3. import torch
  4. from torch import nn
  5. import torch.nn.functional as F
  6. from transformers import BertTokenizer
  7. import numpy as np
  8. class BLIP_VQA(nn.Module):
  9. def __init__(self,
  10. med_config = 'configs/med_config.json',
  11. image_size = 480,
  12. vit = 'base',
  13. vit_grad_ckpt = False,
  14. vit_ckpt_layer = 0,
  15. ):
  16. """
  17. Args:
  18. med_config (str): path for the mixture of encoder-decoder model's configuration file
  19. image_size (int): input image size
  20. vit (str): model size of vision transformer
  21. """
  22. super().__init__()
  23. self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1)
  24. self.tokenizer = init_tokenizer()
  25. encoder_config = BertConfig.from_json_file(med_config)
  26. encoder_config.encoder_width = vision_width
  27. self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
  28. decoder_config = BertConfig.from_json_file(med_config)
  29. self.text_decoder = BertLMHeadModel(config=decoder_config)
  30. def forward(self, image, question, answer=None, n=None, weights=None, train=True, inference='rank', k_test=128):
  31. image_embeds = self.visual_encoder(image)
  32. image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
  33. question = self.tokenizer(question, padding='longest', truncation=True, max_length=35,
  34. return_tensors="pt").to(image.device)
  35. question.input_ids[:,0] = self.tokenizer.enc_token_id
  36. if train:
  37. '''
  38. n: number of answers for each question
  39. weights: weight for each answer
  40. '''
  41. answer = self.tokenizer(answer, padding='longest', return_tensors="pt").to(image.device)
  42. answer.input_ids[:,0] = self.tokenizer.bos_token_id
  43. answer_targets = answer.input_ids.masked_fill(answer.input_ids == self.tokenizer.pad_token_id, -100)
  44. question_output = self.text_encoder(question.input_ids,
  45. attention_mask = question.attention_mask,
  46. encoder_hidden_states = image_embeds,
  47. encoder_attention_mask = image_atts,
  48. return_dict = True)
  49. question_states = []
  50. question_atts = []
  51. for b, n in enumerate(n):
  52. question_states += [question_output.last_hidden_state[b]]*n
  53. question_atts += [question.attention_mask[b]]*n
  54. question_states = torch.stack(question_states,0)
  55. question_atts = torch.stack(question_atts,0)
  56. answer_output = self.text_decoder(answer.input_ids,
  57. attention_mask = answer.attention_mask,
  58. encoder_hidden_states = question_states,
  59. encoder_attention_mask = question_atts,
  60. labels = answer_targets,
  61. return_dict = True,
  62. reduction = 'none',
  63. )
  64. loss = weights * answer_output.loss
  65. loss = loss.sum()/image.size(0)
  66. return loss
  67. else:
  68. question_output = self.text_encoder(question.input_ids,
  69. attention_mask = question.attention_mask,
  70. encoder_hidden_states = image_embeds,
  71. encoder_attention_mask = image_atts,
  72. return_dict = True)
  73. if inference=='generate':
  74. num_beams = 3
  75. question_states = question_output.last_hidden_state.repeat_interleave(num_beams,dim=0)
  76. question_atts = torch.ones(question_states.size()[:-1],dtype=torch.long).to(question_states.device)
  77. model_kwargs = {"encoder_hidden_states": question_states, "encoder_attention_mask":question_atts}
  78. bos_ids = torch.full((image.size(0),1),fill_value=self.tokenizer.bos_token_id,device=image.device)
  79. outputs = self.text_decoder.generate(input_ids=bos_ids,
  80. max_length=10,
  81. min_length=1,
  82. num_beams=num_beams,
  83. eos_token_id=self.tokenizer.sep_token_id,
  84. pad_token_id=self.tokenizer.pad_token_id,
  85. **model_kwargs)
  86. answers = []
  87. for output in outputs:
  88. answer = self.tokenizer.decode(output, skip_special_tokens=True)
  89. answers.append(answer)
  90. return answers
  91. elif inference=='rank':
  92. max_ids = self.rank_answer(question_output.last_hidden_state, question.attention_mask,
  93. answer.input_ids, answer.attention_mask, k_test)
  94. return max_ids
  95. def rank_answer(self, question_states, question_atts, answer_ids, answer_atts, k):
  96. num_ques = question_states.size(0)
  97. start_ids = answer_ids[0,0].repeat(num_ques,1) # bos token
  98. start_output = self.text_decoder(start_ids,
  99. encoder_hidden_states = question_states,
  100. encoder_attention_mask = question_atts,
  101. return_dict = True,
  102. reduction = 'none')
  103. logits = start_output.logits[:,0,:] # first token's logit
  104. # topk_probs: top-k probability
  105. # topk_ids: [num_question, k]
  106. answer_first_token = answer_ids[:,1]
  107. prob_first_token = F.softmax(logits,dim=1).index_select(dim=1, index=answer_first_token)
  108. topk_probs, topk_ids = prob_first_token.topk(k,dim=1)
  109. # answer input: [num_question*k, answer_len]
  110. input_ids = []
  111. input_atts = []
  112. for b, topk_id in enumerate(topk_ids):
  113. input_ids.append(answer_ids.index_select(dim=0, index=topk_id))
  114. input_atts.append(answer_atts.index_select(dim=0, index=topk_id))
  115. input_ids = torch.cat(input_ids,dim=0)
  116. input_atts = torch.cat(input_atts,dim=0)
  117. targets_ids = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100)
  118. # repeat encoder's output for top-k answers
  119. question_states = tile(question_states, 0, k)
  120. question_atts = tile(question_atts, 0, k)
  121. output = self.text_decoder(input_ids,
  122. attention_mask = input_atts,
  123. encoder_hidden_states = question_states,
  124. encoder_attention_mask = question_atts,
  125. labels = targets_ids,
  126. return_dict = True,
  127. reduction = 'none')
  128. log_probs_sum = -output.loss
  129. log_probs_sum = log_probs_sum.view(num_ques,k)
  130. max_topk_ids = log_probs_sum.argmax(dim=1)
  131. max_ids = topk_ids[max_topk_ids>=0,max_topk_ids]
  132. return max_ids
  133. def blip_vqa(pretrained='',**kwargs):
  134. model = BLIP_VQA(**kwargs)
  135. if pretrained:
  136. model,msg = load_checkpoint(model,pretrained)
  137. # assert(len(msg.missing_keys)==0)
  138. return model
  139. def tile(x, dim, n_tile):
  140. init_dim = x.size(dim)
  141. repeat_idx = [1] * x.dim()
  142. repeat_idx[dim] = n_tile
  143. x = x.repeat(*(repeat_idx))
  144. order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
  145. return torch.index_select(x, dim, order_index.to(x.device))