# # ==================== visual_processing_service.py ==================== # import torch # import cv2 # import numpy as np # import requests # from io import BytesIO # from PIL import Image # from typing import Dict, List, Optional, Tuple # import logging # from transformers import CLIPProcessor, CLIPModel, AutoImageProcessor, AutoModelForImageClassification # from sklearn.cluster import KMeans # import webcolors # logger = logging.getLogger(__name__) # class VisualProcessingService: # """Service for extracting visual attributes from product images using CLIP and computer vision.""" # def __init__(self): # self.clip_model = None # self.clip_processor = None # self.classification_model = None # self.classification_processor = None # def _get_clip_model(self): # """Lazy load CLIP model.""" # if self.clip_model is None: # logger.info("Loading CLIP model...") # self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") # self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") # self.clip_model.eval() # return self.clip_model, self.clip_processor # def _get_classification_model(self): # """Lazy load image classification model for product categories.""" # if self.classification_model is None: # logger.info("Loading classification model...") # # Using Google's ViT model fine-tuned on fashion/products # self.classification_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") # self.classification_model = AutoModelForImageClassification.from_pretrained( # "google/vit-base-patch16-224" # ) # self.classification_model.eval() # return self.classification_model, self.classification_processor # def download_image(self, image_url: str) -> Optional[Image.Image]: # """Download image from URL.""" # try: # response = requests.get(image_url, timeout=10) # response.raise_for_status() # image = Image.open(BytesIO(response.content)).convert('RGB') # return image # except Exception as e: # logger.error(f"Error downloading image from {image_url}: {str(e)}") # return None # def extract_dominant_colors(self, image: Image.Image, n_colors: int = 3) -> List[Dict]: # """Extract dominant colors from image using K-means clustering.""" # try: # # Resize image for faster processing # img_small = image.resize((150, 150)) # img_array = np.array(img_small) # # Reshape to pixels # pixels = img_array.reshape(-1, 3) # # Apply K-means # kmeans = KMeans(n_clusters=n_colors, random_state=42, n_init=10) # kmeans.fit(pixels) # colors = [] # for center in kmeans.cluster_centers_: # rgb = tuple(center.astype(int)) # color_name = self._get_color_name(rgb) # colors.append({ # "name": color_name, # "rgb": rgb, # "percentage": float(np.sum(kmeans.labels_ == len(colors)) / len(kmeans.labels_) * 100) # }) # # Sort by percentage # colors.sort(key=lambda x: x['percentage'], reverse=True) # return colors # except Exception as e: # logger.error(f"Error extracting colors: {str(e)}") # return [] # def _get_color_name(self, rgb: Tuple[int, int, int]) -> str: # """Convert RGB to closest color name.""" # try: # # Try to get exact match # color_name = webcolors.rgb_to_name(rgb) # return color_name # except ValueError: # # Find closest color # min_distance = float('inf') # closest_name = 'unknown' # for name in webcolors.CSS3_NAMES_TO_HEX: # hex_color = webcolors.CSS3_NAMES_TO_HEX[name] # r, g, b = webcolors.hex_to_rgb(hex_color) # distance = sum((c1 - c2) ** 2 for c1, c2 in zip(rgb, (r, g, b))) # if distance < min_distance: # min_distance = distance # closest_name = name # return closest_name # def classify_with_clip(self, image: Image.Image, candidates: List[str], attribute_name: str) -> Dict: # """Use CLIP to classify image against candidate labels.""" # try: # model, processor = self._get_clip_model() # # Prepare inputs # inputs = processor( # text=candidates, # images=image, # return_tensors="pt", # padding=True # ) # # Get predictions # with torch.no_grad(): # outputs = model(**inputs) # logits_per_image = outputs.logits_per_image # probs = logits_per_image.softmax(dim=1) # # Get top predictions # top_probs, top_indices = torch.topk(probs[0], k=min(3, len(candidates))) # results = [] # for prob, idx in zip(top_probs, top_indices): # if prob.item() > 0.15: # Confidence threshold # results.append({ # "value": candidates[idx.item()], # "confidence": float(prob.item()) # }) # return { # "attribute": attribute_name, # "predictions": results # } # except Exception as e: # logger.error(f"Error in CLIP classification: {str(e)}") # return {"attribute": attribute_name, "predictions": []} # def detect_patterns(self, image: Image.Image) -> Dict: # """Detect patterns in the image using edge detection and texture analysis.""" # try: # # Convert to OpenCV format # img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) # gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY) # # Calculate edge density # edges = cv2.Canny(gray, 50, 150) # edge_density = np.sum(edges > 0) / (edges.shape[0] * edges.shape[1]) # # Calculate texture variance # laplacian = cv2.Laplacian(gray, cv2.CV_64F) # texture_variance = laplacian.var() # # Determine pattern type based on metrics # pattern_candidates = [] # if edge_density > 0.15: # pattern_candidates.append("geometric") # if texture_variance > 500: # pattern_candidates.append("textured") # if edge_density < 0.05 and texture_variance < 200: # pattern_candidates.append("solid") # # Use CLIP for more detailed pattern detection # pattern_types = [ # "solid color", "striped", "checkered", "polka dot", "floral", # "geometric", "abstract", "graphic print", "camouflage", "paisley" # ] # clip_result = self.classify_with_clip(image, pattern_types, "pattern") # return clip_result # except Exception as e: # logger.error(f"Error detecting patterns: {str(e)}") # return {"attribute": "pattern", "predictions": []} # def detect_material(self, image: Image.Image) -> Dict: # """Detect material type using CLIP.""" # materials = [ # "cotton", "polyester", "denim", "leather", "silk", "wool", # "linen", "satin", "velvet", "fleece", "knit", "jersey", # "canvas", "nylon", "suede", "corduroy" # ] # return self.classify_with_clip(image, materials, "material") # def detect_style(self, image: Image.Image) -> Dict: # """Detect style/occasion using CLIP.""" # styles = [ # "casual", "formal", "sporty", "business", "vintage", "modern", # "bohemian", "streetwear", "elegant", "preppy", "athletic", # "loungewear", "party", "workwear", "outdoor" # ] # return self.classify_with_clip(image, styles, "style") # def detect_fit(self, image: Image.Image) -> Dict: # """Detect clothing fit using CLIP.""" # fits = [ # "slim fit", "regular fit", "loose fit", "oversized", # "tight", "relaxed", "tailored", "athletic fit" # ] # return self.classify_with_clip(image, fits, "fit") # def detect_neckline(self, image: Image.Image, product_type: str) -> Dict: # """Detect neckline type for tops using CLIP.""" # if product_type.lower() not in ['shirt', 't-shirt', 'top', 'blouse', 'dress', 'sweater']: # return {"attribute": "neckline", "predictions": []} # necklines = [ # "crew neck", "v-neck", "round neck", "collar", "turtleneck", # "scoop neck", "boat neck", "off-shoulder", "square neck", "halter" # ] # return self.classify_with_clip(image, necklines, "neckline") # def detect_sleeve_type(self, image: Image.Image, product_type: str) -> Dict: # """Detect sleeve type using CLIP.""" # if product_type.lower() not in ['shirt', 't-shirt', 'top', 'blouse', 'dress', 'sweater', 'jacket']: # return {"attribute": "sleeve_type", "predictions": []} # sleeves = [ # "short sleeve", "long sleeve", "sleeveless", "three-quarter sleeve", # "cap sleeve", "flutter sleeve", "bell sleeve", "raglan sleeve" # ] # return self.classify_with_clip(image, sleeves, "sleeve_type") # def detect_product_type(self, image: Image.Image) -> Dict: # """Detect product type using CLIP.""" # product_types = [ # "t-shirt", "shirt", "dress", "pants", "jeans", "shorts", # "skirt", "jacket", "coat", "sweater", "hoodie", "blazer", # "suit", "jumpsuit", "romper", "cardigan", "vest", "top", # "blouse", "tank top", "polo shirt", "sweatshirt" # ] # return self.classify_with_clip(image, product_types, "product_type") # def detect_closure_type(self, image: Image.Image) -> Dict: # """Detect closure type using CLIP.""" # closures = [ # "button", "zipper", "snap", "hook and eye", "velcro", # "lace-up", "pull-on", "elastic", "tie", "buckle" # ] # return self.classify_with_clip(image, closures, "closure_type") # def detect_length(self, image: Image.Image, product_type: str) -> Dict: # """Detect garment length using CLIP.""" # if product_type.lower() in ['pants', 'jeans', 'trousers']: # lengths = ["full length", "ankle length", "cropped", "capri", "shorts"] # elif product_type.lower() in ['skirt', 'dress']: # lengths = ["mini", "knee length", "midi", "maxi", "floor length"] # elif product_type.lower() in ['jacket', 'coat']: # lengths = ["waist length", "hip length", "thigh length", "knee length", "full length"] # else: # lengths = ["short", "regular", "long"] # return self.classify_with_clip(image, lengths, "length") # def process_image(self, image_url: str, product_type_hint: Optional[str] = None) -> Dict: # """ # Main method to process image and extract all visual attributes. # """ # try: # # Download image # image = self.download_image(image_url) # if image is None: # return { # "visual_attributes": {}, # "error": "Failed to download image" # } # # Extract all attributes # visual_attributes = {} # # 1. Product Type Detection # logger.info("Detecting product type...") # product_type_result = self.detect_product_type(image) # if product_type_result["predictions"]: # visual_attributes["product_type"] = product_type_result["predictions"][0]["value"] # detected_product_type = visual_attributes["product_type"] # else: # detected_product_type = product_type_hint or "unknown" # # 2. Color Detection # logger.info("Extracting colors...") # colors = self.extract_dominant_colors(image, n_colors=3) # if colors: # visual_attributes["primary_color"] = colors[0]["name"] # visual_attributes["color_palette"] = [c["name"] for c in colors] # visual_attributes["color_details"] = colors # # 3. Pattern Detection # logger.info("Detecting patterns...") # pattern_result = self.detect_patterns(image) # if pattern_result["predictions"]: # visual_attributes["pattern"] = pattern_result["predictions"][0]["value"] # # 4. Material Detection # logger.info("Detecting material...") # material_result = self.detect_material(image) # if material_result["predictions"]: # visual_attributes["material"] = material_result["predictions"][0]["value"] # # 5. Style Detection # logger.info("Detecting style...") # style_result = self.detect_style(image) # if style_result["predictions"]: # visual_attributes["style"] = style_result["predictions"][0]["value"] # # 6. Fit Detection # logger.info("Detecting fit...") # fit_result = self.detect_fit(image) # if fit_result["predictions"]: # visual_attributes["fit"] = fit_result["predictions"][0]["value"] # # 7. Neckline Detection (if applicable) # logger.info("Detecting neckline...") # neckline_result = self.detect_neckline(image, detected_product_type) # if neckline_result["predictions"]: # visual_attributes["neckline"] = neckline_result["predictions"][0]["value"] # # 8. Sleeve Type Detection (if applicable) # logger.info("Detecting sleeve type...") # sleeve_result = self.detect_sleeve_type(image, detected_product_type) # if sleeve_result["predictions"]: # visual_attributes["sleeve_type"] = sleeve_result["predictions"][0]["value"] # # 9. Closure Type Detection # logger.info("Detecting closure type...") # closure_result = self.detect_closure_type(image) # if closure_result["predictions"]: # visual_attributes["closure_type"] = closure_result["predictions"][0]["value"] # # 10. Length Detection # logger.info("Detecting length...") # length_result = self.detect_length(image, detected_product_type) # if length_result["predictions"]: # visual_attributes["length"] = length_result["predictions"][0]["value"] # # Format response # return { # "visual_attributes": visual_attributes, # "detailed_predictions": { # "product_type": product_type_result, # "pattern": pattern_result, # "material": material_result, # "style": style_result, # "fit": fit_result, # "neckline": neckline_result, # "sleeve_type": sleeve_result, # "closure_type": closure_result, # "length": length_result # } # } # except Exception as e: # logger.error(f"Error processing image: {str(e)}") # return { # "visual_attributes": {}, # "error": str(e) # } # # ==================== visual_processing_service_optimized.py ==================== # """ # Optimized version with: # - Result caching # - Batch processing support # - Memory management # - Error recovery # - Performance monitoring # """ # import torch # import cv2 # import numpy as np # import requests # from io import BytesIO # from PIL import Image # from typing import Dict, List, Optional, Tuple # import logging # import time # import hashlib # from functools import lru_cache # from transformers import CLIPProcessor, CLIPModel # from sklearn.cluster import KMeans # import webcolors # logger = logging.getLogger(__name__) # class VisualProcessingService: # """Optimized service for extracting visual attributes from product images.""" # # Class-level model caching (shared across instances) # _clip_model = None # _clip_processor = None # _model_device = None # def __init__(self, use_cache: bool = True, cache_ttl: int = 3600): # self.use_cache = use_cache # self.cache_ttl = cache_ttl # self._cache = {} # @classmethod # def _get_device(cls): # """Get optimal device (GPU if available, else CPU).""" # if cls._model_device is None: # cls._model_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # logger.info(f"Using device: {cls._model_device}") # return cls._model_device # @classmethod # def _get_clip_model(cls): # """Lazy load CLIP model with class-level caching.""" # if cls._clip_model is None: # logger.info("Loading CLIP model...") # start_time = time.time() # try: # cls._clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") # cls._clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") # device = cls._get_device() # cls._clip_model.to(device) # cls._clip_model.eval() # logger.info(f"CLIP model loaded in {time.time() - start_time:.2f}s") # except Exception as e: # logger.error(f"Failed to load CLIP model: {str(e)}") # raise # return cls._clip_model, cls._clip_processor # def _get_cache_key(self, image_url: str, operation: str) -> str: # """Generate cache key for results.""" # url_hash = hashlib.md5(image_url.encode()).hexdigest() # return f"visual_{operation}_{url_hash}" # def _get_cached(self, key: str) -> Optional[Dict]: # """Get cached result if available and not expired.""" # if not self.use_cache: # return None # if key in self._cache: # result, timestamp = self._cache[key] # if time.time() - timestamp < self.cache_ttl: # return result # else: # del self._cache[key] # return None # def _set_cached(self, key: str, value: Dict): # """Cache result with timestamp.""" # if self.use_cache: # self._cache[key] = (value, time.time()) # def download_image(self, image_url: str, max_size: Tuple[int, int] = (1024, 1024)) -> Optional[Image.Image]: # """Download and optionally resize image for faster processing.""" # try: # response = requests.get(image_url, timeout=10) # response.raise_for_status() # image = Image.open(BytesIO(response.content)).convert('RGB') # # Resize if image is too large # if image.size[0] > max_size[0] or image.size[1] > max_size[1]: # image.thumbnail(max_size, Image.Resampling.LANCZOS) # logger.info(f"Resized image from original size to {image.size}") # return image # except Exception as e: # logger.error(f"Error downloading image from {image_url}: {str(e)}") # return None # @lru_cache(maxsize=100) # def _get_color_name_cached(self, rgb: Tuple[int, int, int]) -> str: # """Cached version of color name lookup.""" # try: # return webcolors.rgb_to_name(rgb) # except ValueError: # min_distance = float('inf') # closest_name = 'unknown' # for name in webcolors.CSS3_NAMES_TO_HEX: # hex_color = webcolors.CSS3_NAMES_TO_HEX[name] # r, g, b = webcolors.hex_to_rgb(hex_color) # distance = sum((c1 - c2) ** 2 for c1, c2 in zip(rgb, (r, g, b))) # if distance < min_distance: # min_distance = distance # closest_name = name # return closest_name # def extract_dominant_colors(self, image: Image.Image, n_colors: int = 3) -> List[Dict]: # """Extract dominant colors with optimized K-means.""" # try: # # Resize for faster processing # img_small = image.resize((150, 150)) # img_array = np.array(img_small) # pixels = img_array.reshape(-1, 3) # # Sample pixels if too many # if len(pixels) > 10000: # indices = np.random.choice(len(pixels), 10000, replace=False) # pixels = pixels[indices] # # K-means with optimized parameters # kmeans = KMeans( # n_clusters=n_colors, # random_state=42, # n_init=5, # Reduced from 10 for speed # max_iter=100, # algorithm='elkan' # Faster for low dimensions # ) # kmeans.fit(pixels) # colors = [] # labels_counts = np.bincount(kmeans.labels_) # for i, center in enumerate(kmeans.cluster_centers_): # rgb = tuple(center.astype(int)) # color_name = self._get_color_name_cached(rgb) # percentage = float(labels_counts[i] / len(kmeans.labels_) * 100) # colors.append({ # "name": color_name, # "rgb": rgb, # "percentage": percentage # }) # colors.sort(key=lambda x: x['percentage'], reverse=True) # return colors # except Exception as e: # logger.error(f"Error extracting colors: {str(e)}") # return [] # def classify_with_clip( # self, # image: Image.Image, # candidates: List[str], # attribute_name: str, # confidence_threshold: float = 0.15 # ) -> Dict: # """Optimized CLIP classification with batching.""" # try: # model, processor = self._get_clip_model() # device = self._get_device() # # Prepare inputs # inputs = processor( # text=candidates, # images=image, # return_tensors="pt", # padding=True # ) # # Move to device # inputs = {k: v.to(device) for k, v in inputs.items()} # # Get predictions with no_grad for speed # with torch.no_grad(): # outputs = model(**inputs) # logits_per_image = outputs.logits_per_image # probs = logits_per_image.softmax(dim=1).cpu() # # Get top predictions # top_k = min(3, len(candidates)) # top_probs, top_indices = torch.topk(probs[0], k=top_k) # results = [] # for prob, idx in zip(top_probs, top_indices): # if prob.item() > confidence_threshold: # results.append({ # "value": candidates[idx.item()], # "confidence": float(prob.item()) # }) # return { # "attribute": attribute_name, # "predictions": results # } # except Exception as e: # logger.error(f"Error in CLIP classification for {attribute_name}: {str(e)}") # return {"attribute": attribute_name, "predictions": []} # def batch_classify( # self, # image: Image.Image, # attribute_configs: List[Dict[str, any]] # ) -> Dict[str, Dict]: # """ # Batch multiple CLIP classifications for efficiency. # attribute_configs: [{"name": "pattern", "candidates": [...], "threshold": 0.15}, ...] # """ # results = {} # for config in attribute_configs: # attr_name = config["name"] # candidates = config["candidates"] # threshold = config.get("threshold", 0.15) # result = self.classify_with_clip(image, candidates, attr_name, threshold) # results[attr_name] = result # return results # def detect_patterns(self, image: Image.Image) -> Dict: # """Detect patterns using CLIP.""" # pattern_types = [ # "solid color", "striped", "checkered", "polka dot", "floral", # "geometric", "abstract", "graphic print", "camouflage", "paisley" # ] # return self.classify_with_clip(image, pattern_types, "pattern") # def process_image( # self, # image_url: str, # product_type_hint: Optional[str] = None, # attributes_to_extract: Optional[List[str]] = None # ) -> Dict: # """ # Main method with caching and selective attribute extraction. # Args: # image_url: URL of the product image # product_type_hint: Optional hint about product type # attributes_to_extract: List of attributes to extract (None = all) # """ # # Check cache # cache_key = self._get_cache_key(image_url, "full") # cached_result = self._get_cached(cache_key) # if cached_result: # logger.info(f"Returning cached result for {image_url}") # return cached_result # start_time = time.time() # try: # # Download image # image = self.download_image(image_url) # if image is None: # return { # "visual_attributes": {}, # "error": "Failed to download image" # } # visual_attributes = {} # detailed_predictions = {} # # Default: extract all attributes # if attributes_to_extract is None: # attributes_to_extract = [ # "product_type", "color", "pattern", "material", # "style", "fit", "neckline", "sleeve_type", # "closure_type", "length" # ] # # 1. Product Type Detection # if "product_type" in attributes_to_extract: # logger.info("Detecting product type...") # product_types = [ # "t-shirt", "shirt", "dress", "pants", "jeans", "shorts", # "skirt", "jacket", "coat", "sweater", "hoodie", "blazer", # "top", "blouse" # ] # product_type_result = self.classify_with_clip(image, product_types, "product_type") # if product_type_result["predictions"]: # visual_attributes["product_type"] = product_type_result["predictions"][0]["value"] # detected_product_type = visual_attributes["product_type"] # else: # detected_product_type = product_type_hint or "unknown" # detailed_predictions["product_type"] = product_type_result # else: # detected_product_type = product_type_hint or "unknown" # # 2. Color Detection # if "color" in attributes_to_extract: # logger.info("Extracting colors...") # colors = self.extract_dominant_colors(image, n_colors=3) # if colors: # visual_attributes["primary_color"] = colors[0]["name"] # visual_attributes["color_palette"] = [c["name"] for c in colors] # visual_attributes["color_details"] = colors # # 3. Batch classify remaining attributes # batch_configs = [] # if "pattern" in attributes_to_extract: # batch_configs.append({ # "name": "pattern", # "candidates": [ # "solid color", "striped", "checkered", "polka dot", # "floral", "geometric", "abstract", "graphic print" # ] # }) # if "material" in attributes_to_extract: # batch_configs.append({ # "name": "material", # "candidates": [ # "cotton", "polyester", "denim", "leather", "silk", # "wool", "linen", "satin", "fleece", "knit" # ] # }) # if "style" in attributes_to_extract: # batch_configs.append({ # "name": "style", # "candidates": [ # "casual", "formal", "sporty", "business", "vintage", # "modern", "streetwear", "elegant", "athletic" # ] # }) # if "fit" in attributes_to_extract: # batch_configs.append({ # "name": "fit", # "candidates": [ # "slim fit", "regular fit", "loose fit", "oversized", # "relaxed", "tailored" # ] # }) # # Product-type specific attributes # if detected_product_type.lower() in ['shirt', 't-shirt', 'top', 'blouse', 'dress', 'sweater']: # if "neckline" in attributes_to_extract: # batch_configs.append({ # "name": "neckline", # "candidates": [ # "crew neck", "v-neck", "round neck", "collar", # "turtleneck", "scoop neck", "boat neck" # ] # }) # if "sleeve_type" in attributes_to_extract: # batch_configs.append({ # "name": "sleeve_type", # "candidates": [ # "short sleeve", "long sleeve", "sleeveless", # "three-quarter sleeve", "cap sleeve" # ] # }) # if "closure_type" in attributes_to_extract: # batch_configs.append({ # "name": "closure_type", # "candidates": [ # "button", "zipper", "snap", "pull-on", # "lace-up", "elastic", "buckle" # ] # }) # if "length" in attributes_to_extract: # if detected_product_type.lower() in ['pants', 'jeans', 'trousers']: # batch_configs.append({ # "name": "length", # "candidates": ["full length", "ankle length", "cropped", "capri", "shorts"] # }) # elif detected_product_type.lower() in ['skirt', 'dress']: # batch_configs.append({ # "name": "length", # "candidates": ["mini", "knee length", "midi", "maxi", "floor length"] # }) # # Execute batch classification # logger.info(f"Batch classifying {len(batch_configs)} attributes...") # batch_results = self.batch_classify(image, batch_configs) # # Process batch results # for attr_name, result in batch_results.items(): # detailed_predictions[attr_name] = result # if result["predictions"]: # visual_attributes[attr_name] = result["predictions"][0]["value"] # # Format response # result = { # "visual_attributes": visual_attributes, # "detailed_predictions": detailed_predictions, # "processing_time": round(time.time() - start_time, 2) # } # # Cache result # self._set_cached(cache_key, result) # logger.info(f"Visual processing completed in {result['processing_time']}s") # return result # except Exception as e: # logger.error(f"Error processing image: {str(e)}") # return { # "visual_attributes": {}, # "error": str(e), # "processing_time": round(time.time() - start_time, 2) # } # def clear_cache(self): # """Clear all cached results.""" # self._cache.clear() # logger.info("Cache cleared") # def get_cache_stats(self) -> Dict: # """Get cache statistics.""" # return { # "cache_size": len(self._cache), # "cache_enabled": self.use_cache, # "cache_ttl": self.cache_ttl # } # @classmethod # def cleanup_models(cls): # """Free up memory by unloading models.""" # if cls._clip_model is not None: # del cls._clip_model # del cls._clip_processor # cls._clip_model = None # cls._clip_processor = None # if torch.cuda.is_available(): # torch.cuda.empty_cache() # logger.info("Models unloaded and memory freed") # # ==================== Usage Example ==================== # def example_usage(): # """Example of how to use the optimized service.""" # # Initialize service with caching # service = VisualProcessingService(use_cache=True, cache_ttl=3600) # # Process single image with all attributes # result1 = service.process_image("https://example.com/product1.jpg") # print("All attributes:", result1["visual_attributes"]) # # Process with selective attributes (faster) # result2 = service.process_image( # "https://example.com/product2.jpg", # product_type_hint="t-shirt", # attributes_to_extract=["color", "pattern", "style"] # ) # print("Selected attributes:", result2["visual_attributes"]) # # Check cache stats # print("Cache stats:", service.get_cache_stats()) # # Clear cache when needed # service.clear_cache() # # Cleanup models (call when shutting down) # VisualProcessingService.cleanup_models() # if __name__ == "__main__": # example_usage() # ==================== visual_processing_service.py (FIXED) ==================== import torch import cv2 import numpy as np import requests from io import BytesIO from PIL import Image from typing import Dict, List, Optional, Tuple import logging from transformers import CLIPProcessor, CLIPModel from sklearn.cluster import KMeans logger = logging.getLogger(__name__) class VisualProcessingService: """Service for extracting visual attributes from product images using CLIP.""" # Class-level caching (shared across instances) _clip_model = None _clip_processor = None _device = None def __init__(self): pass @classmethod def _get_device(cls): """Get optimal device.""" if cls._device is None: cls._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Visual Processing using device: {cls._device}") return cls._device @classmethod def _get_clip_model(cls): """Lazy load CLIP model with class-level caching.""" if cls._clip_model is None: logger.info("Loading CLIP model (this may take a few minutes on first use)...") cls._clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") cls._clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") device = cls._get_device() cls._clip_model.to(device) cls._clip_model.eval() logger.info("✓ CLIP model loaded successfully") return cls._clip_model, cls._clip_processor def download_image(self, image_url: str) -> Optional[Image.Image]: """Download image from URL.""" try: response = requests.get(image_url, timeout=10) response.raise_for_status() image = Image.open(BytesIO(response.content)).convert('RGB') return image except Exception as e: logger.error(f"Error downloading image from {image_url}: {str(e)}") return None def extract_dominant_colors(self, image: Image.Image, n_colors: int = 3) -> List[Dict]: """Extract dominant colors using K-means (FIXED webcolors issue).""" try: # Resize for faster processing img_small = image.resize((150, 150)) img_array = np.array(img_small) pixels = img_array.reshape(-1, 3) # K-means clustering kmeans = KMeans(n_clusters=n_colors, random_state=42, n_init=5) kmeans.fit(pixels) colors = [] labels_counts = np.bincount(kmeans.labels_) for i, center in enumerate(kmeans.cluster_centers_): rgb = tuple(center.astype(int)) color_name = self._get_color_name_simple(rgb) percentage = float(labels_counts[i] / len(kmeans.labels_) * 100) colors.append({ "name": color_name, "rgb": rgb, "percentage": percentage }) colors.sort(key=lambda x: x['percentage'], reverse=True) return colors except Exception as e: logger.error(f"Error extracting colors: {str(e)}") return [] def _get_color_name_simple(self, rgb: Tuple[int, int, int]) -> str: """ Simple color name detection without webcolors dependency. Maps RGB to basic color names. """ r, g, b = rgb # Define basic color ranges colors = { 'black': (r < 50 and g < 50 and b < 50), 'white': (r > 200 and g > 200 and b > 200), 'gray': (abs(r - g) < 30 and abs(g - b) < 30 and abs(r - b) < 30 and 50 <= r <= 200), 'red': (r > 150 and g < 100 and b < 100), 'green': (g > 150 and r < 100 and b < 100), 'blue': (b > 150 and r < 100 and g < 100), 'yellow': (r > 200 and g > 200 and b < 100), 'orange': (r > 200 and 100 < g < 200 and b < 100), 'purple': (r > 100 and b > 100 and g < 100), 'pink': (r > 200 and 100 < g < 200 and 100 < b < 200), 'brown': (50 < r < 150 and 30 < g < 100 and b < 80), 'cyan': (r < 100 and g > 150 and b > 150), } for color_name, condition in colors.items(): if condition: return color_name # Default fallback if r > g and r > b: return 'red' elif g > r and g > b: return 'green' elif b > r and b > g: return 'blue' else: return 'gray' def classify_with_clip( self, image: Image.Image, candidates: List[str], attribute_name: str ) -> Dict: """Use CLIP to classify image against candidate labels.""" try: model, processor = self._get_clip_model() device = self._get_device() # Prepare inputs inputs = processor( text=candidates, images=image, return_tensors="pt", padding=True ) # Move to device inputs = {k: v.to(device) for k, v in inputs.items()} # Get predictions with torch.no_grad(): outputs = model(**inputs) logits_per_image = outputs.logits_per_image probs = logits_per_image.softmax(dim=1).cpu() # Get top predictions top_k = min(3, len(candidates)) top_probs, top_indices = torch.topk(probs[0], k=top_k) results = [] for prob, idx in zip(top_probs, top_indices): if prob.item() > 0.15: # Confidence threshold results.append({ "value": candidates[idx.item()], "confidence": float(prob.item()) }) return { "attribute": attribute_name, "predictions": results } except Exception as e: logger.error(f"Error in CLIP classification for {attribute_name}: {str(e)}") return {"attribute": attribute_name, "predictions": []} def process_image( self, image_url: str, product_type_hint: Optional[str] = None ) -> Dict: """ Main method to process image and extract visual attributes. """ import time start_time = time.time() try: # Download image image = self.download_image(image_url) if image is None: return { "visual_attributes": {}, "error": "Failed to download image" } visual_attributes = {} detailed_predictions = {} # 1. Product Type Detection product_types = [ "t-shirt", "shirt", "dress", "pants", "jeans", "shorts", "skirt", "jacket", "coat", "sweater", "hoodie", "top" ] product_type_result = self.classify_with_clip(image, product_types, "product_type") if product_type_result["predictions"]: visual_attributes["product_type"] = product_type_result["predictions"][0]["value"] detected_product_type = visual_attributes["product_type"] else: detected_product_type = product_type_hint or "unknown" detailed_predictions["product_type"] = product_type_result # 2. Color Detection colors = self.extract_dominant_colors(image, n_colors=3) if colors: visual_attributes["primary_color"] = colors[0]["name"] visual_attributes["color_palette"] = [c["name"] for c in colors] # 3. Pattern Detection patterns = ["solid color", "striped", "checkered", "graphic print", "floral", "geometric"] pattern_result = self.classify_with_clip(image, patterns, "pattern") if pattern_result["predictions"]: visual_attributes["pattern"] = pattern_result["predictions"][0]["value"] detailed_predictions["pattern"] = pattern_result # 4. Material Detection materials = ["cotton", "polyester", "denim", "leather", "silk", "wool", "linen"] material_result = self.classify_with_clip(image, materials, "material") if material_result["predictions"]: visual_attributes["material"] = material_result["predictions"][0]["value"] detailed_predictions["material"] = material_result # 5. Style Detection styles = ["casual", "formal", "sporty", "streetwear", "elegant", "vintage"] style_result = self.classify_with_clip(image, styles, "style") if style_result["predictions"]: visual_attributes["style"] = style_result["predictions"][0]["value"] detailed_predictions["style"] = style_result # 6. Fit Detection fits = ["slim fit", "regular fit", "loose fit", "oversized"] fit_result = self.classify_with_clip(image, fits, "fit") if fit_result["predictions"]: visual_attributes["fit"] = fit_result["predictions"][0]["value"] detailed_predictions["fit"] = fit_result # 7. Neckline (for tops) if detected_product_type.lower() in ['shirt', 't-shirt', 'top', 'dress']: necklines = ["crew neck", "v-neck", "round neck", "collar"] neckline_result = self.classify_with_clip(image, necklines, "neckline") if neckline_result["predictions"]: visual_attributes["neckline"] = neckline_result["predictions"][0]["value"] detailed_predictions["neckline"] = neckline_result # 8. Sleeve Type (for tops) if detected_product_type.lower() in ['shirt', 't-shirt', 'top']: sleeves = ["short sleeve", "long sleeve", "sleeveless"] sleeve_result = self.classify_with_clip(image, sleeves, "sleeve_type") if sleeve_result["predictions"]: visual_attributes["sleeve_type"] = sleeve_result["predictions"][0]["value"] detailed_predictions["sleeve_type"] = sleeve_result # 9. Closure Type closures = ["button", "zipper", "pull-on"] closure_result = self.classify_with_clip(image, closures, "closure_type") if closure_result["predictions"]: visual_attributes["closure_type"] = closure_result["predictions"][0]["value"] detailed_predictions["closure_type"] = closure_result processing_time = time.time() - start_time return { "visual_attributes": visual_attributes, "detailed_predictions": detailed_predictions, "processing_time": round(processing_time, 2) } except Exception as e: logger.error(f"Error processing image: {str(e)}") return { "visual_attributes": {}, "error": str(e), "processing_time": round(time.time() - start_time, 2) }