| 1234567891011121314151617181920212223242526272829303132333435363738394041 |
- # background_remover/services.py
- import torch
- from PIL import Image
- from torchvision import transforms
- from transformers import AutoModelForImageSegmentation
- class BiRefNetService:
- _instance = None
- def __new__(cls):
- if cls._instance is None:
- cls._instance = super(BiRefNetService, cls).__new__(cls)
- cls._instance._init_model()
- return cls._instance
- def _init_model(self):
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
- self.model = AutoModelForImageSegmentation.from_pretrained(
- "ZhengPeng7/BiRefNet", trust_remote_code=True
- )
- self.model.to(self.device)
- self.model.eval()
- self.transform = transforms.Compose([
- transforms.Resize((1024, 1024)),
- transforms.ToTensor(),
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
- ])
- def remove_background(self, input_image: Image.Image) -> Image.Image:
- image_size = input_image.size
- input_tensor = self.transform(input_image.convert("RGB")).unsqueeze(0).to(self.device)
- with torch.no_grad():
- preds = self.model(input_tensor)[-1].sigmoid().cpu()
- mask = transforms.ToPILImage()(preds[0].squeeze()).resize(image_size)
-
- # Create white background
- white_bg = Image.new("RGB", image_size, (255, 255, 255))
- white_bg.paste(input_image, (0, 0), mask)
- return white_bg
|