Harshit Pathak 3 mesiacov pred
rodič
commit
d65979a6c6
2 zmenil súbory, kde vykonal 434 pridanie a 69 odobranie
  1. 434 69
      attr_extraction/visual_processing_service.py
  2. BIN
      db.sqlite3

+ 434 - 69
attr_extraction/visual_processing_service.py

@@ -904,7 +904,303 @@
 
 
 
-# ==================== visual_processing_service.py (FIXED) ====================
+# # ==================== visual_processing_service.py (FIXED) ====================
+# import torch
+# import cv2
+# import numpy as np
+# import requests
+# from io import BytesIO
+# from PIL import Image
+# from typing import Dict, List, Optional, Tuple
+# import logging
+# from transformers import CLIPProcessor, CLIPModel
+# from sklearn.cluster import KMeans
+
+# logger = logging.getLogger(__name__)
+
+
+# class VisualProcessingService:
+#     """Service for extracting visual attributes from product images using CLIP."""
+    
+#     # Class-level caching (shared across instances)
+#     _clip_model = None
+#     _clip_processor = None
+#     _device = None
+    
+#     def __init__(self):
+#         pass
+    
+#     @classmethod
+#     def _get_device(cls):
+#         """Get optimal device."""
+#         if cls._device is None:
+#             cls._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+#             logger.info(f"Visual Processing using device: {cls._device}")
+#         return cls._device
+    
+#     @classmethod
+#     def _get_clip_model(cls):
+#         """Lazy load CLIP model with class-level caching."""
+#         if cls._clip_model is None:
+#             logger.info("Loading CLIP model (this may take a few minutes on first use)...")
+#             cls._clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
+#             cls._clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
+            
+#             device = cls._get_device()
+#             cls._clip_model.to(device)
+#             cls._clip_model.eval()
+            
+#             logger.info("✓ CLIP model loaded successfully")
+#         return cls._clip_model, cls._clip_processor
+    
+#     def download_image(self, image_url: str) -> Optional[Image.Image]:
+#         """Download image from URL."""
+#         try:
+#             response = requests.get(image_url, timeout=10)
+#             response.raise_for_status()
+#             image = Image.open(BytesIO(response.content)).convert('RGB')
+#             return image
+#         except Exception as e:
+#             logger.error(f"Error downloading image from {image_url}: {str(e)}")
+#             return None
+    
+#     def extract_dominant_colors(self, image: Image.Image, n_colors: int = 3) -> List[Dict]:
+#         """Extract dominant colors using K-means (FIXED webcolors issue)."""
+#         try:
+#             # Resize for faster processing
+#             img_small = image.resize((150, 150))
+#             img_array = np.array(img_small)
+#             pixels = img_array.reshape(-1, 3)
+            
+#             # K-means clustering
+#             kmeans = KMeans(n_clusters=n_colors, random_state=42, n_init=5)
+#             kmeans.fit(pixels)
+            
+#             colors = []
+#             labels_counts = np.bincount(kmeans.labels_)
+            
+#             for i, center in enumerate(kmeans.cluster_centers_):
+#                 rgb = tuple(center.astype(int))
+#                 color_name = self._get_color_name_simple(rgb)
+#                 percentage = float(labels_counts[i] / len(kmeans.labels_) * 100)
+                
+#                 colors.append({
+#                     "name": color_name,
+#                     "rgb": rgb,
+#                     "percentage": percentage
+#                 })
+            
+#             colors.sort(key=lambda x: x['percentage'], reverse=True)
+#             return colors
+            
+#         except Exception as e:
+#             logger.error(f"Error extracting colors: {str(e)}")
+#             return []
+    
+#     def _get_color_name_simple(self, rgb: Tuple[int, int, int]) -> str:
+#         """
+#         Simple color name detection without webcolors dependency.
+#         Maps RGB to basic color names.
+#         """
+#         r, g, b = rgb
+        
+#         # Define basic color ranges
+#         colors = {
+#             'black': (r < 50 and g < 50 and b < 50),
+#             'white': (r > 200 and g > 200 and b > 200),
+#             'gray': (abs(r - g) < 30 and abs(g - b) < 30 and abs(r - b) < 30 and 50 <= r <= 200),
+#             'red': (r > 150 and g < 100 and b < 100),
+#             'green': (g > 150 and r < 100 and b < 100),
+#             'blue': (b > 150 and r < 100 and g < 100),
+#             'yellow': (r > 200 and g > 200 and b < 100),
+#             'orange': (r > 200 and 100 < g < 200 and b < 100),
+#             'purple': (r > 100 and b > 100 and g < 100),
+#             'pink': (r > 200 and 100 < g < 200 and 100 < b < 200),
+#             'brown': (50 < r < 150 and 30 < g < 100 and b < 80),
+#             'cyan': (r < 100 and g > 150 and b > 150),
+#         }
+        
+#         for color_name, condition in colors.items():
+#             if condition:
+#                 return color_name
+        
+#         # Default fallback
+#         if r > g and r > b:
+#             return 'red'
+#         elif g > r and g > b:
+#             return 'green'
+#         elif b > r and b > g:
+#             return 'blue'
+#         else:
+#             return 'gray'
+    
+#     def classify_with_clip(
+#         self,
+#         image: Image.Image,
+#         candidates: List[str],
+#         attribute_name: str
+#     ) -> Dict:
+#         """Use CLIP to classify image against candidate labels."""
+#         try:
+#             model, processor = self._get_clip_model()
+#             device = self._get_device()
+            
+#             # Prepare inputs
+#             inputs = processor(
+#                 text=candidates,
+#                 images=image,
+#                 return_tensors="pt",
+#                 padding=True
+#             )
+            
+#             # Move to device
+#             inputs = {k: v.to(device) for k, v in inputs.items()}
+            
+#             # Get predictions
+#             with torch.no_grad():
+#                 outputs = model(**inputs)
+#                 logits_per_image = outputs.logits_per_image
+#                 probs = logits_per_image.softmax(dim=1).cpu()
+            
+#             # Get top predictions
+#             top_k = min(3, len(candidates))
+#             top_probs, top_indices = torch.topk(probs[0], k=top_k)
+            
+#             results = []
+#             for prob, idx in zip(top_probs, top_indices):
+#                 if prob.item() > 0.15:  # Confidence threshold
+#                     results.append({
+#                         "value": candidates[idx.item()],
+#                         "confidence": float(prob.item())
+#                     })
+            
+#             return {
+#                 "attribute": attribute_name,
+#                 "predictions": results
+#             }
+            
+#         except Exception as e:
+#             logger.error(f"Error in CLIP classification for {attribute_name}: {str(e)}")
+#             return {"attribute": attribute_name, "predictions": []}
+    
+#     def process_image(
+#         self,
+#         image_url: str,
+#         product_type_hint: Optional[str] = None
+#     ) -> Dict:
+#         """
+#         Main method to process image and extract visual attributes.
+#         """
+#         import time
+#         start_time = time.time()
+        
+#         try:
+#             # Download image
+#             image = self.download_image(image_url)
+#             if image is None:
+#                 return {
+#                     "visual_attributes": {},
+#                     "error": "Failed to download image"
+#                 }
+            
+#             visual_attributes = {}
+#             detailed_predictions = {}
+            
+#             # 1. Product Type Detection
+#             product_types = [
+#                 "t-shirt", "shirt", "dress", "pants", "jeans", "shorts",
+#                 "skirt", "jacket", "coat", "sweater", "hoodie", "top"
+#             ]
+#             product_type_result = self.classify_with_clip(image, product_types, "product_type")
+#             if product_type_result["predictions"]:
+#                 visual_attributes["product_type"] = product_type_result["predictions"][0]["value"]
+#                 detected_product_type = visual_attributes["product_type"]
+#             else:
+#                 detected_product_type = product_type_hint or "unknown"
+#             detailed_predictions["product_type"] = product_type_result
+            
+#             # 2. Color Detection
+#             colors = self.extract_dominant_colors(image, n_colors=3)
+#             if colors:
+#                 visual_attributes["primary_color"] = colors[0]["name"]
+#                 visual_attributes["color_palette"] = [c["name"] for c in colors]
+            
+#             # 3. Pattern Detection
+#             patterns = ["solid color", "striped", "checkered", "graphic print", "floral", "geometric"]
+#             pattern_result = self.classify_with_clip(image, patterns, "pattern")
+#             if pattern_result["predictions"]:
+#                 visual_attributes["pattern"] = pattern_result["predictions"][0]["value"]
+#             detailed_predictions["pattern"] = pattern_result
+            
+#             # 4. Material Detection
+#             materials = ["cotton", "polyester", "denim", "leather", "silk", "wool", "linen"]
+#             material_result = self.classify_with_clip(image, materials, "material")
+#             if material_result["predictions"]:
+#                 visual_attributes["material"] = material_result["predictions"][0]["value"]
+#             detailed_predictions["material"] = material_result
+            
+#             # 5. Style Detection
+#             styles = ["casual", "formal", "sporty", "streetwear", "elegant", "vintage"]
+#             style_result = self.classify_with_clip(image, styles, "style")
+#             if style_result["predictions"]:
+#                 visual_attributes["style"] = style_result["predictions"][0]["value"]
+#             detailed_predictions["style"] = style_result
+            
+#             # 6. Fit Detection
+#             fits = ["slim fit", "regular fit", "loose fit", "oversized"]
+#             fit_result = self.classify_with_clip(image, fits, "fit")
+#             if fit_result["predictions"]:
+#                 visual_attributes["fit"] = fit_result["predictions"][0]["value"]
+#             detailed_predictions["fit"] = fit_result
+            
+#             # 7. Neckline (for tops)
+#             if detected_product_type.lower() in ['shirt', 't-shirt', 'top', 'dress']:
+#                 necklines = ["crew neck", "v-neck", "round neck", "collar"]
+#                 neckline_result = self.classify_with_clip(image, necklines, "neckline")
+#                 if neckline_result["predictions"]:
+#                     visual_attributes["neckline"] = neckline_result["predictions"][0]["value"]
+#                 detailed_predictions["neckline"] = neckline_result
+            
+#             # 8. Sleeve Type (for tops)
+#             if detected_product_type.lower() in ['shirt', 't-shirt', 'top']:
+#                 sleeves = ["short sleeve", "long sleeve", "sleeveless"]
+#                 sleeve_result = self.classify_with_clip(image, sleeves, "sleeve_type")
+#                 if sleeve_result["predictions"]:
+#                     visual_attributes["sleeve_type"] = sleeve_result["predictions"][0]["value"]
+#                 detailed_predictions["sleeve_type"] = sleeve_result
+            
+#             # 9. Closure Type
+#             closures = ["button", "zipper", "pull-on"]
+#             closure_result = self.classify_with_clip(image, closures, "closure_type")
+#             if closure_result["predictions"]:
+#                 visual_attributes["closure_type"] = closure_result["predictions"][0]["value"]
+#             detailed_predictions["closure_type"] = closure_result
+            
+#             processing_time = time.time() - start_time
+            
+#             return {
+#                 "visual_attributes": visual_attributes,
+#                 "detailed_predictions": detailed_predictions,
+#                 "processing_time": round(processing_time, 2)
+#             }
+            
+#         except Exception as e:
+#             logger.error(f"Error processing image: {str(e)}")
+#             return {
+#                 "visual_attributes": {},
+#                 "error": str(e),
+#                 "processing_time": round(time.time() - start_time, 2)
+#             }
+
+
+
+
+
+
+
+
+
+# ==================== visual_processing_service.py (FIXED - Dynamic Detection) ====================
 import torch
 import cv2
 import numpy as np
@@ -927,6 +1223,71 @@ class VisualProcessingService:
     _clip_processor = None
     _device = None
     
+    # Define category-specific attributes
+    CATEGORY_ATTRIBUTES = {
+        "clothing": {
+            "products": ["t-shirt", "shirt", "dress", "pants", "jeans", "shorts", 
+                        "skirt", "jacket", "coat", "sweater", "hoodie", "top", "blouse"],
+            "attributes": {
+                "pattern": ["solid color", "striped", "checkered", "graphic print", "floral", "geometric", "plain"],
+                "material": ["cotton", "polyester", "denim", "leather", "silk", "wool", "linen", "blend"],
+                "style": ["casual", "formal", "sporty", "streetwear", "elegant", "vintage", "bohemian"],
+                "fit": ["slim fit", "regular fit", "loose fit", "oversized", "tailored"],
+                "neckline": ["crew neck", "v-neck", "round neck", "collar", "scoop neck"],
+                "sleeve_type": ["short sleeve", "long sleeve", "sleeveless", "3/4 sleeve"],
+                "closure_type": ["button", "zipper", "pull-on", "snap", "tie"]
+            }
+        },
+        "tools": {
+            "products": ["screwdriver", "hammer", "wrench", "pliers", "drill", "saw", 
+                        "measuring tape", "level", "chisel", "file"],
+            "attributes": {
+                "material": ["steel", "aluminum", "plastic", "wood", "rubber", "chrome"],
+                "type": ["manual", "electric", "pneumatic", "cordless", "corded"],
+                "finish": ["chrome plated", "powder coated", "stainless steel", "painted"],
+                "handle_type": ["rubber grip", "plastic", "wooden", "cushioned", "ergonomic"]
+            }
+        },
+        "electronics": {
+            "products": ["phone", "laptop", "tablet", "headphones", "speaker", "camera", 
+                        "smartwatch", "charger", "mouse", "keyboard"],
+            "attributes": {
+                "material": ["plastic", "metal", "glass", "aluminum", "rubber"],
+                "style": ["modern", "minimalist", "sleek", "industrial", "vintage"],
+                "finish": ["matte", "glossy", "metallic", "textured"],
+                "connectivity": ["wireless", "wired", "bluetooth", "USB"]
+            }
+        },
+        "furniture": {
+            "products": ["chair", "table", "sofa", "bed", "desk", "shelf", "cabinet", 
+                        "dresser", "bench", "stool"],
+            "attributes": {
+                "material": ["wood", "metal", "glass", "plastic", "fabric", "leather"],
+                "style": ["modern", "traditional", "industrial", "rustic", "contemporary", "vintage"],
+                "finish": ["natural wood", "painted", "stained", "laminated", "upholstered"]
+            }
+        },
+        "home_decor": {
+            "products": ["painting", "canvas", "wall art", "frame", "vase", "lamp", 
+                        "mirror", "clock", "sculpture", "poster"],
+            "attributes": {
+                "style": ["modern", "abstract", "traditional", "contemporary", "vintage", "minimalist"],
+                "material": ["canvas", "wood", "metal", "glass", "ceramic", "paper"],
+                "finish": ["glossy", "matte", "textured", "framed", "gallery wrapped"],
+                "theme": ["nature", "geometric", "floral", "landscape", "portrait", "abstract"]
+            }
+        },
+        "kitchen": {
+            "products": ["pot", "pan", "knife", "utensil", "plate", "bowl", "cup", 
+                        "appliance", "cutting board", "container"],
+            "attributes": {
+                "material": ["stainless steel", "aluminum", "ceramic", "glass", "plastic", "wood"],
+                "finish": ["non-stick", "stainless", "enameled", "anodized"],
+                "type": ["manual", "electric", "dishwasher safe"]
+            }
+        }
+    }
+    
     def __init__(self):
         pass
     
@@ -965,7 +1326,7 @@ class VisualProcessingService:
             return None
     
     def extract_dominant_colors(self, image: Image.Image, n_colors: int = 3) -> List[Dict]:
-        """Extract dominant colors using K-means (FIXED webcolors issue)."""
+        """Extract dominant colors using K-means."""
         try:
             # Resize for faster processing
             img_small = image.resize((150, 150))
@@ -1038,7 +1399,8 @@ class VisualProcessingService:
         self,
         image: Image.Image,
         candidates: List[str],
-        attribute_name: str
+        attribute_name: str,
+        confidence_threshold: float = 0.15
     ) -> Dict:
         """Use CLIP to classify image against candidate labels."""
         try:
@@ -1068,7 +1430,7 @@ class VisualProcessingService:
             
             results = []
             for prob, idx in zip(top_probs, top_indices):
-                if prob.item() > 0.15:  # Confidence threshold
+                if prob.item() > confidence_threshold:
                     results.append({
                         "value": candidates[idx.item()],
                         "confidence": float(prob.item())
@@ -1083,6 +1445,34 @@ class VisualProcessingService:
             logger.error(f"Error in CLIP classification for {attribute_name}: {str(e)}")
             return {"attribute": attribute_name, "predictions": []}
     
+    def detect_product_category(self, image: Image.Image) -> Tuple[str, float]:
+        """
+        First detect which category the product belongs to.
+        Returns: (category_name, confidence)
+        """
+        # Get all product types from all categories
+        all_categories = []
+        category_map = {}
+        
+        for category, data in self.CATEGORY_ATTRIBUTES.items():
+            for product in data["products"]:
+                all_categories.append(f"a photo of a {product}")
+                category_map[f"a photo of a {product}"] = category
+        
+        # Classify
+        result = self.classify_with_clip(image, all_categories, "category_detection", confidence_threshold=0.10)
+        
+        if result["predictions"]:
+            best_match = result["predictions"][0]
+            detected_category = category_map[best_match["value"]]
+            product_type = best_match["value"].replace("a photo of a ", "")
+            confidence = best_match["confidence"]
+            
+            logger.info(f"Detected category: {detected_category}, product: {product_type}, confidence: {confidence:.3f}")
+            return detected_category, product_type, confidence
+        
+        return "unknown", "unknown", 0.0
+    
     def process_image(
         self,
         image_url: str,
@@ -1090,6 +1480,7 @@ class VisualProcessingService:
     ) -> Dict:
         """
         Main method to process image and extract visual attributes.
+        Now dynamically detects product category first.
         """
         import time
         start_time = time.time()
@@ -1106,81 +1497,54 @@ class VisualProcessingService:
             visual_attributes = {}
             detailed_predictions = {}
             
-            # 1. Product Type Detection
-            product_types = [
-                "t-shirt", "shirt", "dress", "pants", "jeans", "shorts",
-                "skirt", "jacket", "coat", "sweater", "hoodie", "top"
-            ]
-            product_type_result = self.classify_with_clip(image, product_types, "product_type")
-            if product_type_result["predictions"]:
-                visual_attributes["product_type"] = product_type_result["predictions"][0]["value"]
-                detected_product_type = visual_attributes["product_type"]
-            else:
-                detected_product_type = product_type_hint or "unknown"
-            detailed_predictions["product_type"] = product_type_result
-            
-            # 2. Color Detection
+            # Step 1: Detect product category
+            detected_category, detected_product_type, category_confidence = self.detect_product_category(image)
+            
+            # If confidence is too low, return minimal info
+            if category_confidence < 0.10:
+                logger.warning(f"Low confidence in category detection ({category_confidence:.3f}). Returning basic attributes only.")
+                colors = self.extract_dominant_colors(image, n_colors=3)
+                if colors:
+                    visual_attributes["primary_color"] = colors[0]["name"]
+                    visual_attributes["color_palette"] = [c["name"] for c in colors]
+                
+                return {
+                    "visual_attributes": visual_attributes,
+                    "category_confidence": category_confidence,
+                    "processing_time": round(time.time() - start_time, 2)
+                }
+            
+            # Add detected product type
+            visual_attributes["product_type"] = detected_product_type
+            visual_attributes["category"] = detected_category
+            
+            # Step 2: Extract color (universal attribute)
             colors = self.extract_dominant_colors(image, n_colors=3)
             if colors:
                 visual_attributes["primary_color"] = colors[0]["name"]
                 visual_attributes["color_palette"] = [c["name"] for c in colors]
             
-            # 3. Pattern Detection
-            patterns = ["solid color", "striped", "checkered", "graphic print", "floral", "geometric"]
-            pattern_result = self.classify_with_clip(image, patterns, "pattern")
-            if pattern_result["predictions"]:
-                visual_attributes["pattern"] = pattern_result["predictions"][0]["value"]
-            detailed_predictions["pattern"] = pattern_result
-            
-            # 4. Material Detection
-            materials = ["cotton", "polyester", "denim", "leather", "silk", "wool", "linen"]
-            material_result = self.classify_with_clip(image, materials, "material")
-            if material_result["predictions"]:
-                visual_attributes["material"] = material_result["predictions"][0]["value"]
-            detailed_predictions["material"] = material_result
-            
-            # 5. Style Detection
-            styles = ["casual", "formal", "sporty", "streetwear", "elegant", "vintage"]
-            style_result = self.classify_with_clip(image, styles, "style")
-            if style_result["predictions"]:
-                visual_attributes["style"] = style_result["predictions"][0]["value"]
-            detailed_predictions["style"] = style_result
-            
-            # 6. Fit Detection
-            fits = ["slim fit", "regular fit", "loose fit", "oversized"]
-            fit_result = self.classify_with_clip(image, fits, "fit")
-            if fit_result["predictions"]:
-                visual_attributes["fit"] = fit_result["predictions"][0]["value"]
-            detailed_predictions["fit"] = fit_result
-            
-            # 7. Neckline (for tops)
-            if detected_product_type.lower() in ['shirt', 't-shirt', 'top', 'dress']:
-                necklines = ["crew neck", "v-neck", "round neck", "collar"]
-                neckline_result = self.classify_with_clip(image, necklines, "neckline")
-                if neckline_result["predictions"]:
-                    visual_attributes["neckline"] = neckline_result["predictions"][0]["value"]
-                detailed_predictions["neckline"] = neckline_result
-            
-            # 8. Sleeve Type (for tops)
-            if detected_product_type.lower() in ['shirt', 't-shirt', 'top']:
-                sleeves = ["short sleeve", "long sleeve", "sleeveless"]
-                sleeve_result = self.classify_with_clip(image, sleeves, "sleeve_type")
-                if sleeve_result["predictions"]:
-                    visual_attributes["sleeve_type"] = sleeve_result["predictions"][0]["value"]
-                detailed_predictions["sleeve_type"] = sleeve_result
-            
-            # 9. Closure Type
-            closures = ["button", "zipper", "pull-on"]
-            closure_result = self.classify_with_clip(image, closures, "closure_type")
-            if closure_result["predictions"]:
-                visual_attributes["closure_type"] = closure_result["predictions"][0]["value"]
-            detailed_predictions["closure_type"] = closure_result
+            # Step 3: Extract category-specific attributes
+            if detected_category in self.CATEGORY_ATTRIBUTES:
+                category_config = self.CATEGORY_ATTRIBUTES[detected_category]
+                
+                for attr_name, attr_values in category_config["attributes"].items():
+                    # Use higher confidence threshold for category-specific attributes
+                    result = self.classify_with_clip(image, attr_values, attr_name, confidence_threshold=0.20)
+                    
+                    if result["predictions"]:
+                        # Only add if confidence is reasonable
+                        best_prediction = result["predictions"][0]
+                        if best_prediction["confidence"] > 0.20:
+                            visual_attributes[attr_name] = best_prediction["value"]
+                        detailed_predictions[attr_name] = result
             
             processing_time = time.time() - start_time
             
             return {
                 "visual_attributes": visual_attributes,
                 "detailed_predictions": detailed_predictions,
+                "category_confidence": category_confidence,
                 "processing_time": round(processing_time, 2)
             }
             
@@ -1190,4 +1554,5 @@ class VisualProcessingService:
                 "visual_attributes": {},
                 "error": str(e),
                 "processing_time": round(time.time() - start_time, 2)
-            }
+            }
+

BIN
db.sqlite3