| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102 |
- # import torch
- # from PIL import Image
- # from torchvision import transforms
- # from BLIP.models.blip import blip_decoder
- # import os
- # class BLIPService:
- # _model = None
- # _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- # def __init__(self):
- # if BLIPService._model is None:
- # BASE_DIR = os.path.dirname(os.path.abspath(__file__))
- # config_path = os.path.join(BASE_DIR, "../BLIP/configs/med_config.json")
- # checkpoint_path = os.path.join(BASE_DIR, "../BLIP/models/model_base_caption.pth")
- # model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
- # if not os.path.exists(config_path):
- # raise FileNotFoundError(f"BLIP config not found at {config_path}")
- # # if not os.path.exists(checkpoint_path):
- # # raise FileNotFoundError(f"BLIP checkpoint not found at {checkpoint_path}")
- # BLIPService._model = blip_decoder(
- # # pretrained=checkpoint_path,
- # med_config=config_path,
- # pretrained=model_url,
- # image_size=384,
- # vit="base"
- # )
- # BLIPService._model.eval()
- # BLIPService._model.to(self._device)
- # self.transform = transforms.Compose([
- # transforms.Resize((384, 384)),
- # transforms.ToTensor()
- # ])
- # def generate_caption(self, image: Image.Image):
- # model = BLIPService._model
- # model.eval()
- # import torch
- # from torchvision import transforms
- # transform = transforms.Compose([
- # transforms.Resize((384, 384)),
- # transforms.ToTensor(),
- # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
- # ])
- # image_tensor = transform(image).unsqueeze(0)
- # device = "cuda" if torch.cuda.is_available() else "cpu"
- # image_tensor = image_tensor.to(device)
- # model.to(device)
- # with torch.no_grad():
- # caption = model.generate(image_tensor, sample=False, num_beams=3, max_length=20, min_length=5)
- # return caption[0]
-
- # # def generate_caption(self, image: Image.Image) -> str:
- # # image = image.convert("RGB")
- # # image = self.transform(image).unsqueeze(0).to(self._device)
- # # with torch.no_grad():
- # # caption = self._model.generate(
- # # image,
- # # sample=False,
- # # num_beams=3
- # # )
- # # return caption[0]
- # blip_service.py
- from PIL import Image
- import torch
- from transformers import BlipProcessor, BlipForConditionalGeneration
- class BLIPServiceHF:
- _instance = None
- def __new__(cls, *args, **kwargs):
- # Singleton pattern to avoid reloading the model each time
- if cls._instance is None:
- cls._instance = super(BLIPServiceHF, cls).__new__(cls)
- cls._instance._init_model()
- return cls._instance
- def _init_model(self):
- self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
- self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
- self.model.eval()
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- self.model.to(self.device)
- def generate_caption(self, image: Image.Image) -> str:
- inputs = self.processor(image, return_tensors="pt").to(self.device)
- with torch.no_grad():
- out = self.model.generate(**inputs)
- caption = self.processor.decode(out[0], skip_special_tokens=True)
- return caption
|