# ==================== visual_processing_service.py (FIXED - Smart Subcategory Detection) ==================== import torch import numpy as np import requests from io import BytesIO from PIL import Image from typing import Dict, List, Optional, Tuple import logging from transformers import CLIPProcessor, CLIPModel from sklearn.cluster import KMeans logger = logging.getLogger(__name__) import os os.environ['TOKENIZERS_PARALLELISM'] = 'false' # Disable tokenizer warnings import warnings warnings.filterwarnings('ignore') # Suppress all warnings class VisualProcessingService: """Service for extracting visual attributes from product images using CLIP with smart subcategory detection.""" # Class-level caching (shared across instances) _clip_model = None _clip_processor = None _device = None # Define hierarchical category structure with subcategories CATEGORY_ATTRIBUTES = { "clothing": { "subcategories": { "tops": { "products": ["t-shirt", "shirt", "blouse", "top", "sweater", "hoodie", "tank top", "polo shirt"], "attributes": { "pattern": ["solid color", "striped", "checkered", "graphic print", "floral", "geometric", "plain", "logo print"], "material": ["cotton", "polyester", "silk", "wool", "linen", "blend", "knit"], "style": ["casual", "formal", "sporty", "streetwear", "elegant", "vintage", "minimalist"], "fit": ["slim fit", "regular fit", "loose fit", "oversized", "fitted"], "neckline": ["crew neck", "v-neck", "round neck", "collar", "scoop neck", "henley"], "sleeve_type": ["short sleeve", "long sleeve", "sleeveless", "3/4 sleeve", "cap sleeve"], "closure_type": ["button-up", "zipper", "pull-on", "snap button"] } }, "bottoms": { "products": ["jeans", "pants", "trousers", "shorts", "chinos", "cargo pants", "leggings"], "attributes": { "pattern": ["solid color", "distressed", "faded", "plain", "washed", "dark wash", "light wash"], "material": ["denim", "cotton", "polyester", "wool", "blend", "twill", "corduroy"], "style": ["casual", "formal", "sporty", "vintage", "modern", "workwear"], "fit": ["slim fit", "regular fit", "loose fit", "skinny", "bootcut", "straight leg", "relaxed fit"], "rise": ["high rise", "mid rise", "low rise"], "closure_type": ["button fly", "zipper fly", "elastic waist", "drawstring"], "length": ["full length", "cropped", "ankle length", "capri"] } }, "dresses_skirts": { "products": ["dress", "skirt", "gown", "sundress", "maxi dress", "mini skirt"], "attributes": { "pattern": ["solid color", "floral", "striped", "geometric", "plain", "printed", "polka dot"], "material": ["cotton", "silk", "polyester", "linen", "blend", "chiffon", "satin"], "style": ["casual", "formal", "cocktail", "bohemian", "vintage", "elegant", "party"], "fit": ["fitted", "loose", "a-line", "bodycon", "flowy", "wrap"], "neckline": ["crew neck", "v-neck", "scoop neck", "halter", "off-shoulder", "sweetheart"], "sleeve_type": ["short sleeve", "long sleeve", "sleeveless", "3/4 sleeve", "flutter sleeve"], "length": ["mini", "midi", "maxi", "knee-length", "floor-length"] } }, "outerwear": { "products": ["jacket", "coat", "blazer", "windbreaker", "parka", "bomber jacket", "denim jacket"], "attributes": { "pattern": ["solid color", "plain", "quilted", "textured"], "material": ["leather", "denim", "wool", "polyester", "cotton", "nylon", "fleece"], "style": ["casual", "formal", "sporty", "vintage", "military", "biker"], "fit": ["slim fit", "regular fit", "oversized", "cropped"], "closure_type": ["zipper", "button", "snap button", "toggle"], "length": ["cropped", "hip length", "thigh length", "knee length"] } } } }, "footwear": { "products": ["sneakers", "boots", "sandals", "heels", "loafers", "flats", "slippers"], "attributes": { "material": ["leather", "canvas", "suede", "synthetic", "rubber", "mesh"], "style": ["casual", "formal", "athletic", "vintage", "modern"], "closure_type": ["lace-up", "slip-on", "velcro", "buckle", "zipper"], "toe_style": ["round toe", "pointed toe", "square toe", "open toe", "closed toe"] } }, "tools": { "products": ["screwdriver", "hammer", "wrench", "pliers", "drill", "saw", "measuring tape"], "attributes": { "material": ["steel", "aluminum", "plastic", "rubber", "chrome", "iron"], "type": ["manual", "electric", "pneumatic", "cordless", "corded"], "finish": ["chrome plated", "powder coated", "stainless steel", "painted"], "handle_type": ["rubber grip", "plastic", "wooden", "ergonomic", "cushioned"] } }, "electronics": { "products": ["phone", "laptop", "tablet", "headphones", "speaker", "camera", "smartwatch", "earbuds"], "attributes": { "material": ["plastic", "metal", "glass", "aluminum", "rubber", "silicone"], "style": ["modern", "minimalist", "sleek", "industrial", "vintage"], "finish": ["matte", "glossy", "metallic", "textured", "transparent"], "connectivity": ["wireless", "wired", "bluetooth", "USB-C", "USB"] } }, "furniture": { "products": ["chair", "table", "sofa", "bed", "desk", "shelf", "cabinet", "bench"], "attributes": { "material": ["wood", "metal", "glass", "plastic", "fabric", "leather", "rattan"], "style": ["modern", "traditional", "industrial", "rustic", "contemporary", "vintage", "scandinavian"], "finish": ["natural wood", "painted", "stained", "laminated", "upholstered", "polished"] } } } def __init__(self): pass @classmethod def _get_device(cls): """Get optimal device.""" if cls._device is None: cls._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Visual Processing using device: {cls._device}") return cls._device @classmethod def _get_clip_model(cls): """Lazy load CLIP model with class-level caching.""" if cls._clip_model is None: logger.info("Loading CLIP model (this may take a few minutes on first use)...") cls._clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") cls._clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") device = cls._get_device() cls._clip_model.to(device) cls._clip_model.eval() logger.info("✓ CLIP model loaded successfully") return cls._clip_model, cls._clip_processor def download_image(self, image_url: str) -> Optional[Image.Image]: """Download image from URL.""" try: response = requests.get(image_url, timeout=10) response.raise_for_status() image = Image.open(BytesIO(response.content)).convert('RGB') return image except Exception as e: logger.error(f"Error downloading image from {image_url}: {str(e)}") return None def extract_dominant_colors(self, image: Image.Image, n_colors: int = 3) -> List[Dict]: """Extract dominant colors using K-means clustering.""" try: # Resize for faster processing img_small = image.resize((150, 150)) img_array = np.array(img_small) pixels = img_array.reshape(-1, 3) # K-means clustering kmeans = KMeans(n_clusters=n_colors, random_state=42, n_init=5) kmeans.fit(pixels) colors = [] labels_counts = np.bincount(kmeans.labels_) for i, center in enumerate(kmeans.cluster_centers_): rgb = tuple(center.astype(int)) color_name = self._get_color_name_simple(rgb) percentage = float(labels_counts[i] / len(kmeans.labels_) * 100) colors.append({ "name": color_name, "rgb": rgb, "percentage": round(percentage, 2) }) # Sort by percentage (most dominant first) colors.sort(key=lambda x: x['percentage'], reverse=True) return colors except Exception as e: logger.error(f"Error extracting colors: {str(e)}") return [] def _get_color_name_simple(self, rgb: Tuple[int, int, int]) -> str: """Map RGB values to basic color names.""" r, g, b = rgb # Define color ranges with priorities colors = { 'black': (r < 50 and g < 50 and b < 50), 'white': (r > 200 and g > 200 and b > 200), 'gray': (abs(r - g) < 30 and abs(g - b) < 30 and abs(r - b) < 30 and 50 <= r <= 200), 'red': (r > 150 and g < 100 and b < 100), 'green': (g > 150 and r < 100 and b < 100), 'blue': (b > 150 and r < 100 and g < 100), 'yellow': (r > 200 and g > 200 and b < 100), 'orange': (r > 200 and 100 < g < 200 and b < 100), 'purple': (r > 100 and b > 100 and g < 100), 'pink': (r > 200 and 100 < g < 200 and 100 < b < 200), 'brown': (50 < r < 150 and 30 < g < 100 and b < 80), 'cyan': (r < 100 and g > 150 and b > 150), 'beige': (180 < r < 240 and 160 < g < 220 and 120 < b < 180), } for color_name, condition in colors.items(): if condition: return color_name # Fallback to dominant channel if r > g and r > b: return 'red' elif g > r and g > b: return 'green' elif b > r and b > g: return 'blue' else: return 'gray' def classify_with_clip( self, image: Image.Image, candidates: List[str], attribute_name: str, confidence_threshold: float = 0.15 ) -> Dict: """Use CLIP to classify image against candidate labels.""" try: model, processor = self._get_clip_model() device = self._get_device() # ⚡ OPTIMIZATION: Process in smaller batches to avoid memory issues batch_size = 16 # Process 16 candidates at a time all_results = [] for i in range(0, len(candidates), batch_size): batch_candidates = candidates[i:i + batch_size] # Prepare inputs WITHOUT progress bars inputs = processor( text=batch_candidates, images=image, return_tensors="pt", padding=True ) # Move to device inputs = {k: v.to(device) for k, v in inputs.items()} # Get predictions with torch.no_grad(): outputs = model(**inputs) logits_per_image = outputs.logits_per_image probs = logits_per_image.softmax(dim=1).cpu() # Collect results from this batch for j, prob in enumerate(probs[0]): if prob.item() > confidence_threshold: all_results.append({ "value": batch_candidates[j], "confidence": round(float(prob.item()), 3) }) # Sort by confidence and return top 3 all_results.sort(key=lambda x: x['confidence'], reverse=True) return { "attribute": attribute_name, "predictions": all_results[:3] } except Exception as e: logger.error(f"Error in CLIP classification for {attribute_name}: {str(e)}") return {"attribute": attribute_name, "predictions": []} def detect_category_and_subcategory(self, image: Image.Image) -> Tuple[str, str, str, float]: """ Hierarchically detect category, subcategory, and specific product. Returns: (category, subcategory, product_type, confidence) """ # Step 1: Detect if it's clothing or something else main_categories = list(self.CATEGORY_ATTRIBUTES.keys()) category_prompts = [f"a photo of {cat}" for cat in main_categories] result = self.classify_with_clip(image, category_prompts, "main_category", confidence_threshold=0.10) if not result["predictions"]: return "unknown", "unknown", "unknown", 0.0 detected_category = result["predictions"][0]["value"].replace("a photo of ", "") category_confidence = result["predictions"][0]["confidence"] logger.info(f"Step 1 - Main category detected: {detected_category} (confidence: {category_confidence:.3f})") # Step 2: For clothing, detect subcategory (tops/bottoms/dresses/outerwear) if detected_category == "clothing": subcategories = self.CATEGORY_ATTRIBUTES["clothing"]["subcategories"] # Collect all products grouped by subcategory all_products = [] product_to_subcategory = {} for subcat, subcat_data in subcategories.items(): for product in subcat_data["products"]: prompt = f"a photo of {product}" all_products.append(prompt) product_to_subcategory[prompt] = subcat # Step 3: Detect specific product type product_result = self.classify_with_clip( image, all_products, "product_type", confidence_threshold=0.12 ) if product_result["predictions"]: best_match = product_result["predictions"][0] product_prompt = best_match["value"] product_type = product_prompt.replace("a photo of ", "") subcategory = product_to_subcategory[product_prompt] product_confidence = best_match["confidence"] logger.info(f"Step 2 - Detected: {subcategory} > {product_type} (confidence: {product_confidence:.3f})") return detected_category, subcategory, product_type, product_confidence else: logger.warning("Could not detect specific product type for clothing") return detected_category, "unknown", "unknown", category_confidence # Step 3: For non-clothing categories, just detect product type else: category_data = self.CATEGORY_ATTRIBUTES[detected_category] # Check if this category has subcategories or direct products if "products" in category_data: products = category_data["products"] product_prompts = [f"a photo of {p}" for p in products] product_result = self.classify_with_clip( image, product_prompts, "product_type", confidence_threshold=0.12 ) if product_result["predictions"]: best_match = product_result["predictions"][0] product_type = best_match["value"].replace("a photo of ", "") logger.info(f"Step 2 - Detected: {detected_category} > {product_type}") return detected_category, "none", product_type, best_match["confidence"] return detected_category, "unknown", "unknown", category_confidence def process_image( self, image_url: str, product_type_hint: Optional[str] = None ) -> Dict: """ Main method to process image and extract visual attributes. Uses hierarchical detection to extract only relevant attributes. """ import time start_time = time.time() try: # Download image image = self.download_image(image_url) if image is None: return { "visual_attributes": {}, "error": "Failed to download image" } visual_attributes = {} detailed_predictions = {} # Step 1: Detect category, subcategory, and product type category, subcategory, product_type, confidence = self.detect_category_and_subcategory(image) # Low confidence check if confidence < 0.10: logger.warning(f"Low confidence in detection ({confidence:.3f}). Returning basic attributes only.") colors = self.extract_dominant_colors(image, n_colors=3) if colors: visual_attributes["primary_color"] = colors[0]["name"] visual_attributes["color_palette"] = [c["name"] for c in colors] return { "visual_attributes": visual_attributes, "detection_confidence": confidence, "warning": "Low confidence detection", "processing_time": round(time.time() - start_time, 2) } # Add detected metadata visual_attributes["product_type"] = product_type visual_attributes["category"] = category if subcategory != "none" and subcategory != "unknown": visual_attributes["subcategory"] = subcategory # Step 2: Extract color information (universal) colors = self.extract_dominant_colors(image, n_colors=3) if colors: visual_attributes["primary_color"] = colors[0]["name"] visual_attributes["color_palette"] = [c["name"] for c in colors[:3]] visual_attributes["color_distribution"] = [ {"color": c["name"], "percentage": c["percentage"]} for c in colors ] # Step 3: Get the right attribute configuration based on subcategory attributes_config = None if category == "clothing": if subcategory in self.CATEGORY_ATTRIBUTES["clothing"]["subcategories"]: attributes_config = self.CATEGORY_ATTRIBUTES["clothing"]["subcategories"][subcategory]["attributes"] logger.info(f"Using attributes for subcategory: {subcategory}") else: logger.warning(f"Unknown subcategory: {subcategory}. Skipping attribute extraction.") elif category in self.CATEGORY_ATTRIBUTES: if "attributes" in self.CATEGORY_ATTRIBUTES[category]: attributes_config = self.CATEGORY_ATTRIBUTES[category]["attributes"] logger.info(f"Using attributes for category: {category}") # Step 4: Extract category-specific attributes if attributes_config: for attr_name, attr_values in attributes_config.items(): result = self.classify_with_clip( image, attr_values, attr_name, confidence_threshold=0.20 ) if result["predictions"]: best_prediction = result["predictions"][0] # Only add attributes with reasonable confidence if best_prediction["confidence"] > 0.20: visual_attributes[attr_name] = best_prediction["value"] # Store detailed predictions for debugging detailed_predictions[attr_name] = result processing_time = time.time() - start_time logger.info(f"✓ Processing complete in {processing_time:.2f}s. Extracted {len(visual_attributes)} attributes.") return { "visual_attributes": visual_attributes, "detailed_predictions": detailed_predictions, "detection_confidence": confidence, "processing_time": round(processing_time, 2) } except Exception as e: logger.error(f"Error processing image: {str(e)}") return { "visual_attributes": {}, "error": str(e), "processing_time": round(time.time() - start_time, 2) }