blip_service.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. # import torch
  2. # from PIL import Image
  3. # from torchvision import transforms
  4. # from BLIP.models.blip import blip_decoder
  5. # import os
  6. # class BLIPService:
  7. # _model = None
  8. # _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  9. # def __init__(self):
  10. # if BLIPService._model is None:
  11. # BASE_DIR = os.path.dirname(os.path.abspath(__file__))
  12. # config_path = os.path.join(BASE_DIR, "../BLIP/configs/med_config.json")
  13. # checkpoint_path = os.path.join(BASE_DIR, "../BLIP/models/model_base_caption.pth")
  14. # model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
  15. # if not os.path.exists(config_path):
  16. # raise FileNotFoundError(f"BLIP config not found at {config_path}")
  17. # # if not os.path.exists(checkpoint_path):
  18. # # raise FileNotFoundError(f"BLIP checkpoint not found at {checkpoint_path}")
  19. # BLIPService._model = blip_decoder(
  20. # # pretrained=checkpoint_path,
  21. # med_config=config_path,
  22. # pretrained=model_url,
  23. # image_size=384,
  24. # vit="base"
  25. # )
  26. # BLIPService._model.eval()
  27. # BLIPService._model.to(self._device)
  28. # self.transform = transforms.Compose([
  29. # transforms.Resize((384, 384)),
  30. # transforms.ToTensor()
  31. # ])
  32. # def generate_caption(self, image: Image.Image):
  33. # model = BLIPService._model
  34. # model.eval()
  35. # import torch
  36. # from torchvision import transforms
  37. # transform = transforms.Compose([
  38. # transforms.Resize((384, 384)),
  39. # transforms.ToTensor(),
  40. # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
  41. # ])
  42. # image_tensor = transform(image).unsqueeze(0)
  43. # device = "cuda" if torch.cuda.is_available() else "cpu"
  44. # image_tensor = image_tensor.to(device)
  45. # model.to(device)
  46. # with torch.no_grad():
  47. # caption = model.generate(image_tensor, sample=False, num_beams=3, max_length=20, min_length=5)
  48. # return caption[0]
  49. # # def generate_caption(self, image: Image.Image) -> str:
  50. # # image = image.convert("RGB")
  51. # # image = self.transform(image).unsqueeze(0).to(self._device)
  52. # # with torch.no_grad():
  53. # # caption = self._model.generate(
  54. # # image,
  55. # # sample=False,
  56. # # num_beams=3
  57. # # )
  58. # # return caption[0]
  59. # blip_service.py
  60. from PIL import Image
  61. import torch
  62. from transformers import BlipProcessor, BlipForConditionalGeneration
  63. class BLIPServiceHF:
  64. _instance = None
  65. def __new__(cls, *args, **kwargs):
  66. # Singleton pattern to avoid reloading the model each time
  67. if cls._instance is None:
  68. cls._instance = super(BLIPServiceHF, cls).__new__(cls)
  69. cls._instance._init_model()
  70. return cls._instance
  71. def _init_model(self):
  72. self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
  73. self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
  74. self.model.eval()
  75. self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  76. self.model.to(self.device)
  77. def generate_caption(self, image: Image.Image) -> str:
  78. inputs = self.processor(image, return_tensors="pt").to(self.device)
  79. with torch.no_grad():
  80. out = self.model.generate(**inputs)
  81. caption = self.processor.decode(out[0], skip_special_tokens=True)
  82. return caption