visual_processing_service.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. # ==================== visual_processing_service.py (WITH CACHE CONTROL) ====================
  2. import torch
  3. import numpy as np
  4. import requests
  5. from io import BytesIO
  6. from PIL import Image
  7. from typing import Dict, List, Optional, Tuple
  8. import logging
  9. from transformers import CLIPProcessor, CLIPModel
  10. from sklearn.cluster import KMeans
  11. # ⚡ IMPORT CACHE CONFIGURATION
  12. from .cache_config import ENABLE_CLIP_MODEL_CACHE
  13. logger = logging.getLogger(__name__)
  14. import os
  15. os.environ['TOKENIZERS_PARALLELISM'] = 'false'
  16. import warnings
  17. warnings.filterwarnings('ignore')
  18. class VisualProcessingService:
  19. """Service for extracting visual attributes from product images using CLIP with smart subcategory detection."""
  20. # ⚡ Class-level caching (controlled by cache_config)
  21. _clip_model = None
  22. _clip_processor = None
  23. _device = None
  24. # Define hierarchical category structure with subcategories
  25. CATEGORY_ATTRIBUTES = {
  26. "clothing": {
  27. "subcategories": {
  28. "tops": {
  29. "products": ["t-shirt", "shirt", "blouse", "top", "sweater", "hoodie", "tank top", "polo shirt"],
  30. "attributes": {
  31. "pattern": ["solid color", "striped", "checkered", "graphic print", "floral", "geometric", "plain", "logo print"],
  32. "material": ["cotton", "polyester", "silk", "wool", "linen", "blend", "knit"],
  33. "style": ["casual", "formal", "sporty", "streetwear", "elegant", "vintage", "minimalist"],
  34. "fit": ["slim fit", "regular fit", "loose fit", "oversized", "fitted"],
  35. "neckline": ["crew neck", "v-neck", "round neck", "collar", "scoop neck", "henley"],
  36. "sleeve_type": ["short sleeve", "long sleeve", "sleeveless", "3/4 sleeve", "cap sleeve"],
  37. "closure_type": ["button-up", "zipper", "pull-on", "snap button"]
  38. }
  39. },
  40. "bottoms": {
  41. "products": ["jeans", "pants", "trousers", "shorts", "chinos", "cargo pants", "leggings"],
  42. "attributes": {
  43. "pattern": ["solid color", "distressed", "faded", "plain", "washed", "dark wash", "light wash"],
  44. "material": ["denim", "cotton", "polyester", "wool", "blend", "twill", "corduroy"],
  45. "style": ["casual", "formal", "sporty", "vintage", "modern", "workwear"],
  46. "fit": ["slim fit", "regular fit", "loose fit", "skinny", "bootcut", "straight leg", "relaxed fit"],
  47. "rise": ["high rise", "mid rise", "low rise"],
  48. "closure_type": ["button fly", "zipper fly", "elastic waist", "drawstring"],
  49. "length": ["full length", "cropped", "ankle length", "capri"]
  50. }
  51. },
  52. "dresses_skirts": {
  53. "products": ["dress", "skirt", "gown", "sundress", "maxi dress", "mini skirt"],
  54. "attributes": {
  55. "pattern": ["solid color", "floral", "striped", "geometric", "plain", "printed", "polka dot"],
  56. "material": ["cotton", "silk", "polyester", "linen", "blend", "chiffon", "satin"],
  57. "style": ["casual", "formal", "cocktail", "bohemian", "vintage", "elegant", "party"],
  58. "fit": ["fitted", "loose", "a-line", "bodycon", "flowy", "wrap"],
  59. "neckline": ["crew neck", "v-neck", "scoop neck", "halter", "off-shoulder", "sweetheart"],
  60. "sleeve_type": ["short sleeve", "long sleeve", "sleeveless", "3/4 sleeve", "flutter sleeve"],
  61. "length": ["mini", "midi", "maxi", "knee-length", "floor-length"]
  62. }
  63. },
  64. "outerwear": {
  65. "products": ["jacket", "coat", "blazer", "windbreaker", "parka", "bomber jacket", "denim jacket"],
  66. "attributes": {
  67. "pattern": ["solid color", "plain", "quilted", "textured"],
  68. "material": ["leather", "denim", "wool", "polyester", "cotton", "nylon", "fleece"],
  69. "style": ["casual", "formal", "sporty", "vintage", "military", "biker"],
  70. "fit": ["slim fit", "regular fit", "oversized", "cropped"],
  71. "closure_type": ["zipper", "button", "snap button", "toggle"],
  72. "length": ["cropped", "hip length", "thigh length", "knee length"]
  73. }
  74. }
  75. }
  76. },
  77. "footwear": {
  78. "products": ["sneakers", "boots", "sandals", "heels", "loafers", "flats", "slippers"],
  79. "attributes": {
  80. "material": ["leather", "canvas", "suede", "synthetic", "rubber", "mesh"],
  81. "style": ["casual", "formal", "athletic", "vintage", "modern"],
  82. "closure_type": ["lace-up", "slip-on", "velcro", "buckle", "zipper"],
  83. "toe_style": ["round toe", "pointed toe", "square toe", "open toe", "closed toe"]
  84. }
  85. },
  86. "tools": {
  87. "products": ["screwdriver", "hammer", "wrench", "pliers", "drill", "saw", "measuring tape"],
  88. "attributes": {
  89. "material": ["steel", "aluminum", "plastic", "rubber", "chrome", "iron"],
  90. "type": ["manual", "electric", "pneumatic", "cordless", "corded"],
  91. "finish": ["chrome plated", "powder coated", "stainless steel", "painted"],
  92. "handle_type": ["rubber grip", "plastic", "wooden", "ergonomic", "cushioned"]
  93. }
  94. },
  95. "electronics": {
  96. "products": ["phone", "laptop", "tablet", "headphones", "speaker", "camera", "smartwatch", "earbuds"],
  97. "attributes": {
  98. "material": ["plastic", "metal", "glass", "aluminum", "rubber", "silicone"],
  99. "style": ["modern", "minimalist", "sleek", "industrial", "vintage"],
  100. "finish": ["matte", "glossy", "metallic", "textured", "transparent"],
  101. "connectivity": ["wireless", "wired", "bluetooth", "USB-C", "USB"]
  102. }
  103. },
  104. "furniture": {
  105. "products": ["chair", "table", "sofa", "bed", "desk", "shelf", "cabinet", "bench"],
  106. "attributes": {
  107. "material": ["wood", "metal", "glass", "plastic", "fabric", "leather", "rattan"],
  108. "style": ["modern", "traditional", "industrial", "rustic", "contemporary", "vintage", "scandinavian"],
  109. "finish": ["natural wood", "painted", "stained", "laminated", "upholstered", "polished"]
  110. }
  111. }
  112. }
  113. def __init__(self):
  114. pass
  115. @classmethod
  116. def _get_device(cls):
  117. """Get optimal device."""
  118. if cls._device is None:
  119. cls._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  120. logger.info(f"Visual Processing using device: {cls._device}")
  121. return cls._device
  122. # ==================== visual_processing_service.py ====================
  123. @classmethod
  124. def _get_clip_model(cls):
  125. """
  126. 🔥 ALWAYS cache CLIP model (ignores global cache setting).
  127. This is a 400MB model that takes 30-60s to load.
  128. """
  129. if cls._clip_model is None:
  130. import time
  131. start = time.time()
  132. logger.info("📥 Loading CLIP model from HuggingFace...")
  133. cls._clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
  134. cls._clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
  135. device = cls._get_device()
  136. cls._clip_model.to(device)
  137. cls._clip_model.eval()
  138. load_time = time.time() - start
  139. logger.info(f"✓ CLIP model loaded in {load_time:.1f}s and cached in memory")
  140. else:
  141. logger.debug("✓ Using cached CLIP model")
  142. return cls._clip_model, cls._clip_processor
  143. @classmethod
  144. def clear_clip_cache(cls):
  145. """Clear the cached CLIP model to free memory."""
  146. if cls._clip_model is not None:
  147. del cls._clip_model
  148. del cls._clip_processor
  149. cls._clip_model = None
  150. cls._clip_processor = None
  151. if torch.cuda.is_available():
  152. torch.cuda.empty_cache()
  153. logger.info("✓ CLIP model cache cleared")
  154. def download_image(self, image_url: str) -> Optional[Image.Image]:
  155. """Download image from URL."""
  156. try:
  157. response = requests.get(image_url, timeout=10)
  158. response.raise_for_status()
  159. image = Image.open(BytesIO(response.content)).convert('RGB')
  160. return image
  161. except Exception as e:
  162. logger.error(f"Error downloading image from {image_url}: {str(e)}")
  163. return None
  164. def extract_dominant_colors(self, image: Image.Image, n_colors: int = 3) -> List[Dict]:
  165. """Extract dominant colors using K-means clustering."""
  166. try:
  167. img_small = image.resize((150, 150))
  168. img_array = np.array(img_small)
  169. pixels = img_array.reshape(-1, 3)
  170. kmeans = KMeans(n_clusters=n_colors, random_state=42, n_init=5)
  171. kmeans.fit(pixels)
  172. colors = []
  173. labels_counts = np.bincount(kmeans.labels_)
  174. for i, center in enumerate(kmeans.cluster_centers_):
  175. rgb = tuple(center.astype(int))
  176. color_name = self._get_color_name_simple(rgb)
  177. percentage = float(labels_counts[i] / len(kmeans.labels_) * 100)
  178. colors.append({
  179. "name": color_name,
  180. "rgb": rgb,
  181. "percentage": round(percentage, 2)
  182. })
  183. colors.sort(key=lambda x: x['percentage'], reverse=True)
  184. return colors
  185. except Exception as e:
  186. logger.error(f"Error extracting colors: {str(e)}")
  187. return []
  188. def _get_color_name_simple(self, rgb: Tuple[int, int, int]) -> str:
  189. """Map RGB values to basic color names."""
  190. r, g, b = rgb
  191. colors = {
  192. 'black': (r < 50 and g < 50 and b < 50),
  193. 'white': (r > 200 and g > 200 and b > 200),
  194. 'gray': (abs(r - g) < 30 and abs(g - b) < 30 and abs(r - b) < 30 and 50 <= r <= 200),
  195. 'red': (r > 150 and g < 100 and b < 100),
  196. 'green': (g > 150 and r < 100 and b < 100),
  197. 'blue': (b > 150 and r < 100 and g < 100),
  198. 'yellow': (r > 200 and g > 200 and b < 100),
  199. 'orange': (r > 200 and 100 < g < 200 and b < 100),
  200. 'purple': (r > 100 and b > 100 and g < 100),
  201. 'pink': (r > 200 and 100 < g < 200 and 100 < b < 200),
  202. 'brown': (50 < r < 150 and 30 < g < 100 and b < 80),
  203. 'cyan': (r < 100 and g > 150 and b > 150),
  204. 'beige': (180 < r < 240 and 160 < g < 220 and 120 < b < 180),
  205. }
  206. for color_name, condition in colors.items():
  207. if condition:
  208. return color_name
  209. if r > g and r > b:
  210. return 'red'
  211. elif g > r and g > b:
  212. return 'green'
  213. elif b > r and b > g:
  214. return 'blue'
  215. else:
  216. return 'gray'
  217. def classify_with_clip(
  218. self,
  219. image: Image.Image,
  220. candidates: List[str],
  221. attribute_name: str,
  222. confidence_threshold: float = 0.15
  223. ) -> Dict:
  224. """Use CLIP to classify image against candidate labels."""
  225. try:
  226. model, processor = self._get_clip_model()
  227. device = self._get_device()
  228. batch_size = 16
  229. all_results = []
  230. for i in range(0, len(candidates), batch_size):
  231. batch_candidates = candidates[i:i + batch_size]
  232. inputs = processor(
  233. text=batch_candidates,
  234. images=image,
  235. return_tensors="pt",
  236. padding=True
  237. )
  238. inputs = {k: v.to(device) for k, v in inputs.items()}
  239. with torch.no_grad():
  240. outputs = model(**inputs)
  241. logits_per_image = outputs.logits_per_image
  242. probs = logits_per_image.softmax(dim=1).cpu()
  243. for j, prob in enumerate(probs[0]):
  244. if prob.item() > confidence_threshold:
  245. all_results.append({
  246. "value": batch_candidates[j],
  247. "confidence": round(float(prob.item()), 3)
  248. })
  249. all_results.sort(key=lambda x: x['confidence'], reverse=True)
  250. return {
  251. "attribute": attribute_name,
  252. "predictions": all_results[:3]
  253. }
  254. except Exception as e:
  255. logger.error(f"Error in CLIP classification for {attribute_name}: {str(e)}")
  256. return {"attribute": attribute_name, "predictions": []}
  257. def detect_category_and_subcategory(self, image: Image.Image) -> Tuple[str, str, str, float]:
  258. """
  259. Hierarchically detect category, subcategory, and specific product.
  260. Returns: (category, subcategory, product_type, confidence)
  261. """
  262. main_categories = list(self.CATEGORY_ATTRIBUTES.keys())
  263. category_prompts = [f"a photo of {cat}" for cat in main_categories]
  264. result = self.classify_with_clip(image, category_prompts, "main_category", confidence_threshold=0.10)
  265. if not result["predictions"]:
  266. return "unknown", "unknown", "unknown", 0.0
  267. detected_category = result["predictions"][0]["value"].replace("a photo of ", "")
  268. category_confidence = result["predictions"][0]["confidence"]
  269. logger.info(f"Step 1 - Main category detected: {detected_category} (confidence: {category_confidence:.3f})")
  270. if detected_category == "clothing":
  271. subcategories = self.CATEGORY_ATTRIBUTES["clothing"]["subcategories"]
  272. all_products = []
  273. product_to_subcategory = {}
  274. for subcat, subcat_data in subcategories.items():
  275. for product in subcat_data["products"]:
  276. prompt = f"a photo of {product}"
  277. all_products.append(prompt)
  278. product_to_subcategory[prompt] = subcat
  279. product_result = self.classify_with_clip(
  280. image,
  281. all_products,
  282. "product_type",
  283. confidence_threshold=0.12
  284. )
  285. if product_result["predictions"]:
  286. best_match = product_result["predictions"][0]
  287. product_prompt = best_match["value"]
  288. product_type = product_prompt.replace("a photo of ", "")
  289. subcategory = product_to_subcategory[product_prompt]
  290. product_confidence = best_match["confidence"]
  291. logger.info(f"Step 2 - Detected: {subcategory} > {product_type} (confidence: {product_confidence:.3f})")
  292. return detected_category, subcategory, product_type, product_confidence
  293. else:
  294. logger.warning("Could not detect specific product type for clothing")
  295. return detected_category, "unknown", "unknown", category_confidence
  296. else:
  297. category_data = self.CATEGORY_ATTRIBUTES[detected_category]
  298. if "products" in category_data:
  299. products = category_data["products"]
  300. product_prompts = [f"a photo of {p}" for p in products]
  301. product_result = self.classify_with_clip(
  302. image,
  303. product_prompts,
  304. "product_type",
  305. confidence_threshold=0.12
  306. )
  307. if product_result["predictions"]:
  308. best_match = product_result["predictions"][0]
  309. product_type = best_match["value"].replace("a photo of ", "")
  310. logger.info(f"Step 2 - Detected: {detected_category} > {product_type}")
  311. return detected_category, "none", product_type, best_match["confidence"]
  312. return detected_category, "unknown", "unknown", category_confidence
  313. def process_image(
  314. self,
  315. image_url: str,
  316. product_type_hint: Optional[str] = None
  317. ) -> Dict:
  318. """
  319. Main method to process image and extract visual attributes.
  320. Uses hierarchical detection to extract only relevant attributes.
  321. """
  322. import time
  323. start_time = time.time()
  324. try:
  325. image = self.download_image(image_url)
  326. if image is None:
  327. return {
  328. "visual_attributes": {},
  329. "error": "Failed to download image"
  330. }
  331. visual_attributes = {}
  332. detailed_predictions = {}
  333. category, subcategory, product_type, confidence = self.detect_category_and_subcategory(image)
  334. if confidence < 0.10:
  335. logger.warning(f"Low confidence in detection ({confidence:.3f}). Returning basic attributes only.")
  336. colors = self.extract_dominant_colors(image, n_colors=3)
  337. if colors:
  338. visual_attributes["primary_color"] = colors[0]["name"]
  339. visual_attributes["color_palette"] = [c["name"] for c in colors]
  340. return {
  341. "visual_attributes": visual_attributes,
  342. "detection_confidence": confidence,
  343. "warning": "Low confidence detection",
  344. "processing_time": round(time.time() - start_time, 2)
  345. }
  346. visual_attributes["product_type"] = product_type
  347. visual_attributes["category"] = category
  348. if subcategory != "none" and subcategory != "unknown":
  349. visual_attributes["subcategory"] = subcategory
  350. colors = self.extract_dominant_colors(image, n_colors=3)
  351. if colors:
  352. visual_attributes["primary_color"] = colors[0]["name"]
  353. visual_attributes["color_palette"] = [c["name"] for c in colors[:3]]
  354. visual_attributes["color_distribution"] = [
  355. {"color": c["name"], "percentage": c["percentage"]}
  356. for c in colors
  357. ]
  358. attributes_config = None
  359. if category == "clothing":
  360. if subcategory in self.CATEGORY_ATTRIBUTES["clothing"]["subcategories"]:
  361. attributes_config = self.CATEGORY_ATTRIBUTES["clothing"]["subcategories"][subcategory]["attributes"]
  362. logger.info(f"Using attributes for subcategory: {subcategory}")
  363. else:
  364. logger.warning(f"Unknown subcategory: {subcategory}. Skipping attribute extraction.")
  365. elif category in self.CATEGORY_ATTRIBUTES:
  366. if "attributes" in self.CATEGORY_ATTRIBUTES[category]:
  367. attributes_config = self.CATEGORY_ATTRIBUTES[category]["attributes"]
  368. logger.info(f"Using attributes for category: {category}")
  369. if attributes_config:
  370. for attr_name, attr_values in attributes_config.items():
  371. result = self.classify_with_clip(
  372. image,
  373. attr_values,
  374. attr_name,
  375. confidence_threshold=0.20
  376. )
  377. if result["predictions"]:
  378. best_prediction = result["predictions"][0]
  379. if best_prediction["confidence"] > 0.20:
  380. visual_attributes[attr_name] = best_prediction["value"]
  381. detailed_predictions[attr_name] = result
  382. processing_time = time.time() - start_time
  383. logger.info(f"✓ Processing complete in {processing_time:.2f}s. Extracted {len(visual_attributes)} attributes.")
  384. return {
  385. "visual_attributes": visual_attributes,
  386. "detailed_predictions": detailed_predictions,
  387. "detection_confidence": confidence,
  388. "processing_time": round(processing_time, 2),
  389. "cache_status": "enabled" if ENABLE_CLIP_MODEL_CACHE else "disabled"
  390. }
  391. except Exception as e:
  392. logger.error(f"Error processing image: {str(e)}")
  393. return {
  394. "visual_attributes": {},
  395. "error": str(e),
  396. "processing_time": round(time.time() - start_time, 2)
  397. }