blip_service_itm.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536
  1. import torch
  2. from PIL import Image
  3. from torchvision import transforms
  4. from torchvision.transforms.functional import InterpolationMode
  5. from BLIP.models.blip import blip_decoder
  6. class BLIPDecoderService:
  7. _instance = None
  8. _model = None
  9. def __new__(cls):
  10. if cls._instance is None:
  11. cls._instance = super(BLIPDecoderService, cls).__new__(cls)
  12. cls.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  13. cls.image_size = 384
  14. # Load Model
  15. model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth'
  16. cls._model = blip_decoder(pretrained=model_url, image_size=cls.image_size, vit='base')
  17. cls._model.eval()
  18. cls._model = cls._model.to(cls.device)
  19. # Preprocess
  20. cls.transform = transforms.Compose([
  21. transforms.Resize((cls.image_size, cls.image_size), interpolation=InterpolationMode.BICUBIC),
  22. transforms.ToTensor(),
  23. transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
  24. ])
  25. return cls._instance
  26. def generate_caption(self, pil_image):
  27. image = self.transform(pil_image).unsqueeze(0).to(self.device)
  28. with torch.no_grad():
  29. # Beam search as per your requirement
  30. caption = self._model.generate(image, sample=False, num_beams=3, max_length=20, min_length=5)
  31. return caption[0]