visual_processing_service.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478
  1. # ==================== visual_processing_service.py (WITH CACHE CONTROL) ====================
  2. import torch
  3. import numpy as np
  4. import requests
  5. from io import BytesIO
  6. from PIL import Image
  7. from typing import Dict, List, Optional, Tuple
  8. import logging
  9. from transformers import CLIPProcessor, CLIPModel
  10. from sklearn.cluster import KMeans
  11. # ⚡ IMPORT CACHE CONFIGURATION
  12. from .cache_config import ENABLE_CLIP_MODEL_CACHE
  13. logger = logging.getLogger(__name__)
  14. import os
  15. os.environ['TOKENIZERS_PARALLELISM'] = 'false'
  16. import warnings
  17. warnings.filterwarnings('ignore')
  18. class VisualProcessingService:
  19. """Service for extracting visual attributes from product images using CLIP with smart subcategory detection."""
  20. # ⚡ Class-level caching (controlled by cache_config)
  21. _clip_model = None
  22. _clip_processor = None
  23. _device = None
  24. # Define hierarchical category structure with subcategories
  25. CATEGORY_ATTRIBUTES = {
  26. "clothing": {
  27. "subcategories": {
  28. "tops": {
  29. "products": ["t-shirt", "shirt", "blouse", "top", "sweater", "hoodie", "tank top", "polo shirt"],
  30. "attributes": {
  31. "pattern": ["solid color", "striped", "checkered", "graphic print", "floral", "geometric", "plain", "logo print"],
  32. "material": ["cotton", "polyester", "silk", "wool", "linen", "blend", "knit"],
  33. "style": ["casual", "formal", "sporty", "streetwear", "elegant", "vintage", "minimalist"],
  34. "fit": ["slim fit", "regular fit", "loose fit", "oversized", "fitted"],
  35. "neckline": ["crew neck", "v-neck", "round neck", "collar", "scoop neck", "henley"],
  36. "sleeve_type": ["short sleeve", "long sleeve", "sleeveless", "3/4 sleeve", "cap sleeve"],
  37. "closure_type": ["button-up", "zipper", "pull-on", "snap button"]
  38. }
  39. },
  40. "bottoms": {
  41. "products": ["jeans", "pants", "trousers", "shorts", "chinos", "cargo pants", "leggings"],
  42. "attributes": {
  43. "pattern": ["solid color", "distressed", "faded", "plain", "washed", "dark wash", "light wash"],
  44. "material": ["denim", "cotton", "polyester", "wool", "blend", "twill", "corduroy"],
  45. "style": ["casual", "formal", "sporty", "vintage", "modern", "workwear"],
  46. "fit": ["slim fit", "regular fit", "loose fit", "skinny", "bootcut", "straight leg", "relaxed fit"],
  47. "rise": ["high rise", "mid rise", "low rise"],
  48. "closure_type": ["button fly", "zipper fly", "elastic waist", "drawstring"],
  49. "length": ["full length", "cropped", "ankle length", "capri"]
  50. }
  51. },
  52. "dresses_skirts": {
  53. "products": ["dress", "skirt", "gown", "sundress", "maxi dress", "mini skirt"],
  54. "attributes": {
  55. "pattern": ["solid color", "floral", "striped", "geometric", "plain", "printed", "polka dot"],
  56. "material": ["cotton", "silk", "polyester", "linen", "blend", "chiffon", "satin"],
  57. "style": ["casual", "formal", "cocktail", "bohemian", "vintage", "elegant", "party"],
  58. "fit": ["fitted", "loose", "a-line", "bodycon", "flowy", "wrap"],
  59. "neckline": ["crew neck", "v-neck", "scoop neck", "halter", "off-shoulder", "sweetheart"],
  60. "sleeve_type": ["short sleeve", "long sleeve", "sleeveless", "3/4 sleeve", "flutter sleeve"],
  61. "length": ["mini", "midi", "maxi", "knee-length", "floor-length"]
  62. }
  63. },
  64. "outerwear": {
  65. "products": ["jacket", "coat", "blazer", "windbreaker", "parka", "bomber jacket", "denim jacket"],
  66. "attributes": {
  67. "pattern": ["solid color", "plain", "quilted", "textured"],
  68. "material": ["leather", "denim", "wool", "polyester", "cotton", "nylon", "fleece"],
  69. "style": ["casual", "formal", "sporty", "vintage", "military", "biker"],
  70. "fit": ["slim fit", "regular fit", "oversized", "cropped"],
  71. "closure_type": ["zipper", "button", "snap button", "toggle"],
  72. "length": ["cropped", "hip length", "thigh length", "knee length"]
  73. }
  74. }
  75. }
  76. },
  77. "footwear": {
  78. "products": ["sneakers", "boots", "sandals", "heels", "loafers", "flats", "slippers"],
  79. "attributes": {
  80. "material": ["leather", "canvas", "suede", "synthetic", "rubber", "mesh"],
  81. "style": ["casual", "formal", "athletic", "vintage", "modern"],
  82. "closure_type": ["lace-up", "slip-on", "velcro", "buckle", "zipper"],
  83. "toe_style": ["round toe", "pointed toe", "square toe", "open toe", "closed toe"]
  84. }
  85. },
  86. "tools": {
  87. "products": ["screwdriver", "hammer", "wrench", "pliers", "drill", "saw", "measuring tape"],
  88. "attributes": {
  89. "material": ["steel", "aluminum", "plastic", "rubber", "chrome", "iron"],
  90. "type": ["manual", "electric", "pneumatic", "cordless", "corded"],
  91. "finish": ["chrome plated", "powder coated", "stainless steel", "painted"],
  92. "handle_type": ["rubber grip", "plastic", "wooden", "ergonomic", "cushioned"]
  93. }
  94. },
  95. "electronics": {
  96. "products": ["phone", "laptop", "tablet", "headphones", "speaker", "camera", "smartwatch", "earbuds"],
  97. "attributes": {
  98. "material": ["plastic", "metal", "glass", "aluminum", "rubber", "silicone"],
  99. "style": ["modern", "minimalist", "sleek", "industrial", "vintage"],
  100. "finish": ["matte", "glossy", "metallic", "textured", "transparent"],
  101. "connectivity": ["wireless", "wired", "bluetooth", "USB-C", "USB"]
  102. }
  103. },
  104. "furniture": {
  105. "products": ["chair", "table", "sofa", "bed", "desk", "shelf", "cabinet", "bench"],
  106. "attributes": {
  107. "material": ["wood", "metal", "glass", "plastic", "fabric", "leather", "rattan"],
  108. "style": ["modern", "traditional", "industrial", "rustic", "contemporary", "vintage", "scandinavian"],
  109. "finish": ["natural wood", "painted", "stained", "laminated", "upholstered", "polished"]
  110. }
  111. }
  112. }
  113. def __init__(self):
  114. pass
  115. @classmethod
  116. def _get_device(cls):
  117. """Get optimal device."""
  118. if cls._device is None:
  119. cls._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  120. logger.info(f"Visual Processing using device: {cls._device}")
  121. return cls._device
  122. @classmethod
  123. def _get_clip_model(cls):
  124. """
  125. Lazy load CLIP model with optional class-level caching.
  126. ⚡ If caching is disabled, model is still loaded but not persisted at class level.
  127. """
  128. # ⚡ CACHE CONTROL: If caching is disabled, always reload (no persistence)
  129. if not ENABLE_CLIP_MODEL_CACHE:
  130. logger.info("⚠ CLIP model caching is DISABLED - loading fresh instance")
  131. model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
  132. processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
  133. device = cls._get_device()
  134. model.to(device)
  135. model.eval()
  136. logger.info("✓ CLIP model loaded (no caching)")
  137. return model, processor
  138. # Caching is enabled - use class-level cache
  139. if cls._clip_model is None:
  140. logger.info("Loading CLIP model (this may take a few minutes on first use)...")
  141. cls._clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
  142. cls._clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
  143. device = cls._get_device()
  144. cls._clip_model.to(device)
  145. cls._clip_model.eval()
  146. logger.info("✓ CLIP model loaded and cached successfully")
  147. else:
  148. logger.info("✓ Using cached CLIP model")
  149. return cls._clip_model, cls._clip_processor
  150. @classmethod
  151. def clear_clip_cache(cls):
  152. """Clear the cached CLIP model to free memory."""
  153. if cls._clip_model is not None:
  154. del cls._clip_model
  155. del cls._clip_processor
  156. cls._clip_model = None
  157. cls._clip_processor = None
  158. if torch.cuda.is_available():
  159. torch.cuda.empty_cache()
  160. logger.info("✓ CLIP model cache cleared")
  161. def download_image(self, image_url: str) -> Optional[Image.Image]:
  162. """Download image from URL."""
  163. try:
  164. response = requests.get(image_url, timeout=10)
  165. response.raise_for_status()
  166. image = Image.open(BytesIO(response.content)).convert('RGB')
  167. return image
  168. except Exception as e:
  169. logger.error(f"Error downloading image from {image_url}: {str(e)}")
  170. return None
  171. def extract_dominant_colors(self, image: Image.Image, n_colors: int = 3) -> List[Dict]:
  172. """Extract dominant colors using K-means clustering."""
  173. try:
  174. img_small = image.resize((150, 150))
  175. img_array = np.array(img_small)
  176. pixels = img_array.reshape(-1, 3)
  177. kmeans = KMeans(n_clusters=n_colors, random_state=42, n_init=5)
  178. kmeans.fit(pixels)
  179. colors = []
  180. labels_counts = np.bincount(kmeans.labels_)
  181. for i, center in enumerate(kmeans.cluster_centers_):
  182. rgb = tuple(center.astype(int))
  183. color_name = self._get_color_name_simple(rgb)
  184. percentage = float(labels_counts[i] / len(kmeans.labels_) * 100)
  185. colors.append({
  186. "name": color_name,
  187. "rgb": rgb,
  188. "percentage": round(percentage, 2)
  189. })
  190. colors.sort(key=lambda x: x['percentage'], reverse=True)
  191. return colors
  192. except Exception as e:
  193. logger.error(f"Error extracting colors: {str(e)}")
  194. return []
  195. def _get_color_name_simple(self, rgb: Tuple[int, int, int]) -> str:
  196. """Map RGB values to basic color names."""
  197. r, g, b = rgb
  198. colors = {
  199. 'black': (r < 50 and g < 50 and b < 50),
  200. 'white': (r > 200 and g > 200 and b > 200),
  201. 'gray': (abs(r - g) < 30 and abs(g - b) < 30 and abs(r - b) < 30 and 50 <= r <= 200),
  202. 'red': (r > 150 and g < 100 and b < 100),
  203. 'green': (g > 150 and r < 100 and b < 100),
  204. 'blue': (b > 150 and r < 100 and g < 100),
  205. 'yellow': (r > 200 and g > 200 and b < 100),
  206. 'orange': (r > 200 and 100 < g < 200 and b < 100),
  207. 'purple': (r > 100 and b > 100 and g < 100),
  208. 'pink': (r > 200 and 100 < g < 200 and 100 < b < 200),
  209. 'brown': (50 < r < 150 and 30 < g < 100 and b < 80),
  210. 'cyan': (r < 100 and g > 150 and b > 150),
  211. 'beige': (180 < r < 240 and 160 < g < 220 and 120 < b < 180),
  212. }
  213. for color_name, condition in colors.items():
  214. if condition:
  215. return color_name
  216. if r > g and r > b:
  217. return 'red'
  218. elif g > r and g > b:
  219. return 'green'
  220. elif b > r and b > g:
  221. return 'blue'
  222. else:
  223. return 'gray'
  224. def classify_with_clip(
  225. self,
  226. image: Image.Image,
  227. candidates: List[str],
  228. attribute_name: str,
  229. confidence_threshold: float = 0.15
  230. ) -> Dict:
  231. """Use CLIP to classify image against candidate labels."""
  232. try:
  233. model, processor = self._get_clip_model()
  234. device = self._get_device()
  235. batch_size = 16
  236. all_results = []
  237. for i in range(0, len(candidates), batch_size):
  238. batch_candidates = candidates[i:i + batch_size]
  239. inputs = processor(
  240. text=batch_candidates,
  241. images=image,
  242. return_tensors="pt",
  243. padding=True
  244. )
  245. inputs = {k: v.to(device) for k, v in inputs.items()}
  246. with torch.no_grad():
  247. outputs = model(**inputs)
  248. logits_per_image = outputs.logits_per_image
  249. probs = logits_per_image.softmax(dim=1).cpu()
  250. for j, prob in enumerate(probs[0]):
  251. if prob.item() > confidence_threshold:
  252. all_results.append({
  253. "value": batch_candidates[j],
  254. "confidence": round(float(prob.item()), 3)
  255. })
  256. all_results.sort(key=lambda x: x['confidence'], reverse=True)
  257. return {
  258. "attribute": attribute_name,
  259. "predictions": all_results[:3]
  260. }
  261. except Exception as e:
  262. logger.error(f"Error in CLIP classification for {attribute_name}: {str(e)}")
  263. return {"attribute": attribute_name, "predictions": []}
  264. def detect_category_and_subcategory(self, image: Image.Image) -> Tuple[str, str, str, float]:
  265. """
  266. Hierarchically detect category, subcategory, and specific product.
  267. Returns: (category, subcategory, product_type, confidence)
  268. """
  269. main_categories = list(self.CATEGORY_ATTRIBUTES.keys())
  270. category_prompts = [f"a photo of {cat}" for cat in main_categories]
  271. result = self.classify_with_clip(image, category_prompts, "main_category", confidence_threshold=0.10)
  272. if not result["predictions"]:
  273. return "unknown", "unknown", "unknown", 0.0
  274. detected_category = result["predictions"][0]["value"].replace("a photo of ", "")
  275. category_confidence = result["predictions"][0]["confidence"]
  276. logger.info(f"Step 1 - Main category detected: {detected_category} (confidence: {category_confidence:.3f})")
  277. if detected_category == "clothing":
  278. subcategories = self.CATEGORY_ATTRIBUTES["clothing"]["subcategories"]
  279. all_products = []
  280. product_to_subcategory = {}
  281. for subcat, subcat_data in subcategories.items():
  282. for product in subcat_data["products"]:
  283. prompt = f"a photo of {product}"
  284. all_products.append(prompt)
  285. product_to_subcategory[prompt] = subcat
  286. product_result = self.classify_with_clip(
  287. image,
  288. all_products,
  289. "product_type",
  290. confidence_threshold=0.12
  291. )
  292. if product_result["predictions"]:
  293. best_match = product_result["predictions"][0]
  294. product_prompt = best_match["value"]
  295. product_type = product_prompt.replace("a photo of ", "")
  296. subcategory = product_to_subcategory[product_prompt]
  297. product_confidence = best_match["confidence"]
  298. logger.info(f"Step 2 - Detected: {subcategory} > {product_type} (confidence: {product_confidence:.3f})")
  299. return detected_category, subcategory, product_type, product_confidence
  300. else:
  301. logger.warning("Could not detect specific product type for clothing")
  302. return detected_category, "unknown", "unknown", category_confidence
  303. else:
  304. category_data = self.CATEGORY_ATTRIBUTES[detected_category]
  305. if "products" in category_data:
  306. products = category_data["products"]
  307. product_prompts = [f"a photo of {p}" for p in products]
  308. product_result = self.classify_with_clip(
  309. image,
  310. product_prompts,
  311. "product_type",
  312. confidence_threshold=0.12
  313. )
  314. if product_result["predictions"]:
  315. best_match = product_result["predictions"][0]
  316. product_type = best_match["value"].replace("a photo of ", "")
  317. logger.info(f"Step 2 - Detected: {detected_category} > {product_type}")
  318. return detected_category, "none", product_type, best_match["confidence"]
  319. return detected_category, "unknown", "unknown", category_confidence
  320. def process_image(
  321. self,
  322. image_url: str,
  323. product_type_hint: Optional[str] = None
  324. ) -> Dict:
  325. """
  326. Main method to process image and extract visual attributes.
  327. Uses hierarchical detection to extract only relevant attributes.
  328. """
  329. import time
  330. start_time = time.time()
  331. try:
  332. image = self.download_image(image_url)
  333. if image is None:
  334. return {
  335. "visual_attributes": {},
  336. "error": "Failed to download image"
  337. }
  338. visual_attributes = {}
  339. detailed_predictions = {}
  340. category, subcategory, product_type, confidence = self.detect_category_and_subcategory(image)
  341. if confidence < 0.10:
  342. logger.warning(f"Low confidence in detection ({confidence:.3f}). Returning basic attributes only.")
  343. colors = self.extract_dominant_colors(image, n_colors=3)
  344. if colors:
  345. visual_attributes["primary_color"] = colors[0]["name"]
  346. visual_attributes["color_palette"] = [c["name"] for c in colors]
  347. return {
  348. "visual_attributes": visual_attributes,
  349. "detection_confidence": confidence,
  350. "warning": "Low confidence detection",
  351. "processing_time": round(time.time() - start_time, 2)
  352. }
  353. visual_attributes["product_type"] = product_type
  354. visual_attributes["category"] = category
  355. if subcategory != "none" and subcategory != "unknown":
  356. visual_attributes["subcategory"] = subcategory
  357. colors = self.extract_dominant_colors(image, n_colors=3)
  358. if colors:
  359. visual_attributes["primary_color"] = colors[0]["name"]
  360. visual_attributes["color_palette"] = [c["name"] for c in colors[:3]]
  361. visual_attributes["color_distribution"] = [
  362. {"color": c["name"], "percentage": c["percentage"]}
  363. for c in colors
  364. ]
  365. attributes_config = None
  366. if category == "clothing":
  367. if subcategory in self.CATEGORY_ATTRIBUTES["clothing"]["subcategories"]:
  368. attributes_config = self.CATEGORY_ATTRIBUTES["clothing"]["subcategories"][subcategory]["attributes"]
  369. logger.info(f"Using attributes for subcategory: {subcategory}")
  370. else:
  371. logger.warning(f"Unknown subcategory: {subcategory}. Skipping attribute extraction.")
  372. elif category in self.CATEGORY_ATTRIBUTES:
  373. if "attributes" in self.CATEGORY_ATTRIBUTES[category]:
  374. attributes_config = self.CATEGORY_ATTRIBUTES[category]["attributes"]
  375. logger.info(f"Using attributes for category: {category}")
  376. if attributes_config:
  377. for attr_name, attr_values in attributes_config.items():
  378. result = self.classify_with_clip(
  379. image,
  380. attr_values,
  381. attr_name,
  382. confidence_threshold=0.20
  383. )
  384. if result["predictions"]:
  385. best_prediction = result["predictions"][0]
  386. if best_prediction["confidence"] > 0.20:
  387. visual_attributes[attr_name] = best_prediction["value"]
  388. detailed_predictions[attr_name] = result
  389. processing_time = time.time() - start_time
  390. logger.info(f"✓ Processing complete in {processing_time:.2f}s. Extracted {len(visual_attributes)} attributes.")
  391. return {
  392. "visual_attributes": visual_attributes,
  393. "detailed_predictions": detailed_predictions,
  394. "detection_confidence": confidence,
  395. "processing_time": round(processing_time, 2),
  396. "cache_status": "enabled" if ENABLE_CLIP_MODEL_CACHE else "disabled"
  397. }
  398. except Exception as e:
  399. logger.error(f"Error processing image: {str(e)}")
  400. return {
  401. "visual_attributes": {},
  402. "error": str(e),
  403. "processing_time": round(time.time() - start_time, 2)
  404. }