| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478 |
- # ==================== visual_processing_service.py (WITH CACHE CONTROL) ====================
- import torch
- import numpy as np
- import requests
- from io import BytesIO
- from PIL import Image
- from typing import Dict, List, Optional, Tuple
- import logging
- from transformers import CLIPProcessor, CLIPModel
- from sklearn.cluster import KMeans
- # ⚡ IMPORT CACHE CONFIGURATION
- from .cache_config import ENABLE_CLIP_MODEL_CACHE
- logger = logging.getLogger(__name__)
- import os
- os.environ['TOKENIZERS_PARALLELISM'] = 'false'
- import warnings
- warnings.filterwarnings('ignore')
- class VisualProcessingService:
- """Service for extracting visual attributes from product images using CLIP with smart subcategory detection."""
-
- # ⚡ Class-level caching (controlled by cache_config)
- _clip_model = None
- _clip_processor = None
- _device = None
-
- # Define hierarchical category structure with subcategories
- CATEGORY_ATTRIBUTES = {
- "clothing": {
- "subcategories": {
- "tops": {
- "products": ["t-shirt", "shirt", "blouse", "top", "sweater", "hoodie", "tank top", "polo shirt"],
- "attributes": {
- "pattern": ["solid color", "striped", "checkered", "graphic print", "floral", "geometric", "plain", "logo print"],
- "material": ["cotton", "polyester", "silk", "wool", "linen", "blend", "knit"],
- "style": ["casual", "formal", "sporty", "streetwear", "elegant", "vintage", "minimalist"],
- "fit": ["slim fit", "regular fit", "loose fit", "oversized", "fitted"],
- "neckline": ["crew neck", "v-neck", "round neck", "collar", "scoop neck", "henley"],
- "sleeve_type": ["short sleeve", "long sleeve", "sleeveless", "3/4 sleeve", "cap sleeve"],
- "closure_type": ["button-up", "zipper", "pull-on", "snap button"]
- }
- },
- "bottoms": {
- "products": ["jeans", "pants", "trousers", "shorts", "chinos", "cargo pants", "leggings"],
- "attributes": {
- "pattern": ["solid color", "distressed", "faded", "plain", "washed", "dark wash", "light wash"],
- "material": ["denim", "cotton", "polyester", "wool", "blend", "twill", "corduroy"],
- "style": ["casual", "formal", "sporty", "vintage", "modern", "workwear"],
- "fit": ["slim fit", "regular fit", "loose fit", "skinny", "bootcut", "straight leg", "relaxed fit"],
- "rise": ["high rise", "mid rise", "low rise"],
- "closure_type": ["button fly", "zipper fly", "elastic waist", "drawstring"],
- "length": ["full length", "cropped", "ankle length", "capri"]
- }
- },
- "dresses_skirts": {
- "products": ["dress", "skirt", "gown", "sundress", "maxi dress", "mini skirt"],
- "attributes": {
- "pattern": ["solid color", "floral", "striped", "geometric", "plain", "printed", "polka dot"],
- "material": ["cotton", "silk", "polyester", "linen", "blend", "chiffon", "satin"],
- "style": ["casual", "formal", "cocktail", "bohemian", "vintage", "elegant", "party"],
- "fit": ["fitted", "loose", "a-line", "bodycon", "flowy", "wrap"],
- "neckline": ["crew neck", "v-neck", "scoop neck", "halter", "off-shoulder", "sweetheart"],
- "sleeve_type": ["short sleeve", "long sleeve", "sleeveless", "3/4 sleeve", "flutter sleeve"],
- "length": ["mini", "midi", "maxi", "knee-length", "floor-length"]
- }
- },
- "outerwear": {
- "products": ["jacket", "coat", "blazer", "windbreaker", "parka", "bomber jacket", "denim jacket"],
- "attributes": {
- "pattern": ["solid color", "plain", "quilted", "textured"],
- "material": ["leather", "denim", "wool", "polyester", "cotton", "nylon", "fleece"],
- "style": ["casual", "formal", "sporty", "vintage", "military", "biker"],
- "fit": ["slim fit", "regular fit", "oversized", "cropped"],
- "closure_type": ["zipper", "button", "snap button", "toggle"],
- "length": ["cropped", "hip length", "thigh length", "knee length"]
- }
- }
- }
- },
- "footwear": {
- "products": ["sneakers", "boots", "sandals", "heels", "loafers", "flats", "slippers"],
- "attributes": {
- "material": ["leather", "canvas", "suede", "synthetic", "rubber", "mesh"],
- "style": ["casual", "formal", "athletic", "vintage", "modern"],
- "closure_type": ["lace-up", "slip-on", "velcro", "buckle", "zipper"],
- "toe_style": ["round toe", "pointed toe", "square toe", "open toe", "closed toe"]
- }
- },
- "tools": {
- "products": ["screwdriver", "hammer", "wrench", "pliers", "drill", "saw", "measuring tape"],
- "attributes": {
- "material": ["steel", "aluminum", "plastic", "rubber", "chrome", "iron"],
- "type": ["manual", "electric", "pneumatic", "cordless", "corded"],
- "finish": ["chrome plated", "powder coated", "stainless steel", "painted"],
- "handle_type": ["rubber grip", "plastic", "wooden", "ergonomic", "cushioned"]
- }
- },
- "electronics": {
- "products": ["phone", "laptop", "tablet", "headphones", "speaker", "camera", "smartwatch", "earbuds"],
- "attributes": {
- "material": ["plastic", "metal", "glass", "aluminum", "rubber", "silicone"],
- "style": ["modern", "minimalist", "sleek", "industrial", "vintage"],
- "finish": ["matte", "glossy", "metallic", "textured", "transparent"],
- "connectivity": ["wireless", "wired", "bluetooth", "USB-C", "USB"]
- }
- },
- "furniture": {
- "products": ["chair", "table", "sofa", "bed", "desk", "shelf", "cabinet", "bench"],
- "attributes": {
- "material": ["wood", "metal", "glass", "plastic", "fabric", "leather", "rattan"],
- "style": ["modern", "traditional", "industrial", "rustic", "contemporary", "vintage", "scandinavian"],
- "finish": ["natural wood", "painted", "stained", "laminated", "upholstered", "polished"]
- }
- }
- }
-
- def __init__(self):
- pass
-
- @classmethod
- def _get_device(cls):
- """Get optimal device."""
- if cls._device is None:
- cls._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- logger.info(f"Visual Processing using device: {cls._device}")
- return cls._device
-
- @classmethod
- def _get_clip_model(cls):
- """
- Lazy load CLIP model with optional class-level caching.
- ⚡ If caching is disabled, model is still loaded but not persisted at class level.
- """
- # ⚡ CACHE CONTROL: If caching is disabled, always reload (no persistence)
- if not ENABLE_CLIP_MODEL_CACHE:
- logger.info("⚠ CLIP model caching is DISABLED - loading fresh instance")
- model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
- processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
-
- device = cls._get_device()
- model.to(device)
- model.eval()
-
- logger.info("✓ CLIP model loaded (no caching)")
- return model, processor
-
- # Caching is enabled - use class-level cache
- if cls._clip_model is None:
- logger.info("Loading CLIP model (this may take a few minutes on first use)...")
- cls._clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
- cls._clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
-
- device = cls._get_device()
- cls._clip_model.to(device)
- cls._clip_model.eval()
-
- logger.info("✓ CLIP model loaded and cached successfully")
- else:
- logger.info("✓ Using cached CLIP model")
-
- return cls._clip_model, cls._clip_processor
-
- @classmethod
- def clear_clip_cache(cls):
- """Clear the cached CLIP model to free memory."""
- if cls._clip_model is not None:
- del cls._clip_model
- del cls._clip_processor
- cls._clip_model = None
- cls._clip_processor = None
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- logger.info("✓ CLIP model cache cleared")
-
- def download_image(self, image_url: str) -> Optional[Image.Image]:
- """Download image from URL."""
- try:
- response = requests.get(image_url, timeout=10)
- response.raise_for_status()
- image = Image.open(BytesIO(response.content)).convert('RGB')
- return image
- except Exception as e:
- logger.error(f"Error downloading image from {image_url}: {str(e)}")
- return None
-
- def extract_dominant_colors(self, image: Image.Image, n_colors: int = 3) -> List[Dict]:
- """Extract dominant colors using K-means clustering."""
- try:
- img_small = image.resize((150, 150))
- img_array = np.array(img_small)
- pixels = img_array.reshape(-1, 3)
-
- kmeans = KMeans(n_clusters=n_colors, random_state=42, n_init=5)
- kmeans.fit(pixels)
-
- colors = []
- labels_counts = np.bincount(kmeans.labels_)
-
- for i, center in enumerate(kmeans.cluster_centers_):
- rgb = tuple(center.astype(int))
- color_name = self._get_color_name_simple(rgb)
- percentage = float(labels_counts[i] / len(kmeans.labels_) * 100)
-
- colors.append({
- "name": color_name,
- "rgb": rgb,
- "percentage": round(percentage, 2)
- })
-
- colors.sort(key=lambda x: x['percentage'], reverse=True)
- return colors
-
- except Exception as e:
- logger.error(f"Error extracting colors: {str(e)}")
- return []
-
- def _get_color_name_simple(self, rgb: Tuple[int, int, int]) -> str:
- """Map RGB values to basic color names."""
- r, g, b = rgb
-
- colors = {
- 'black': (r < 50 and g < 50 and b < 50),
- 'white': (r > 200 and g > 200 and b > 200),
- 'gray': (abs(r - g) < 30 and abs(g - b) < 30 and abs(r - b) < 30 and 50 <= r <= 200),
- 'red': (r > 150 and g < 100 and b < 100),
- 'green': (g > 150 and r < 100 and b < 100),
- 'blue': (b > 150 and r < 100 and g < 100),
- 'yellow': (r > 200 and g > 200 and b < 100),
- 'orange': (r > 200 and 100 < g < 200 and b < 100),
- 'purple': (r > 100 and b > 100 and g < 100),
- 'pink': (r > 200 and 100 < g < 200 and 100 < b < 200),
- 'brown': (50 < r < 150 and 30 < g < 100 and b < 80),
- 'cyan': (r < 100 and g > 150 and b > 150),
- 'beige': (180 < r < 240 and 160 < g < 220 and 120 < b < 180),
- }
-
- for color_name, condition in colors.items():
- if condition:
- return color_name
-
- if r > g and r > b:
- return 'red'
- elif g > r and g > b:
- return 'green'
- elif b > r and b > g:
- return 'blue'
- else:
- return 'gray'
- def classify_with_clip(
- self,
- image: Image.Image,
- candidates: List[str],
- attribute_name: str,
- confidence_threshold: float = 0.15
- ) -> Dict:
- """Use CLIP to classify image against candidate labels."""
- try:
- model, processor = self._get_clip_model()
- device = self._get_device()
-
- batch_size = 16
- all_results = []
-
- for i in range(0, len(candidates), batch_size):
- batch_candidates = candidates[i:i + batch_size]
-
- inputs = processor(
- text=batch_candidates,
- images=image,
- return_tensors="pt",
- padding=True
- )
-
- inputs = {k: v.to(device) for k, v in inputs.items()}
-
- with torch.no_grad():
- outputs = model(**inputs)
- logits_per_image = outputs.logits_per_image
- probs = logits_per_image.softmax(dim=1).cpu()
-
- for j, prob in enumerate(probs[0]):
- if prob.item() > confidence_threshold:
- all_results.append({
- "value": batch_candidates[j],
- "confidence": round(float(prob.item()), 3)
- })
-
- all_results.sort(key=lambda x: x['confidence'], reverse=True)
-
- return {
- "attribute": attribute_name,
- "predictions": all_results[:3]
- }
-
- except Exception as e:
- logger.error(f"Error in CLIP classification for {attribute_name}: {str(e)}")
- return {"attribute": attribute_name, "predictions": []}
- def detect_category_and_subcategory(self, image: Image.Image) -> Tuple[str, str, str, float]:
- """
- Hierarchically detect category, subcategory, and specific product.
- Returns: (category, subcategory, product_type, confidence)
- """
- main_categories = list(self.CATEGORY_ATTRIBUTES.keys())
- category_prompts = [f"a photo of {cat}" for cat in main_categories]
-
- result = self.classify_with_clip(image, category_prompts, "main_category", confidence_threshold=0.10)
-
- if not result["predictions"]:
- return "unknown", "unknown", "unknown", 0.0
-
- detected_category = result["predictions"][0]["value"].replace("a photo of ", "")
- category_confidence = result["predictions"][0]["confidence"]
-
- logger.info(f"Step 1 - Main category detected: {detected_category} (confidence: {category_confidence:.3f})")
-
- if detected_category == "clothing":
- subcategories = self.CATEGORY_ATTRIBUTES["clothing"]["subcategories"]
-
- all_products = []
- product_to_subcategory = {}
-
- for subcat, subcat_data in subcategories.items():
- for product in subcat_data["products"]:
- prompt = f"a photo of {product}"
- all_products.append(prompt)
- product_to_subcategory[prompt] = subcat
-
- product_result = self.classify_with_clip(
- image,
- all_products,
- "product_type",
- confidence_threshold=0.12
- )
-
- if product_result["predictions"]:
- best_match = product_result["predictions"][0]
- product_prompt = best_match["value"]
- product_type = product_prompt.replace("a photo of ", "")
- subcategory = product_to_subcategory[product_prompt]
- product_confidence = best_match["confidence"]
-
- logger.info(f"Step 2 - Detected: {subcategory} > {product_type} (confidence: {product_confidence:.3f})")
- return detected_category, subcategory, product_type, product_confidence
- else:
- logger.warning("Could not detect specific product type for clothing")
- return detected_category, "unknown", "unknown", category_confidence
-
- else:
- category_data = self.CATEGORY_ATTRIBUTES[detected_category]
-
- if "products" in category_data:
- products = category_data["products"]
- product_prompts = [f"a photo of {p}" for p in products]
-
- product_result = self.classify_with_clip(
- image,
- product_prompts,
- "product_type",
- confidence_threshold=0.12
- )
-
- if product_result["predictions"]:
- best_match = product_result["predictions"][0]
- product_type = best_match["value"].replace("a photo of ", "")
-
- logger.info(f"Step 2 - Detected: {detected_category} > {product_type}")
- return detected_category, "none", product_type, best_match["confidence"]
-
- return detected_category, "unknown", "unknown", category_confidence
-
- def process_image(
- self,
- image_url: str,
- product_type_hint: Optional[str] = None
- ) -> Dict:
- """
- Main method to process image and extract visual attributes.
- Uses hierarchical detection to extract only relevant attributes.
- """
- import time
- start_time = time.time()
-
- try:
- image = self.download_image(image_url)
- if image is None:
- return {
- "visual_attributes": {},
- "error": "Failed to download image"
- }
-
- visual_attributes = {}
- detailed_predictions = {}
-
- category, subcategory, product_type, confidence = self.detect_category_and_subcategory(image)
-
- if confidence < 0.10:
- logger.warning(f"Low confidence in detection ({confidence:.3f}). Returning basic attributes only.")
- colors = self.extract_dominant_colors(image, n_colors=3)
- if colors:
- visual_attributes["primary_color"] = colors[0]["name"]
- visual_attributes["color_palette"] = [c["name"] for c in colors]
-
- return {
- "visual_attributes": visual_attributes,
- "detection_confidence": confidence,
- "warning": "Low confidence detection",
- "processing_time": round(time.time() - start_time, 2)
- }
-
- visual_attributes["product_type"] = product_type
- visual_attributes["category"] = category
- if subcategory != "none" and subcategory != "unknown":
- visual_attributes["subcategory"] = subcategory
-
- colors = self.extract_dominant_colors(image, n_colors=3)
- if colors:
- visual_attributes["primary_color"] = colors[0]["name"]
- visual_attributes["color_palette"] = [c["name"] for c in colors[:3]]
- visual_attributes["color_distribution"] = [
- {"color": c["name"], "percentage": c["percentage"]}
- for c in colors
- ]
-
- attributes_config = None
-
- if category == "clothing":
- if subcategory in self.CATEGORY_ATTRIBUTES["clothing"]["subcategories"]:
- attributes_config = self.CATEGORY_ATTRIBUTES["clothing"]["subcategories"][subcategory]["attributes"]
- logger.info(f"Using attributes for subcategory: {subcategory}")
- else:
- logger.warning(f"Unknown subcategory: {subcategory}. Skipping attribute extraction.")
-
- elif category in self.CATEGORY_ATTRIBUTES:
- if "attributes" in self.CATEGORY_ATTRIBUTES[category]:
- attributes_config = self.CATEGORY_ATTRIBUTES[category]["attributes"]
- logger.info(f"Using attributes for category: {category}")
-
- if attributes_config:
- for attr_name, attr_values in attributes_config.items():
- result = self.classify_with_clip(
- image,
- attr_values,
- attr_name,
- confidence_threshold=0.20
- )
-
- if result["predictions"]:
- best_prediction = result["predictions"][0]
- if best_prediction["confidence"] > 0.20:
- visual_attributes[attr_name] = best_prediction["value"]
-
- detailed_predictions[attr_name] = result
-
- processing_time = time.time() - start_time
-
- logger.info(f"✓ Processing complete in {processing_time:.2f}s. Extracted {len(visual_attributes)} attributes.")
-
- return {
- "visual_attributes": visual_attributes,
- "detailed_predictions": detailed_predictions,
- "detection_confidence": confidence,
- "processing_time": round(processing_time, 2),
- "cache_status": "enabled" if ENABLE_CLIP_MODEL_CACHE else "disabled"
- }
-
- except Exception as e:
- logger.error(f"Error processing image: {str(e)}")
- return {
- "visual_attributes": {},
- "error": str(e),
- "processing_time": round(time.time() - start_time, 2)
- }
|