services.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. # background_remover/services.py
  2. import torch
  3. from PIL import Image
  4. from torchvision import transforms
  5. from transformers import AutoModelForImageSegmentation
  6. class BiRefNetService:
  7. _instance = None
  8. def __new__(cls):
  9. if cls._instance is None:
  10. cls._instance = super(BiRefNetService, cls).__new__(cls)
  11. cls._instance._init_model()
  12. return cls._instance
  13. def _init_model(self):
  14. self.device = "cuda" if torch.cuda.is_available() else "cpu"
  15. self.model = AutoModelForImageSegmentation.from_pretrained(
  16. "ZhengPeng7/BiRefNet", trust_remote_code=True
  17. )
  18. self.model.to(self.device)
  19. self.model.eval()
  20. self.transform = transforms.Compose([
  21. transforms.Resize((1024, 1024)),
  22. transforms.ToTensor(),
  23. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
  24. ])
  25. def remove_background(self, input_image: Image.Image) -> Image.Image:
  26. image_size = input_image.size
  27. input_tensor = self.transform(input_image.convert("RGB")).unsqueeze(0).to(self.device)
  28. with torch.no_grad():
  29. preds = self.model(input_tensor)[-1].sigmoid().cpu()
  30. mask = transforms.ToPILImage()(preds[0].squeeze()).resize(image_size)
  31. # Create white background
  32. white_bg = Image.new("RGB", image_size, (255, 255, 255))
  33. white_bg.paste(input_image, (0, 0), mask)
  34. return white_bg