blip_nlvr.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. from models.med import BertConfig
  2. from models.nlvr_encoder import BertModel
  3. from models.vit import interpolate_pos_embed
  4. from models.blip import create_vit, init_tokenizer, is_url
  5. from timm.models.hub import download_cached_file
  6. import torch
  7. from torch import nn
  8. import torch.nn.functional as F
  9. from transformers import BertTokenizer
  10. import numpy as np
  11. class BLIP_NLVR(nn.Module):
  12. def __init__(self,
  13. med_config = 'configs/med_config.json',
  14. image_size = 480,
  15. vit = 'base',
  16. vit_grad_ckpt = False,
  17. vit_ckpt_layer = 0,
  18. ):
  19. """
  20. Args:
  21. med_config (str): path for the mixture of encoder-decoder model's configuration file
  22. image_size (int): input image size
  23. vit (str): model size of vision transformer
  24. """
  25. super().__init__()
  26. self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1)
  27. self.tokenizer = init_tokenizer()
  28. med_config = BertConfig.from_json_file(med_config)
  29. med_config.encoder_width = vision_width
  30. self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
  31. self.cls_head = nn.Sequential(
  32. nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size),
  33. nn.ReLU(),
  34. nn.Linear(self.text_encoder.config.hidden_size, 2)
  35. )
  36. def forward(self, image, text, targets, train=True):
  37. image_embeds = self.visual_encoder(image)
  38. image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
  39. image0_embeds, image1_embeds = torch.split(image_embeds,targets.size(0))
  40. text = self.tokenizer(text, padding='longest', return_tensors="pt").to(image.device)
  41. text.input_ids[:,0] = self.tokenizer.enc_token_id
  42. output = self.text_encoder(text.input_ids,
  43. attention_mask = text.attention_mask,
  44. encoder_hidden_states = [image0_embeds,image1_embeds],
  45. encoder_attention_mask = [image_atts[:image0_embeds.size(0)],
  46. image_atts[image0_embeds.size(0):]],
  47. return_dict = True,
  48. )
  49. hidden_state = output.last_hidden_state[:,0,:]
  50. prediction = self.cls_head(hidden_state)
  51. if train:
  52. loss = F.cross_entropy(prediction, targets)
  53. return loss
  54. else:
  55. return prediction
  56. def blip_nlvr(pretrained='',**kwargs):
  57. model = BLIP_NLVR(**kwargs)
  58. if pretrained:
  59. model,msg = load_checkpoint(model,pretrained)
  60. print("missing keys:")
  61. print(msg.missing_keys)
  62. return model
  63. def load_checkpoint(model,url_or_filename):
  64. if is_url(url_or_filename):
  65. cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
  66. checkpoint = torch.load(cached_file, map_location='cpu')
  67. elif os.path.isfile(url_or_filename):
  68. checkpoint = torch.load(url_or_filename, map_location='cpu')
  69. else:
  70. raise RuntimeError('checkpoint url or path is invalid')
  71. state_dict = checkpoint['model']
  72. state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
  73. for key in list(state_dict.keys()):
  74. if 'crossattention.self.' in key:
  75. new_key0 = key.replace('self','self0')
  76. new_key1 = key.replace('self','self1')
  77. state_dict[new_key0] = state_dict[key]
  78. state_dict[new_key1] = state_dict[key]
  79. elif 'crossattention.output.dense.' in key:
  80. new_key0 = key.replace('dense','dense0')
  81. new_key1 = key.replace('dense','dense1')
  82. state_dict[new_key0] = state_dict[key]
  83. state_dict[new_key1] = state_dict[key]
  84. msg = model.load_state_dict(state_dict,strict=False)
  85. print('load checkpoint from %s'%url_or_filename)
  86. return model,msg