# ==================== visual_processing_service.py (WITH CACHE CONTROL) ==================== import torch import numpy as np import requests from io import BytesIO from PIL import Image from typing import Dict, List, Optional, Tuple import logging from transformers import CLIPProcessor, CLIPModel from sklearn.cluster import KMeans # ⚡ IMPORT CACHE CONFIGURATION from .cache_config import ENABLE_CLIP_MODEL_CACHE logger = logging.getLogger(__name__) import os os.environ['TOKENIZERS_PARALLELISM'] = 'false' import warnings warnings.filterwarnings('ignore') class VisualProcessingService: """Service for extracting visual attributes from product images using CLIP with smart subcategory detection.""" # ⚡ Class-level caching (controlled by cache_config) _clip_model = None _clip_processor = None _device = None # Define hierarchical category structure with subcategories CATEGORY_ATTRIBUTES = { "clothing": { "subcategories": { "tops": { "products": ["t-shirt", "shirt", "blouse", "top", "sweater", "hoodie", "tank top", "polo shirt"], "attributes": { "pattern": ["solid color", "striped", "checkered", "graphic print", "floral", "geometric", "plain", "logo print"], "material": ["cotton", "polyester", "silk", "wool", "linen", "blend", "knit"], "style": ["casual", "formal", "sporty", "streetwear", "elegant", "vintage", "minimalist"], "fit": ["slim fit", "regular fit", "loose fit", "oversized", "fitted"], "neckline": ["crew neck", "v-neck", "round neck", "collar", "scoop neck", "henley"], "sleeve_type": ["short sleeve", "long sleeve", "sleeveless", "3/4 sleeve", "cap sleeve"], "closure_type": ["button-up", "zipper", "pull-on", "snap button"] } }, "bottoms": { "products": ["jeans", "pants", "trousers", "shorts", "chinos", "cargo pants", "leggings"], "attributes": { "pattern": ["solid color", "distressed", "faded", "plain", "washed", "dark wash", "light wash"], "material": ["denim", "cotton", "polyester", "wool", "blend", "twill", "corduroy"], "style": ["casual", "formal", "sporty", "vintage", "modern", "workwear"], "fit": ["slim fit", "regular fit", "loose fit", "skinny", "bootcut", "straight leg", "relaxed fit"], "rise": ["high rise", "mid rise", "low rise"], "closure_type": ["button fly", "zipper fly", "elastic waist", "drawstring"], "length": ["full length", "cropped", "ankle length", "capri"] } }, "dresses_skirts": { "products": ["dress", "skirt", "gown", "sundress", "maxi dress", "mini skirt"], "attributes": { "pattern": ["solid color", "floral", "striped", "geometric", "plain", "printed", "polka dot"], "material": ["cotton", "silk", "polyester", "linen", "blend", "chiffon", "satin"], "style": ["casual", "formal", "cocktail", "bohemian", "vintage", "elegant", "party"], "fit": ["fitted", "loose", "a-line", "bodycon", "flowy", "wrap"], "neckline": ["crew neck", "v-neck", "scoop neck", "halter", "off-shoulder", "sweetheart"], "sleeve_type": ["short sleeve", "long sleeve", "sleeveless", "3/4 sleeve", "flutter sleeve"], "length": ["mini", "midi", "maxi", "knee-length", "floor-length"] } }, "outerwear": { "products": ["jacket", "coat", "blazer", "windbreaker", "parka", "bomber jacket", "denim jacket"], "attributes": { "pattern": ["solid color", "plain", "quilted", "textured"], "material": ["leather", "denim", "wool", "polyester", "cotton", "nylon", "fleece"], "style": ["casual", "formal", "sporty", "vintage", "military", "biker"], "fit": ["slim fit", "regular fit", "oversized", "cropped"], "closure_type": ["zipper", "button", "snap button", "toggle"], "length": ["cropped", "hip length", "thigh length", "knee length"] } } } }, "footwear": { "products": ["sneakers", "boots", "sandals", "heels", "loafers", "flats", "slippers"], "attributes": { "material": ["leather", "canvas", "suede", "synthetic", "rubber", "mesh"], "style": ["casual", "formal", "athletic", "vintage", "modern"], "closure_type": ["lace-up", "slip-on", "velcro", "buckle", "zipper"], "toe_style": ["round toe", "pointed toe", "square toe", "open toe", "closed toe"] } }, "tools": { "products": ["screwdriver", "hammer", "wrench", "pliers", "drill", "saw", "measuring tape"], "attributes": { "material": ["steel", "aluminum", "plastic", "rubber", "chrome", "iron"], "type": ["manual", "electric", "pneumatic", "cordless", "corded"], "finish": ["chrome plated", "powder coated", "stainless steel", "painted"], "handle_type": ["rubber grip", "plastic", "wooden", "ergonomic", "cushioned"] } }, "electronics": { "products": ["phone", "laptop", "tablet", "headphones", "speaker", "camera", "smartwatch", "earbuds"], "attributes": { "material": ["plastic", "metal", "glass", "aluminum", "rubber", "silicone"], "style": ["modern", "minimalist", "sleek", "industrial", "vintage"], "finish": ["matte", "glossy", "metallic", "textured", "transparent"], "connectivity": ["wireless", "wired", "bluetooth", "USB-C", "USB"] } }, "furniture": { "products": ["chair", "table", "sofa", "bed", "desk", "shelf", "cabinet", "bench"], "attributes": { "material": ["wood", "metal", "glass", "plastic", "fabric", "leather", "rattan"], "style": ["modern", "traditional", "industrial", "rustic", "contemporary", "vintage", "scandinavian"], "finish": ["natural wood", "painted", "stained", "laminated", "upholstered", "polished"] } } } def __init__(self): pass @classmethod def _get_device(cls): """Get optimal device.""" if cls._device is None: cls._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Visual Processing using device: {cls._device}") return cls._device # ==================== visual_processing_service.py ==================== @classmethod def _get_clip_model(cls): """ 🔥 ALWAYS cache CLIP model (ignores global cache setting). This is a 400MB model that takes 30-60s to load. """ if cls._clip_model is None: import time start = time.time() logger.info("📥 Loading CLIP model from HuggingFace...") cls._clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") cls._clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") device = cls._get_device() cls._clip_model.to(device) cls._clip_model.eval() load_time = time.time() - start logger.info(f"✓ CLIP model loaded in {load_time:.1f}s and cached in memory") else: logger.debug("✓ Using cached CLIP model") return cls._clip_model, cls._clip_processor @classmethod def clear_clip_cache(cls): """Clear the cached CLIP model to free memory.""" if cls._clip_model is not None: del cls._clip_model del cls._clip_processor cls._clip_model = None cls._clip_processor = None if torch.cuda.is_available(): torch.cuda.empty_cache() logger.info("✓ CLIP model cache cleared") def download_image(self, image_url: str) -> Optional[Image.Image]: """Download image from URL.""" try: response = requests.get(image_url, timeout=10) response.raise_for_status() image = Image.open(BytesIO(response.content)).convert('RGB') return image except Exception as e: logger.error(f"Error downloading image from {image_url}: {str(e)}") return None def extract_dominant_colors(self, image: Image.Image, n_colors: int = 3) -> List[Dict]: """Extract dominant colors using K-means clustering.""" try: img_small = image.resize((150, 150)) img_array = np.array(img_small) pixels = img_array.reshape(-1, 3) kmeans = KMeans(n_clusters=n_colors, random_state=42, n_init=5) kmeans.fit(pixels) colors = [] labels_counts = np.bincount(kmeans.labels_) for i, center in enumerate(kmeans.cluster_centers_): rgb = tuple(center.astype(int)) color_name = self._get_color_name_simple(rgb) percentage = float(labels_counts[i] / len(kmeans.labels_) * 100) colors.append({ "name": color_name, "rgb": rgb, "percentage": round(percentage, 2) }) colors.sort(key=lambda x: x['percentage'], reverse=True) return colors except Exception as e: logger.error(f"Error extracting colors: {str(e)}") return [] def _get_color_name_simple(self, rgb: Tuple[int, int, int]) -> str: """Map RGB values to basic color names.""" r, g, b = rgb colors = { 'black': (r < 50 and g < 50 and b < 50), 'white': (r > 200 and g > 200 and b > 200), 'gray': (abs(r - g) < 30 and abs(g - b) < 30 and abs(r - b) < 30 and 50 <= r <= 200), 'red': (r > 150 and g < 100 and b < 100), 'green': (g > 150 and r < 100 and b < 100), 'blue': (b > 150 and r < 100 and g < 100), 'yellow': (r > 200 and g > 200 and b < 100), 'orange': (r > 200 and 100 < g < 200 and b < 100), 'purple': (r > 100 and b > 100 and g < 100), 'pink': (r > 200 and 100 < g < 200 and 100 < b < 200), 'brown': (50 < r < 150 and 30 < g < 100 and b < 80), 'cyan': (r < 100 and g > 150 and b > 150), 'beige': (180 < r < 240 and 160 < g < 220 and 120 < b < 180), } for color_name, condition in colors.items(): if condition: return color_name if r > g and r > b: return 'red' elif g > r and g > b: return 'green' elif b > r and b > g: return 'blue' else: return 'gray' def classify_with_clip( self, image: Image.Image, candidates: List[str], attribute_name: str, confidence_threshold: float = 0.15 ) -> Dict: """Use CLIP to classify image against candidate labels.""" try: model, processor = self._get_clip_model() device = self._get_device() batch_size = 16 all_results = [] for i in range(0, len(candidates), batch_size): batch_candidates = candidates[i:i + batch_size] inputs = processor( text=batch_candidates, images=image, return_tensors="pt", padding=True ) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs) logits_per_image = outputs.logits_per_image probs = logits_per_image.softmax(dim=1).cpu() for j, prob in enumerate(probs[0]): if prob.item() > confidence_threshold: all_results.append({ "value": batch_candidates[j], "confidence": round(float(prob.item()), 3) }) all_results.sort(key=lambda x: x['confidence'], reverse=True) return { "attribute": attribute_name, "predictions": all_results[:3] } except Exception as e: logger.error(f"Error in CLIP classification for {attribute_name}: {str(e)}") return {"attribute": attribute_name, "predictions": []} def detect_category_and_subcategory(self, image: Image.Image) -> Tuple[str, str, str, float]: """ Hierarchically detect category, subcategory, and specific product. Returns: (category, subcategory, product_type, confidence) """ main_categories = list(self.CATEGORY_ATTRIBUTES.keys()) category_prompts = [f"a photo of {cat}" for cat in main_categories] result = self.classify_with_clip(image, category_prompts, "main_category", confidence_threshold=0.10) if not result["predictions"]: return "unknown", "unknown", "unknown", 0.0 detected_category = result["predictions"][0]["value"].replace("a photo of ", "") category_confidence = result["predictions"][0]["confidence"] logger.info(f"Step 1 - Main category detected: {detected_category} (confidence: {category_confidence:.3f})") if detected_category == "clothing": subcategories = self.CATEGORY_ATTRIBUTES["clothing"]["subcategories"] all_products = [] product_to_subcategory = {} for subcat, subcat_data in subcategories.items(): for product in subcat_data["products"]: prompt = f"a photo of {product}" all_products.append(prompt) product_to_subcategory[prompt] = subcat product_result = self.classify_with_clip( image, all_products, "product_type", confidence_threshold=0.12 ) if product_result["predictions"]: best_match = product_result["predictions"][0] product_prompt = best_match["value"] product_type = product_prompt.replace("a photo of ", "") subcategory = product_to_subcategory[product_prompt] product_confidence = best_match["confidence"] logger.info(f"Step 2 - Detected: {subcategory} > {product_type} (confidence: {product_confidence:.3f})") return detected_category, subcategory, product_type, product_confidence else: logger.warning("Could not detect specific product type for clothing") return detected_category, "unknown", "unknown", category_confidence else: category_data = self.CATEGORY_ATTRIBUTES[detected_category] if "products" in category_data: products = category_data["products"] product_prompts = [f"a photo of {p}" for p in products] product_result = self.classify_with_clip( image, product_prompts, "product_type", confidence_threshold=0.12 ) if product_result["predictions"]: best_match = product_result["predictions"][0] product_type = best_match["value"].replace("a photo of ", "") logger.info(f"Step 2 - Detected: {detected_category} > {product_type}") return detected_category, "none", product_type, best_match["confidence"] return detected_category, "unknown", "unknown", category_confidence def process_image( self, image_url: str, product_type_hint: Optional[str] = None ) -> Dict: """ Main method to process image and extract visual attributes. Uses hierarchical detection to extract only relevant attributes. """ import time start_time = time.time() try: image = self.download_image(image_url) if image is None: return { "visual_attributes": {}, "error": "Failed to download image" } visual_attributes = {} detailed_predictions = {} category, subcategory, product_type, confidence = self.detect_category_and_subcategory(image) if confidence < 0.10: logger.warning(f"Low confidence in detection ({confidence:.3f}). Returning basic attributes only.") colors = self.extract_dominant_colors(image, n_colors=3) if colors: visual_attributes["primary_color"] = colors[0]["name"] visual_attributes["color_palette"] = [c["name"] for c in colors] return { "visual_attributes": visual_attributes, "detection_confidence": confidence, "warning": "Low confidence detection", "processing_time": round(time.time() - start_time, 2) } visual_attributes["product_type"] = product_type visual_attributes["category"] = category if subcategory != "none" and subcategory != "unknown": visual_attributes["subcategory"] = subcategory colors = self.extract_dominant_colors(image, n_colors=3) if colors: visual_attributes["primary_color"] = colors[0]["name"] visual_attributes["color_palette"] = [c["name"] for c in colors[:3]] visual_attributes["color_distribution"] = [ {"color": c["name"], "percentage": c["percentage"]} for c in colors ] attributes_config = None if category == "clothing": if subcategory in self.CATEGORY_ATTRIBUTES["clothing"]["subcategories"]: attributes_config = self.CATEGORY_ATTRIBUTES["clothing"]["subcategories"][subcategory]["attributes"] logger.info(f"Using attributes for subcategory: {subcategory}") else: logger.warning(f"Unknown subcategory: {subcategory}. Skipping attribute extraction.") elif category in self.CATEGORY_ATTRIBUTES: if "attributes" in self.CATEGORY_ATTRIBUTES[category]: attributes_config = self.CATEGORY_ATTRIBUTES[category]["attributes"] logger.info(f"Using attributes for category: {category}") if attributes_config: for attr_name, attr_values in attributes_config.items(): result = self.classify_with_clip( image, attr_values, attr_name, confidence_threshold=0.20 ) if result["predictions"]: best_prediction = result["predictions"][0] if best_prediction["confidence"] > 0.20: visual_attributes[attr_name] = best_prediction["value"] detailed_predictions[attr_name] = result processing_time = time.time() - start_time logger.info(f"✓ Processing complete in {processing_time:.2f}s. Extracted {len(visual_attributes)} attributes.") return { "visual_attributes": visual_attributes, "detailed_predictions": detailed_predictions, "detection_confidence": confidence, "processing_time": round(processing_time, 2), "cache_status": "enabled" if ENABLE_CLIP_MODEL_CACHE else "disabled" } except Exception as e: logger.error(f"Error processing image: {str(e)}") return { "visual_attributes": {}, "error": str(e), "processing_time": round(time.time() - start_time, 2) }