visual_processing_service.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475
  1. # ==================== visual_processing_service.py (FIXED - Smart Subcategory Detection) ====================
  2. import torch
  3. import numpy as np
  4. import requests
  5. from io import BytesIO
  6. from PIL import Image
  7. from typing import Dict, List, Optional, Tuple
  8. import logging
  9. from transformers import CLIPProcessor, CLIPModel
  10. from sklearn.cluster import KMeans
  11. logger = logging.getLogger(__name__)
  12. import os
  13. os.environ['TOKENIZERS_PARALLELISM'] = 'false' # Disable tokenizer warnings
  14. import warnings
  15. warnings.filterwarnings('ignore') # Suppress all warnings
  16. class VisualProcessingService:
  17. """Service for extracting visual attributes from product images using CLIP with smart subcategory detection."""
  18. # Class-level caching (shared across instances)
  19. _clip_model = None
  20. _clip_processor = None
  21. _device = None
  22. # Define hierarchical category structure with subcategories
  23. CATEGORY_ATTRIBUTES = {
  24. "clothing": {
  25. "subcategories": {
  26. "tops": {
  27. "products": ["t-shirt", "shirt", "blouse", "top", "sweater", "hoodie", "tank top", "polo shirt"],
  28. "attributes": {
  29. "pattern": ["solid color", "striped", "checkered", "graphic print", "floral", "geometric", "plain", "logo print"],
  30. "material": ["cotton", "polyester", "silk", "wool", "linen", "blend", "knit"],
  31. "style": ["casual", "formal", "sporty", "streetwear", "elegant", "vintage", "minimalist"],
  32. "fit": ["slim fit", "regular fit", "loose fit", "oversized", "fitted"],
  33. "neckline": ["crew neck", "v-neck", "round neck", "collar", "scoop neck", "henley"],
  34. "sleeve_type": ["short sleeve", "long sleeve", "sleeveless", "3/4 sleeve", "cap sleeve"],
  35. "closure_type": ["button-up", "zipper", "pull-on", "snap button"]
  36. }
  37. },
  38. "bottoms": {
  39. "products": ["jeans", "pants", "trousers", "shorts", "chinos", "cargo pants", "leggings"],
  40. "attributes": {
  41. "pattern": ["solid color", "distressed", "faded", "plain", "washed", "dark wash", "light wash"],
  42. "material": ["denim", "cotton", "polyester", "wool", "blend", "twill", "corduroy"],
  43. "style": ["casual", "formal", "sporty", "vintage", "modern", "workwear"],
  44. "fit": ["slim fit", "regular fit", "loose fit", "skinny", "bootcut", "straight leg", "relaxed fit"],
  45. "rise": ["high rise", "mid rise", "low rise"],
  46. "closure_type": ["button fly", "zipper fly", "elastic waist", "drawstring"],
  47. "length": ["full length", "cropped", "ankle length", "capri"]
  48. }
  49. },
  50. "dresses_skirts": {
  51. "products": ["dress", "skirt", "gown", "sundress", "maxi dress", "mini skirt"],
  52. "attributes": {
  53. "pattern": ["solid color", "floral", "striped", "geometric", "plain", "printed", "polka dot"],
  54. "material": ["cotton", "silk", "polyester", "linen", "blend", "chiffon", "satin"],
  55. "style": ["casual", "formal", "cocktail", "bohemian", "vintage", "elegant", "party"],
  56. "fit": ["fitted", "loose", "a-line", "bodycon", "flowy", "wrap"],
  57. "neckline": ["crew neck", "v-neck", "scoop neck", "halter", "off-shoulder", "sweetheart"],
  58. "sleeve_type": ["short sleeve", "long sleeve", "sleeveless", "3/4 sleeve", "flutter sleeve"],
  59. "length": ["mini", "midi", "maxi", "knee-length", "floor-length"]
  60. }
  61. },
  62. "outerwear": {
  63. "products": ["jacket", "coat", "blazer", "windbreaker", "parka", "bomber jacket", "denim jacket"],
  64. "attributes": {
  65. "pattern": ["solid color", "plain", "quilted", "textured"],
  66. "material": ["leather", "denim", "wool", "polyester", "cotton", "nylon", "fleece"],
  67. "style": ["casual", "formal", "sporty", "vintage", "military", "biker"],
  68. "fit": ["slim fit", "regular fit", "oversized", "cropped"],
  69. "closure_type": ["zipper", "button", "snap button", "toggle"],
  70. "length": ["cropped", "hip length", "thigh length", "knee length"]
  71. }
  72. }
  73. }
  74. },
  75. "footwear": {
  76. "products": ["sneakers", "boots", "sandals", "heels", "loafers", "flats", "slippers"],
  77. "attributes": {
  78. "material": ["leather", "canvas", "suede", "synthetic", "rubber", "mesh"],
  79. "style": ["casual", "formal", "athletic", "vintage", "modern"],
  80. "closure_type": ["lace-up", "slip-on", "velcro", "buckle", "zipper"],
  81. "toe_style": ["round toe", "pointed toe", "square toe", "open toe", "closed toe"]
  82. }
  83. },
  84. "tools": {
  85. "products": ["screwdriver", "hammer", "wrench", "pliers", "drill", "saw", "measuring tape"],
  86. "attributes": {
  87. "material": ["steel", "aluminum", "plastic", "rubber", "chrome", "iron"],
  88. "type": ["manual", "electric", "pneumatic", "cordless", "corded"],
  89. "finish": ["chrome plated", "powder coated", "stainless steel", "painted"],
  90. "handle_type": ["rubber grip", "plastic", "wooden", "ergonomic", "cushioned"]
  91. }
  92. },
  93. "electronics": {
  94. "products": ["phone", "laptop", "tablet", "headphones", "speaker", "camera", "smartwatch", "earbuds"],
  95. "attributes": {
  96. "material": ["plastic", "metal", "glass", "aluminum", "rubber", "silicone"],
  97. "style": ["modern", "minimalist", "sleek", "industrial", "vintage"],
  98. "finish": ["matte", "glossy", "metallic", "textured", "transparent"],
  99. "connectivity": ["wireless", "wired", "bluetooth", "USB-C", "USB"]
  100. }
  101. },
  102. "furniture": {
  103. "products": ["chair", "table", "sofa", "bed", "desk", "shelf", "cabinet", "bench"],
  104. "attributes": {
  105. "material": ["wood", "metal", "glass", "plastic", "fabric", "leather", "rattan"],
  106. "style": ["modern", "traditional", "industrial", "rustic", "contemporary", "vintage", "scandinavian"],
  107. "finish": ["natural wood", "painted", "stained", "laminated", "upholstered", "polished"]
  108. }
  109. }
  110. }
  111. def __init__(self):
  112. pass
  113. @classmethod
  114. def _get_device(cls):
  115. """Get optimal device."""
  116. if cls._device is None:
  117. cls._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  118. logger.info(f"Visual Processing using device: {cls._device}")
  119. return cls._device
  120. @classmethod
  121. def _get_clip_model(cls):
  122. """Lazy load CLIP model with class-level caching."""
  123. if cls._clip_model is None:
  124. logger.info("Loading CLIP model (this may take a few minutes on first use)...")
  125. cls._clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
  126. cls._clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
  127. device = cls._get_device()
  128. cls._clip_model.to(device)
  129. cls._clip_model.eval()
  130. logger.info("✓ CLIP model loaded successfully")
  131. return cls._clip_model, cls._clip_processor
  132. def download_image(self, image_url: str) -> Optional[Image.Image]:
  133. """Download image from URL."""
  134. try:
  135. response = requests.get(image_url, timeout=10)
  136. response.raise_for_status()
  137. image = Image.open(BytesIO(response.content)).convert('RGB')
  138. return image
  139. except Exception as e:
  140. logger.error(f"Error downloading image from {image_url}: {str(e)}")
  141. return None
  142. def extract_dominant_colors(self, image: Image.Image, n_colors: int = 3) -> List[Dict]:
  143. """Extract dominant colors using K-means clustering."""
  144. try:
  145. # Resize for faster processing
  146. img_small = image.resize((150, 150))
  147. img_array = np.array(img_small)
  148. pixels = img_array.reshape(-1, 3)
  149. # K-means clustering
  150. kmeans = KMeans(n_clusters=n_colors, random_state=42, n_init=5)
  151. kmeans.fit(pixels)
  152. colors = []
  153. labels_counts = np.bincount(kmeans.labels_)
  154. for i, center in enumerate(kmeans.cluster_centers_):
  155. rgb = tuple(center.astype(int))
  156. color_name = self._get_color_name_simple(rgb)
  157. percentage = float(labels_counts[i] / len(kmeans.labels_) * 100)
  158. colors.append({
  159. "name": color_name,
  160. "rgb": rgb,
  161. "percentage": round(percentage, 2)
  162. })
  163. # Sort by percentage (most dominant first)
  164. colors.sort(key=lambda x: x['percentage'], reverse=True)
  165. return colors
  166. except Exception as e:
  167. logger.error(f"Error extracting colors: {str(e)}")
  168. return []
  169. def _get_color_name_simple(self, rgb: Tuple[int, int, int]) -> str:
  170. """Map RGB values to basic color names."""
  171. r, g, b = rgb
  172. # Define color ranges with priorities
  173. colors = {
  174. 'black': (r < 50 and g < 50 and b < 50),
  175. 'white': (r > 200 and g > 200 and b > 200),
  176. 'gray': (abs(r - g) < 30 and abs(g - b) < 30 and abs(r - b) < 30 and 50 <= r <= 200),
  177. 'red': (r > 150 and g < 100 and b < 100),
  178. 'green': (g > 150 and r < 100 and b < 100),
  179. 'blue': (b > 150 and r < 100 and g < 100),
  180. 'yellow': (r > 200 and g > 200 and b < 100),
  181. 'orange': (r > 200 and 100 < g < 200 and b < 100),
  182. 'purple': (r > 100 and b > 100 and g < 100),
  183. 'pink': (r > 200 and 100 < g < 200 and 100 < b < 200),
  184. 'brown': (50 < r < 150 and 30 < g < 100 and b < 80),
  185. 'cyan': (r < 100 and g > 150 and b > 150),
  186. 'beige': (180 < r < 240 and 160 < g < 220 and 120 < b < 180),
  187. }
  188. for color_name, condition in colors.items():
  189. if condition:
  190. return color_name
  191. # Fallback to dominant channel
  192. if r > g and r > b:
  193. return 'red'
  194. elif g > r and g > b:
  195. return 'green'
  196. elif b > r and b > g:
  197. return 'blue'
  198. else:
  199. return 'gray'
  200. def classify_with_clip(
  201. self,
  202. image: Image.Image,
  203. candidates: List[str],
  204. attribute_name: str,
  205. confidence_threshold: float = 0.15
  206. ) -> Dict:
  207. """Use CLIP to classify image against candidate labels."""
  208. try:
  209. model, processor = self._get_clip_model()
  210. device = self._get_device()
  211. # ⚡ OPTIMIZATION: Process in smaller batches to avoid memory issues
  212. batch_size = 16 # Process 16 candidates at a time
  213. all_results = []
  214. for i in range(0, len(candidates), batch_size):
  215. batch_candidates = candidates[i:i + batch_size]
  216. # Prepare inputs WITHOUT progress bars
  217. inputs = processor(
  218. text=batch_candidates,
  219. images=image,
  220. return_tensors="pt",
  221. padding=True
  222. )
  223. # Move to device
  224. inputs = {k: v.to(device) for k, v in inputs.items()}
  225. # Get predictions
  226. with torch.no_grad():
  227. outputs = model(**inputs)
  228. logits_per_image = outputs.logits_per_image
  229. probs = logits_per_image.softmax(dim=1).cpu()
  230. # Collect results from this batch
  231. for j, prob in enumerate(probs[0]):
  232. if prob.item() > confidence_threshold:
  233. all_results.append({
  234. "value": batch_candidates[j],
  235. "confidence": round(float(prob.item()), 3)
  236. })
  237. # Sort by confidence and return top 3
  238. all_results.sort(key=lambda x: x['confidence'], reverse=True)
  239. return {
  240. "attribute": attribute_name,
  241. "predictions": all_results[:3]
  242. }
  243. except Exception as e:
  244. logger.error(f"Error in CLIP classification for {attribute_name}: {str(e)}")
  245. return {"attribute": attribute_name, "predictions": []}
  246. def detect_category_and_subcategory(self, image: Image.Image) -> Tuple[str, str, str, float]:
  247. """
  248. Hierarchically detect category, subcategory, and specific product.
  249. Returns: (category, subcategory, product_type, confidence)
  250. """
  251. # Step 1: Detect if it's clothing or something else
  252. main_categories = list(self.CATEGORY_ATTRIBUTES.keys())
  253. category_prompts = [f"a photo of {cat}" for cat in main_categories]
  254. result = self.classify_with_clip(image, category_prompts, "main_category", confidence_threshold=0.10)
  255. if not result["predictions"]:
  256. return "unknown", "unknown", "unknown", 0.0
  257. detected_category = result["predictions"][0]["value"].replace("a photo of ", "")
  258. category_confidence = result["predictions"][0]["confidence"]
  259. logger.info(f"Step 1 - Main category detected: {detected_category} (confidence: {category_confidence:.3f})")
  260. # Step 2: For clothing, detect subcategory (tops/bottoms/dresses/outerwear)
  261. if detected_category == "clothing":
  262. subcategories = self.CATEGORY_ATTRIBUTES["clothing"]["subcategories"]
  263. # Collect all products grouped by subcategory
  264. all_products = []
  265. product_to_subcategory = {}
  266. for subcat, subcat_data in subcategories.items():
  267. for product in subcat_data["products"]:
  268. prompt = f"a photo of {product}"
  269. all_products.append(prompt)
  270. product_to_subcategory[prompt] = subcat
  271. # Step 3: Detect specific product type
  272. product_result = self.classify_with_clip(
  273. image,
  274. all_products,
  275. "product_type",
  276. confidence_threshold=0.12
  277. )
  278. if product_result["predictions"]:
  279. best_match = product_result["predictions"][0]
  280. product_prompt = best_match["value"]
  281. product_type = product_prompt.replace("a photo of ", "")
  282. subcategory = product_to_subcategory[product_prompt]
  283. product_confidence = best_match["confidence"]
  284. logger.info(f"Step 2 - Detected: {subcategory} > {product_type} (confidence: {product_confidence:.3f})")
  285. return detected_category, subcategory, product_type, product_confidence
  286. else:
  287. logger.warning("Could not detect specific product type for clothing")
  288. return detected_category, "unknown", "unknown", category_confidence
  289. # Step 3: For non-clothing categories, just detect product type
  290. else:
  291. category_data = self.CATEGORY_ATTRIBUTES[detected_category]
  292. # Check if this category has subcategories or direct products
  293. if "products" in category_data:
  294. products = category_data["products"]
  295. product_prompts = [f"a photo of {p}" for p in products]
  296. product_result = self.classify_with_clip(
  297. image,
  298. product_prompts,
  299. "product_type",
  300. confidence_threshold=0.12
  301. )
  302. if product_result["predictions"]:
  303. best_match = product_result["predictions"][0]
  304. product_type = best_match["value"].replace("a photo of ", "")
  305. logger.info(f"Step 2 - Detected: {detected_category} > {product_type}")
  306. return detected_category, "none", product_type, best_match["confidence"]
  307. return detected_category, "unknown", "unknown", category_confidence
  308. def process_image(
  309. self,
  310. image_url: str,
  311. product_type_hint: Optional[str] = None
  312. ) -> Dict:
  313. """
  314. Main method to process image and extract visual attributes.
  315. Uses hierarchical detection to extract only relevant attributes.
  316. """
  317. import time
  318. start_time = time.time()
  319. try:
  320. # Download image
  321. image = self.download_image(image_url)
  322. if image is None:
  323. return {
  324. "visual_attributes": {},
  325. "error": "Failed to download image"
  326. }
  327. visual_attributes = {}
  328. detailed_predictions = {}
  329. # Step 1: Detect category, subcategory, and product type
  330. category, subcategory, product_type, confidence = self.detect_category_and_subcategory(image)
  331. # Low confidence check
  332. if confidence < 0.10:
  333. logger.warning(f"Low confidence in detection ({confidence:.3f}). Returning basic attributes only.")
  334. colors = self.extract_dominant_colors(image, n_colors=3)
  335. if colors:
  336. visual_attributes["primary_color"] = colors[0]["name"]
  337. visual_attributes["color_palette"] = [c["name"] for c in colors]
  338. return {
  339. "visual_attributes": visual_attributes,
  340. "detection_confidence": confidence,
  341. "warning": "Low confidence detection",
  342. "processing_time": round(time.time() - start_time, 2)
  343. }
  344. # Add detected metadata
  345. visual_attributes["product_type"] = product_type
  346. visual_attributes["category"] = category
  347. if subcategory != "none" and subcategory != "unknown":
  348. visual_attributes["subcategory"] = subcategory
  349. # Step 2: Extract color information (universal)
  350. colors = self.extract_dominant_colors(image, n_colors=3)
  351. if colors:
  352. visual_attributes["primary_color"] = colors[0]["name"]
  353. visual_attributes["color_palette"] = [c["name"] for c in colors[:3]]
  354. visual_attributes["color_distribution"] = [
  355. {"color": c["name"], "percentage": c["percentage"]}
  356. for c in colors
  357. ]
  358. # Step 3: Get the right attribute configuration based on subcategory
  359. attributes_config = None
  360. if category == "clothing":
  361. if subcategory in self.CATEGORY_ATTRIBUTES["clothing"]["subcategories"]:
  362. attributes_config = self.CATEGORY_ATTRIBUTES["clothing"]["subcategories"][subcategory]["attributes"]
  363. logger.info(f"Using attributes for subcategory: {subcategory}")
  364. else:
  365. logger.warning(f"Unknown subcategory: {subcategory}. Skipping attribute extraction.")
  366. elif category in self.CATEGORY_ATTRIBUTES:
  367. if "attributes" in self.CATEGORY_ATTRIBUTES[category]:
  368. attributes_config = self.CATEGORY_ATTRIBUTES[category]["attributes"]
  369. logger.info(f"Using attributes for category: {category}")
  370. # Step 4: Extract category-specific attributes
  371. if attributes_config:
  372. for attr_name, attr_values in attributes_config.items():
  373. result = self.classify_with_clip(
  374. image,
  375. attr_values,
  376. attr_name,
  377. confidence_threshold=0.20
  378. )
  379. if result["predictions"]:
  380. best_prediction = result["predictions"][0]
  381. # Only add attributes with reasonable confidence
  382. if best_prediction["confidence"] > 0.20:
  383. visual_attributes[attr_name] = best_prediction["value"]
  384. # Store detailed predictions for debugging
  385. detailed_predictions[attr_name] = result
  386. processing_time = time.time() - start_time
  387. logger.info(f"✓ Processing complete in {processing_time:.2f}s. Extracted {len(visual_attributes)} attributes.")
  388. return {
  389. "visual_attributes": visual_attributes,
  390. "detailed_predictions": detailed_predictions,
  391. "detection_confidence": confidence,
  392. "processing_time": round(processing_time, 2)
  393. }
  394. except Exception as e:
  395. logger.error(f"Error processing image: {str(e)}")
  396. return {
  397. "visual_attributes": {},
  398. "error": str(e),
  399. "processing_time": round(time.time() - start_time, 2)
  400. }