Forráskód Böngészése

added visual processing

Harshit Pathak 3 hónapja
szülő
commit
5d0bf8a733

+ 1 - 0
attr_extraction/ocr_service.py

@@ -111,6 +111,7 @@ class OCRService:
             # Download image
             image = self.download_image(image_url)
             if image is None:
+                print("Unable to download the image...")
                 return {
                     "detected_text": [],
                     "extracted_attributes": {},

+ 69 - 58
attr_extraction/serializers.py

@@ -3,8 +3,20 @@
 
 
 
+
+
+
+
+
+
+
+
+
+
 # # ==================== serializers.py ====================
 # from rest_framework import serializers
+# from .models import Product, ProductType, ProductAttribute, AttributePossibleValue
+
 
 # class ProductInputSerializer(serializers.Serializer):
 #     """Serializer for individual product input."""
@@ -14,10 +26,12 @@
 #     long_desc = serializers.CharField(required=False, allow_blank=True, allow_null=True)
 #     image_url = serializers.URLField(required=False, allow_blank=True, allow_null=True)
 
+
 # class MandatoryAttrsField(serializers.DictField):
 #     """Custom DictField to validate mandatory_attrs structure."""
 #     child = serializers.ListField(child=serializers.CharField())
 
+
 # class ProductBatchInputSerializer(serializers.Serializer):
 #     """Serializer for an individual product input within the batch request."""
 #     item_id = serializers.CharField(required=True)
@@ -25,15 +39,10 @@
 #         required=True,
 #         help_text="A dictionary of attribute names and their possible values."
 #     )
-#     # You can also allow per-product model/flags if needed, but keeping it batch-level for simplicity here.
 
 
 # class SingleProductRequestSerializer(serializers.Serializer):
 #     """Serializer for single product extraction request."""
-#     # title = serializers.CharField(required=False, allow_blank=True, allow_null=True)
-#     # short_desc = serializers.CharField(required=False, allow_blank=True, allow_null=True)
-#     # long_desc = serializers.CharField(required=False, allow_blank=True, allow_null=True)
-#     # image_url = serializers.URLField(required=False, allow_blank=True, allow_null=True)
 #     item_id = serializers.CharField(required=True)
 #     mandatory_attrs = serializers.DictField(
 #         child=serializers.ListField(child=serializers.CharField()),
@@ -42,6 +51,17 @@
 #     model = serializers.CharField(required=False, default="llama-3.1-8b-instant")
 #     extract_additional = serializers.BooleanField(required=False, default=True)
 #     process_image = serializers.BooleanField(required=False, default=True)
+#     multiple = serializers.ListField(
+#         child=serializers.CharField(),
+#         required=False,
+#         default=list,
+#         help_text="List of attribute names that can have multiple values"
+#     )
+#     threshold_abs = serializers.FloatField(default=0.65, required=False)
+#     margin = serializers.FloatField(default=0.15, required=False)
+#     use_dynamic_thresholds = serializers.BooleanField(default=True, required=False)
+#     use_adaptive_margin = serializers.BooleanField(default=True, required=False)
+#     use_semantic_clustering = serializers.BooleanField(default=True, required=False)
 
 #     def validate_model(self, value):
 #         from django.conf import settings
@@ -52,19 +72,28 @@
 #         return value
 
 
-
 # class BatchProductRequestSerializer(serializers.Serializer):
 #     """Serializer for batch product extraction request (with item-specific attributes)."""
 #     products = serializers.ListField(
-#         child=ProductBatchInputSerializer(), # <--- Changed
+#         child=ProductBatchInputSerializer(),
 #         required=True,
 #         min_length=1
 #     )
 #     model = serializers.CharField(required=False, default="llama-3.1-8b-instant")
 #     extract_additional = serializers.BooleanField(required=False, default=True)
 #     process_image = serializers.BooleanField(required=False, default=True)
+#     multiple = serializers.ListField(
+#         child=serializers.CharField(),
+#         required=False,
+#         default=list,
+#         help_text="List of attribute names that can have multiple values"
+#     )
+#     threshold_abs = serializers.FloatField(default=0.65, required=False)
+#     margin = serializers.FloatField(default=0.15, required=False)
+#     use_dynamic_thresholds = serializers.BooleanField(default=True, required=False)
+#     use_adaptive_margin = serializers.BooleanField(default=True, required=False)
+#     use_semantic_clustering = serializers.BooleanField(default=True, required=False)
     
-#     # ... validate_model method ...
 #     def validate_model(self, value):
 #         from django.conf import settings
 #         if value not in settings.SUPPORTED_MODELS:
@@ -73,7 +102,6 @@
 #             )
 #         return value
     
-#     # ... validate_products method (updated to use products instead of item_ids) ...
 #     def validate_products(self, value):
 #         from django.conf import settings
 #         max_size = getattr(settings, 'MAX_BATCH_SIZE', 100)
@@ -83,6 +111,7 @@
 #             )
 #         return value
 
+
 # class OCRResultSerializer(serializers.Serializer):
 #     """Serializer for OCR results."""
 #     detected_text = serializers.ListField(child=serializers.DictField())
@@ -107,12 +136,8 @@
 #     failed = serializers.IntegerField()
 
 
-
-
-# from rest_framework import serializers
-# from .models import Product
-
 # class ProductSerializer(serializers.ModelSerializer):
+#     """Serializer for Product model with product type details."""
 #     product_type_details = serializers.SerializerMethodField()
     
 #     class Meta:
@@ -126,11 +151,11 @@
 #             'product_type',
 #             'image_path',
 #             'image',
-#             'product_type_details',  # new field
+#             'product_type_details',
 #         ]
 
 #     def get_product_type_details(self, obj):
-#         # Fetch ProductType object for this product
+#         """Fetch ProductType object and its attributes for this product."""
 #         try:
 #             product_type = ProductType.objects.get(name=obj.product_type)
 #         except ProductType.DoesNotExist:
@@ -148,23 +173,24 @@
 #         ]
 
 
-
-# from rest_framework import serializers
-# from .models import Product, ProductType, ProductAttribute, AttributePossibleValue
-
 # class AttributePossibleValueSerializer(serializers.ModelSerializer):
+#     """Serializer for AttributePossibleValue model."""
 #     class Meta:
 #         model = AttributePossibleValue
 #         fields = ['value']
 
+
 # class ProductAttributeSerializer(serializers.ModelSerializer):
+#     """Serializer for ProductAttribute model with possible values."""
 #     possible_values = AttributePossibleValueSerializer(many=True, read_only=True)
     
 #     class Meta:
 #         model = ProductAttribute
 #         fields = ['name', 'is_mandatory', 'possible_values']
 
+
 # class ProductTypeSerializer(serializers.ModelSerializer):
+#     """Serializer for ProductType model with attributes."""
 #     attributes = ProductAttributeSerializer(many=True, read_only=True)
     
 #     class Meta:
@@ -177,38 +203,11 @@
 
 
 
+        
 
 
 
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-# ==================== serializers.py ====================
+# ==================== Updated serializers.py ====================
 from rest_framework import serializers
 from .models import Product, ProductType, ProductAttribute, AttributePossibleValue
 
@@ -313,12 +312,32 @@ class OCRResultSerializer(serializers.Serializer):
     extracted_attributes = serializers.DictField()
 
 
+class VisualAttributeDetailSerializer(serializers.Serializer):
+    """Serializer for detailed visual attribute predictions."""
+    attribute = serializers.CharField()
+    predictions = serializers.ListField(child=serializers.DictField())
+
+
+class VisualResultSerializer(serializers.Serializer):
+    """Serializer for visual processing results."""
+    visual_attributes = serializers.DictField(
+        help_text="Extracted visual attributes like color, pattern, style, etc."
+    )
+    detailed_predictions = serializers.DictField(
+        child=VisualAttributeDetailSerializer(),
+        required=False,
+        help_text="Detailed predictions with confidence scores for each attribute"
+    )
+    error = serializers.CharField(required=False)
+
+
 class ProductAttributeResultSerializer(serializers.Serializer):
     """Serializer for individual product extraction result."""
     product_id = serializers.CharField(required=False)
     mandatory = serializers.DictField()
     additional = serializers.DictField(required=False)
     ocr_results = OCRResultSerializer(required=False)
+    visual_results = VisualResultSerializer(required=False)
     error = serializers.CharField(required=False)
     raw_output = serializers.CharField(required=False)
 
@@ -390,12 +409,4 @@ class ProductTypeSerializer(serializers.ModelSerializer):
     
     class Meta:
         model = ProductType
-        fields = ['name', 'attributes']
-
-
-
-
-
-
-
-        
+        fields = ['name', 'attributes']

+ 294 - 4
attr_extraction/views.py

@@ -19,11 +19,30 @@ from .services import ProductAttributeService
 from .ocr_service import OCRService
 
 
+
+# Sample test images (publicly available)
+SAMPLE_IMAGES = {
+    "tshirt": "https://images.unsplash.com/photo-1521572163474-6864f9cf17ab",
+    "dress": "https://images.unsplash.com/photo-1595777457583-95e059d581b8",
+    "jeans": "https://images.unsplash.com/photo-1542272604-787c3835535d"
+}
+
+# ==================== Updated views.py ====================
+from rest_framework.views import APIView
+from rest_framework.response import Response
+from rest_framework import status
+from .models import Product
+from .services import ProductAttributeService
+from .ocr_service import OCRService
+from .visual_processing_service import VisualProcessingService
+
+
 class ExtractProductAttributesView(APIView):
     """
     API endpoint to extract product attributes for a single product by item_id.
     Fetches product details from database with source tracking.
     Returns attributes in array format: [{"value": "...", "source": "..."}]
+    Includes OCR and Visual Processing results.
     """
 
     def post(self, request):
@@ -52,8 +71,10 @@ class ExtractProductAttributesView(APIView):
         # Process image for OCR if required
         ocr_results = None
         ocr_text = None
+        visual_results = None
 
         if validated_data.get("process_image", True) and image_url:
+            # OCR Processing
             ocr_service = OCRService()
             ocr_results = ocr_service.process_image(image_url)
 
@@ -67,6 +88,11 @@ class ExtractProductAttributesView(APIView):
                     f"{item['text']} (confidence: {item['confidence']:.2f})"
                     for item in ocr_results["detected_text"]
                 ])
+            
+            # Visual Processing
+            visual_service = VisualProcessingService()
+            product_type_hint = product.product_type if hasattr(product, 'product_type') else None
+            visual_results = visual_service.process_image(image_url, product_type_hint)
 
         # Combine all product text with source tracking
         product_text, source_map = ProductAttributeService.combine_product_text(
@@ -94,6 +120,10 @@ class ExtractProductAttributesView(APIView):
         # Attach OCR results if available
         if ocr_results:
             result["ocr_results"] = ocr_results
+        
+        # Attach Visual Processing results if available
+        if visual_results:
+            result["visual_results"] = visual_results
 
         response_serializer = ProductAttributeResultSerializer(data=result)
         if response_serializer.is_valid():
@@ -107,6 +137,7 @@ class BatchExtractProductAttributesView(APIView):
     API endpoint to extract product attributes for multiple products in batch.
     Uses item-specific mandatory_attrs with source tracking.
     Returns attributes in array format: [{"value": "...", "source": "..."}]
+    Includes OCR and Visual Processing results.
     """
 
     def post(self, request):
@@ -170,20 +201,23 @@ class BatchExtractProductAttributesView(APIView):
 
             product = product_map[item_id]
             
-            try:
+            try: 
                 title = product.product_name
                 short_desc = product.product_short_description
                 long_desc = product.product_long_description
                 image_url = product.image_path
-
+                # image_url = "https://images.unsplash.com/photo-1595777457583-95e059d581b8"
                 ocr_results = None
                 ocr_text = None
+                visual_results = None
 
                 # Image Processing Logic
                 if process_image and image_url:
+                    # OCR Processing
                     ocr_service = OCRService()
                     ocr_results = ocr_service.process_image(image_url)
-
+                    print(f"OCR results for {item_id}: {ocr_results}")
+                    
                     if ocr_results and ocr_results.get("detected_text"):
                         ocr_attrs = ProductAttributeService.extract_attributes_from_ocr(
                             ocr_results, model
@@ -193,6 +227,12 @@ class BatchExtractProductAttributesView(APIView):
                             f"{item['text']} (confidence: {item['confidence']:.2f})"
                             for item in ocr_results["detected_text"]
                         ])
+                    
+                    # Visual Processing
+                    visual_service = VisualProcessingService()
+                    product_type_hint = product.product_type if hasattr(product, 'product_type') else None
+                    visual_results = visual_service.process_image(image_url, product_type_hint)
+                    print(f"Visual results for {item_id}: {visual_results.get('visual_attributes', {})}")
 
                 # Combine product text with source tracking
                 product_text, source_map = ProductAttributeService.combine_product_text(
@@ -213,7 +253,7 @@ class BatchExtractProductAttributesView(APIView):
                     source_map=source_map,
                     model=model,
                     extract_additional=extract_additional,
-                    multiple=multiple,  # Make sure this is passed!
+                    multiple=multiple,
                     threshold_abs=threshold_abs,
                     margin=margin,
                     use_dynamic_thresholds=use_dynamic_thresholds,
@@ -227,8 +267,13 @@ class BatchExtractProductAttributesView(APIView):
                     "additional": extracted.get("additional", {}),
                 }
 
+                # Attach OCR results if available
                 if ocr_results:
                     result["ocr_results"] = ocr_results
+                
+                # Attach Visual Processing results if available
+                if visual_results:
+                    result["visual_results"] = visual_results
 
                 results.append(result)
                 successful += 1
@@ -254,6 +299,251 @@ class BatchExtractProductAttributesView(APIView):
         return Response(batch_result, status=status.HTTP_200_OK)
 
 
+
+
+
+# class ExtractProductAttributesView(APIView):
+#     """
+#     API endpoint to extract product attributes for a single product by item_id.
+#     Fetches product details from database with source tracking.
+#     Returns attributes in array format: [{"value": "...", "source": "..."}]
+#     """
+
+#     def post(self, request):
+#         serializer = SingleProductRequestSerializer(data=request.data)
+#         if not serializer.is_valid():
+#             return Response({"error": serializer.errors}, status=status.HTTP_400_BAD_REQUEST)
+
+#         validated_data = serializer.validated_data
+#         item_id = validated_data.get("item_id")
+
+#         # Fetch product from DB
+#         try:
+#             product = Product.objects.get(item_id=item_id)
+#         except Product.DoesNotExist:
+#             return Response(
+#                 {"error": f"Product with item_id '{item_id}' not found."},
+#                 status=status.HTTP_404_NOT_FOUND
+#             )
+
+#         # Extract product details
+#         title = product.product_name
+#         short_desc = product.product_short_description
+#         long_desc = product.product_long_description
+#         image_url = product.image_path
+
+#         # Process image for OCR if required
+#         ocr_results = None
+#         ocr_text = None
+
+#         if validated_data.get("process_image", True) and image_url:
+#             ocr_service = OCRService()
+#             ocr_results = ocr_service.process_image(image_url)
+
+#             if ocr_results and ocr_results.get("detected_text"):
+#                 ocr_attrs = ProductAttributeService.extract_attributes_from_ocr(
+#                     ocr_results, validated_data.get("model")
+#                 )
+#                 ocr_results["extracted_attributes"] = ocr_attrs
+
+#                 ocr_text = "\n".join([
+#                     f"{item['text']} (confidence: {item['confidence']:.2f})"
+#                     for item in ocr_results["detected_text"]
+#                 ])
+
+#         # Combine all product text with source tracking
+#         product_text, source_map = ProductAttributeService.combine_product_text(
+#             title=title,
+#             short_desc=short_desc,
+#             long_desc=long_desc,
+#             ocr_text=ocr_text
+#         )
+
+#         # Extract attributes with enhanced features and source tracking
+#         result = ProductAttributeService.extract_attributes(
+#             product_text=product_text,
+#             mandatory_attrs=validated_data["mandatory_attrs"],
+#             source_map=source_map,
+#             model=validated_data.get("model"),
+#             extract_additional=validated_data.get("extract_additional", True),
+#             multiple=validated_data.get("multiple", []),
+#             threshold_abs=validated_data.get("threshold_abs", 0.65),
+#             margin=validated_data.get("margin", 0.15),
+#             use_dynamic_thresholds=validated_data.get("use_dynamic_thresholds", True),
+#             use_adaptive_margin=validated_data.get("use_adaptive_margin", True),
+#             use_semantic_clustering=validated_data.get("use_semantic_clustering", True)
+#         )
+
+#         # Attach OCR results if available
+#         if ocr_results:
+#             result["ocr_results"] = ocr_results
+
+#         response_serializer = ProductAttributeResultSerializer(data=result)
+#         if response_serializer.is_valid():
+#             return Response(response_serializer.data, status=status.HTTP_200_OK)
+
+#         return Response(result, status=status.HTTP_200_OK)
+
+
+# class BatchExtractProductAttributesView(APIView):
+#     """
+#     API endpoint to extract product attributes for multiple products in batch.
+#     Uses item-specific mandatory_attrs with source tracking.
+#     Returns attributes in array format: [{"value": "...", "source": "..."}]
+#     """
+
+#     def post(self, request):
+#         serializer = BatchProductRequestSerializer(data=request.data)
+#         if not serializer.is_valid():
+#             return Response({"error": serializer.errors}, status=status.HTTP_400_BAD_REQUEST)
+
+#         validated_data = serializer.validated_data
+        
+#         # DEBUG: Print what we received
+#         print("\n" + "="*80)
+#         print("BATCH REQUEST - RECEIVED DATA")
+#         print("="*80)
+#         print(f"Raw request data keys: {request.data.keys()}")
+#         print(f"Multiple field in request: {request.data.get('multiple')}")
+#         print(f"Validated multiple field: {validated_data.get('multiple')}")
+#         print("="*80 + "\n")
+        
+#         # Get batch-level settings
+#         product_list = validated_data.get("products", [])
+#         model = validated_data.get("model")
+#         extract_additional = validated_data.get("extract_additional", True)
+#         process_image = validated_data.get("process_image", True)
+#         multiple = validated_data.get("multiple", [])
+#         threshold_abs = validated_data.get("threshold_abs", 0.65)
+#         margin = validated_data.get("margin", 0.15)
+#         use_dynamic_thresholds = validated_data.get("use_dynamic_thresholds", True)
+#         use_adaptive_margin = validated_data.get("use_adaptive_margin", True)
+#         use_semantic_clustering = validated_data.get("use_semantic_clustering", True)
+        
+#         # DEBUG: Print extracted settings
+#         print(f"Extracted multiple parameter: {multiple}")
+#         print(f"Type: {type(multiple)}")
+        
+#         # Extract all item_ids to query the database efficiently
+#         item_ids = [p['item_id'] for p in product_list] 
+        
+#         # Fetch all products in one query
+#         products_queryset = Product.objects.filter(item_id__in=item_ids)
+        
+#         # Create a dictionary for easy lookup: item_id -> Product object
+#         product_map = {product.item_id: product for product in products_queryset}
+#         found_ids = set(product_map.keys())
+        
+#         results = []
+#         successful = 0
+#         failed = 0
+
+#         for product_entry in product_list:
+#             item_id = product_entry['item_id']
+#             # Get item-specific mandatory attributes
+#             mandatory_attrs = product_entry['mandatory_attrs'] 
+
+#             if item_id not in found_ids:
+#                 failed += 1
+#                 results.append({
+#                     "product_id": item_id,
+#                     "error": "Product not found in database"
+#                 })
+#                 continue
+
+#             product = product_map[item_id]
+            
+#             try: 
+#                 title = product.product_name
+#                 short_desc = product.product_short_description
+#                 long_desc = product.product_long_description
+#                 # image_url = product.image_path
+#                 image_url = "http://localhost:8000/media/products/levi_test_ocr2.jpg"
+#                 ocr_results = None
+#                 ocr_text = None
+
+#                 # Image Processing Logic
+#                 if process_image and image_url:
+#                     ocr_service = OCRService()
+#                     ocr_results = ocr_service.process_image(image_url)
+#                     print(f"ocr results are: {ocr_results}")
+
+#                     if ocr_results and ocr_results.get("detected_text"):
+#                         ocr_attrs = ProductAttributeService.extract_attributes_from_ocr(
+#                             ocr_results, model
+#                         )
+#                         ocr_results["extracted_attributes"] = ocr_attrs
+#                         ocr_text = "\n".join([
+#                             f"{item['text']} (confidence: {item['confidence']:.2f})"
+#                             for item in ocr_results["detected_text"]
+#                         ])
+
+#                 # Combine product text with source tracking
+#                 product_text, source_map = ProductAttributeService.combine_product_text(
+#                     title=title,
+#                     short_desc=short_desc,
+#                     long_desc=long_desc,
+#                     ocr_text=ocr_text
+#                 )
+
+#                 # DEBUG: Print before extraction
+#                 print(f"\n>>> Extracting for product {item_id}")
+#                 print(f"    Passing multiple: {multiple}")
+
+#                 # Attribute Extraction with source tracking (returns array format)
+#                 extracted = ProductAttributeService.extract_attributes(
+#                     product_text=product_text,
+#                     mandatory_attrs=mandatory_attrs,
+#                     source_map=source_map,
+#                     model=model,
+#                     extract_additional=extract_additional,
+#                     multiple=multiple,  # Make sure this is passed!
+#                     threshold_abs=threshold_abs,
+#                     margin=margin,
+#                     use_dynamic_thresholds=use_dynamic_thresholds,
+#                     use_adaptive_margin=use_adaptive_margin,
+#                     use_semantic_clustering=use_semantic_clustering
+#                 )
+
+#                 result = {
+#                     "product_id": product.item_id,
+#                     "mandatory": extracted.get("mandatory", {}),
+#                     "additional": extracted.get("additional", {}),
+#                 }
+
+#                 if ocr_results:
+#                     result["ocr_results"] = ocr_results
+
+#                 results.append(result)
+#                 successful += 1
+
+#             except Exception as e:
+#                 failed += 1
+#                 results.append({
+#                     "product_id": item_id,
+#                     "error": str(e)
+#                 })
+
+#         batch_result = {
+#             "results": results,
+#             "total_products": len(product_list),
+#             "successful": successful,
+#             "failed": failed
+#         }
+
+#         response_serializer = BatchProductResponseSerializer(data=batch_result)
+#         if response_serializer.is_valid():
+#             return Response(response_serializer.data, status=status.HTTP_200_OK)
+
+#         return Response(batch_result, status=status.HTTP_200_OK)
+
+
+
+
+
+
+
+
 class ProductListView(APIView):
     """
     GET API to list all products with details

+ 1193 - 0
attr_extraction/visual_processing_service.py

@@ -0,0 +1,1193 @@
+# # ==================== visual_processing_service.py ====================
+# import torch
+# import cv2
+# import numpy as np
+# import requests
+# from io import BytesIO
+# from PIL import Image
+# from typing import Dict, List, Optional, Tuple
+# import logging
+# from transformers import CLIPProcessor, CLIPModel, AutoImageProcessor, AutoModelForImageClassification
+# from sklearn.cluster import KMeans
+# import webcolors
+
+# logger = logging.getLogger(__name__)
+
+
+# class VisualProcessingService:
+#     """Service for extracting visual attributes from product images using CLIP and computer vision."""
+    
+#     def __init__(self):
+#         self.clip_model = None
+#         self.clip_processor = None
+#         self.classification_model = None
+#         self.classification_processor = None
+        
+#     def _get_clip_model(self):
+#         """Lazy load CLIP model."""
+#         if self.clip_model is None:
+#             logger.info("Loading CLIP model...")
+#             self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
+#             self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
+#             self.clip_model.eval()
+#         return self.clip_model, self.clip_processor
+    
+#     def _get_classification_model(self):
+#         """Lazy load image classification model for product categories."""
+#         if self.classification_model is None:
+#             logger.info("Loading classification model...")
+#             # Using Google's ViT model fine-tuned on fashion/products
+#             self.classification_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
+#             self.classification_model = AutoModelForImageClassification.from_pretrained(
+#                 "google/vit-base-patch16-224"
+#             )
+#             self.classification_model.eval()
+#         return self.classification_model, self.classification_processor
+    
+#     def download_image(self, image_url: str) -> Optional[Image.Image]:
+#         """Download image from URL."""
+#         try:
+#             response = requests.get(image_url, timeout=10)
+#             response.raise_for_status()
+#             image = Image.open(BytesIO(response.content)).convert('RGB')
+#             return image
+#         except Exception as e:
+#             logger.error(f"Error downloading image from {image_url}: {str(e)}")
+#             return None
+    
+#     def extract_dominant_colors(self, image: Image.Image, n_colors: int = 3) -> List[Dict]:
+#         """Extract dominant colors from image using K-means clustering."""
+#         try:
+#             # Resize image for faster processing
+#             img_small = image.resize((150, 150))
+#             img_array = np.array(img_small)
+            
+#             # Reshape to pixels
+#             pixels = img_array.reshape(-1, 3)
+            
+#             # Apply K-means
+#             kmeans = KMeans(n_clusters=n_colors, random_state=42, n_init=10)
+#             kmeans.fit(pixels)
+            
+#             colors = []
+#             for center in kmeans.cluster_centers_:
+#                 rgb = tuple(center.astype(int))
+#                 color_name = self._get_color_name(rgb)
+#                 colors.append({
+#                     "name": color_name,
+#                     "rgb": rgb,
+#                     "percentage": float(np.sum(kmeans.labels_ == len(colors)) / len(kmeans.labels_) * 100)
+#                 })
+            
+#             # Sort by percentage
+#             colors.sort(key=lambda x: x['percentage'], reverse=True)
+#             return colors
+            
+#         except Exception as e:
+#             logger.error(f"Error extracting colors: {str(e)}")
+#             return []
+    
+#     def _get_color_name(self, rgb: Tuple[int, int, int]) -> str:
+#         """Convert RGB to closest color name."""
+#         try:
+#             # Try to get exact match
+#             color_name = webcolors.rgb_to_name(rgb)
+#             return color_name
+#         except ValueError:
+#             # Find closest color
+#             min_distance = float('inf')
+#             closest_name = 'unknown'
+            
+#             for name in webcolors.CSS3_NAMES_TO_HEX:
+#                 hex_color = webcolors.CSS3_NAMES_TO_HEX[name]
+#                 r, g, b = webcolors.hex_to_rgb(hex_color)
+#                 distance = sum((c1 - c2) ** 2 for c1, c2 in zip(rgb, (r, g, b)))
+                
+#                 if distance < min_distance:
+#                     min_distance = distance
+#                     closest_name = name
+            
+#             return closest_name
+    
+#     def classify_with_clip(self, image: Image.Image, candidates: List[str], attribute_name: str) -> Dict:
+#         """Use CLIP to classify image against candidate labels."""
+#         try:
+#             model, processor = self._get_clip_model()
+            
+#             # Prepare inputs
+#             inputs = processor(
+#                 text=candidates,
+#                 images=image,
+#                 return_tensors="pt",
+#                 padding=True
+#             )
+            
+#             # Get predictions
+#             with torch.no_grad():
+#                 outputs = model(**inputs)
+#                 logits_per_image = outputs.logits_per_image
+#                 probs = logits_per_image.softmax(dim=1)
+            
+#             # Get top predictions
+#             top_probs, top_indices = torch.topk(probs[0], k=min(3, len(candidates)))
+            
+#             results = []
+#             for prob, idx in zip(top_probs, top_indices):
+#                 if prob.item() > 0.15:  # Confidence threshold
+#                     results.append({
+#                         "value": candidates[idx.item()],
+#                         "confidence": float(prob.item())
+#                     })
+            
+#             return {
+#                 "attribute": attribute_name,
+#                 "predictions": results
+#             }
+            
+#         except Exception as e:
+#             logger.error(f"Error in CLIP classification: {str(e)}")
+#             return {"attribute": attribute_name, "predictions": []}
+    
+#     def detect_patterns(self, image: Image.Image) -> Dict:
+#         """Detect patterns in the image using edge detection and texture analysis."""
+#         try:
+#             # Convert to OpenCV format
+#             img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
+#             gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
+            
+#             # Calculate edge density
+#             edges = cv2.Canny(gray, 50, 150)
+#             edge_density = np.sum(edges > 0) / (edges.shape[0] * edges.shape[1])
+            
+#             # Calculate texture variance
+#             laplacian = cv2.Laplacian(gray, cv2.CV_64F)
+#             texture_variance = laplacian.var()
+            
+#             # Determine pattern type based on metrics
+#             pattern_candidates = []
+            
+#             if edge_density > 0.15:
+#                 pattern_candidates.append("geometric")
+#             if texture_variance > 500:
+#                 pattern_candidates.append("textured")
+#             if edge_density < 0.05 and texture_variance < 200:
+#                 pattern_candidates.append("solid")
+            
+#             # Use CLIP for more detailed pattern detection
+#             pattern_types = [
+#                 "solid color", "striped", "checkered", "polka dot", "floral",
+#                 "geometric", "abstract", "graphic print", "camouflage", "paisley"
+#             ]
+            
+#             clip_result = self.classify_with_clip(image, pattern_types, "pattern")
+            
+#             return clip_result
+            
+#         except Exception as e:
+#             logger.error(f"Error detecting patterns: {str(e)}")
+#             return {"attribute": "pattern", "predictions": []}
+    
+#     def detect_material(self, image: Image.Image) -> Dict:
+#         """Detect material type using CLIP."""
+#         materials = [
+#             "cotton", "polyester", "denim", "leather", "silk", "wool",
+#             "linen", "satin", "velvet", "fleece", "knit", "jersey",
+#             "canvas", "nylon", "suede", "corduroy"
+#         ]
+        
+#         return self.classify_with_clip(image, materials, "material")
+    
+#     def detect_style(self, image: Image.Image) -> Dict:
+#         """Detect style/occasion using CLIP."""
+#         styles = [
+#             "casual", "formal", "sporty", "business", "vintage", "modern",
+#             "bohemian", "streetwear", "elegant", "preppy", "athletic",
+#             "loungewear", "party", "workwear", "outdoor"
+#         ]
+        
+#         return self.classify_with_clip(image, styles, "style")
+    
+#     def detect_fit(self, image: Image.Image) -> Dict:
+#         """Detect clothing fit using CLIP."""
+#         fits = [
+#             "slim fit", "regular fit", "loose fit", "oversized",
+#             "tight", "relaxed", "tailored", "athletic fit"
+#         ]
+        
+#         return self.classify_with_clip(image, fits, "fit")
+    
+#     def detect_neckline(self, image: Image.Image, product_type: str) -> Dict:
+#         """Detect neckline type for tops using CLIP."""
+#         if product_type.lower() not in ['shirt', 't-shirt', 'top', 'blouse', 'dress', 'sweater']:
+#             return {"attribute": "neckline", "predictions": []}
+        
+#         necklines = [
+#             "crew neck", "v-neck", "round neck", "collar", "turtleneck",
+#             "scoop neck", "boat neck", "off-shoulder", "square neck", "halter"
+#         ]
+        
+#         return self.classify_with_clip(image, necklines, "neckline")
+    
+#     def detect_sleeve_type(self, image: Image.Image, product_type: str) -> Dict:
+#         """Detect sleeve type using CLIP."""
+#         if product_type.lower() not in ['shirt', 't-shirt', 'top', 'blouse', 'dress', 'sweater', 'jacket']:
+#             return {"attribute": "sleeve_type", "predictions": []}
+        
+#         sleeves = [
+#             "short sleeve", "long sleeve", "sleeveless", "three-quarter sleeve",
+#             "cap sleeve", "flutter sleeve", "bell sleeve", "raglan sleeve"
+#         ]
+        
+#         return self.classify_with_clip(image, sleeves, "sleeve_type")
+    
+#     def detect_product_type(self, image: Image.Image) -> Dict:
+#         """Detect product type using CLIP."""
+#         product_types = [
+#             "t-shirt", "shirt", "dress", "pants", "jeans", "shorts",
+#             "skirt", "jacket", "coat", "sweater", "hoodie", "blazer",
+#             "suit", "jumpsuit", "romper", "cardigan", "vest", "top",
+#             "blouse", "tank top", "polo shirt", "sweatshirt"
+#         ]
+        
+#         return self.classify_with_clip(image, product_types, "product_type")
+    
+#     def detect_closure_type(self, image: Image.Image) -> Dict:
+#         """Detect closure type using CLIP."""
+#         closures = [
+#             "button", "zipper", "snap", "hook and eye", "velcro",
+#             "lace-up", "pull-on", "elastic", "tie", "buckle"
+#         ]
+        
+#         return self.classify_with_clip(image, closures, "closure_type")
+    
+#     def detect_length(self, image: Image.Image, product_type: str) -> Dict:
+#         """Detect garment length using CLIP."""
+#         if product_type.lower() in ['pants', 'jeans', 'trousers']:
+#             lengths = ["full length", "ankle length", "cropped", "capri", "shorts"]
+#         elif product_type.lower() in ['skirt', 'dress']:
+#             lengths = ["mini", "knee length", "midi", "maxi", "floor length"]
+#         elif product_type.lower() in ['jacket', 'coat']:
+#             lengths = ["waist length", "hip length", "thigh length", "knee length", "full length"]
+#         else:
+#             lengths = ["short", "regular", "long"]
+        
+#         return self.classify_with_clip(image, lengths, "length")
+    
+#     def process_image(self, image_url: str, product_type_hint: Optional[str] = None) -> Dict:
+#         """
+#         Main method to process image and extract all visual attributes.
+#         """
+#         try:
+#             # Download image
+#             image = self.download_image(image_url)
+#             if image is None:
+#                 return {
+#                     "visual_attributes": {},
+#                     "error": "Failed to download image"
+#                 }
+            
+#             # Extract all attributes
+#             visual_attributes = {}
+            
+#             # 1. Product Type Detection
+#             logger.info("Detecting product type...")
+#             product_type_result = self.detect_product_type(image)
+#             if product_type_result["predictions"]:
+#                 visual_attributes["product_type"] = product_type_result["predictions"][0]["value"]
+#                 detected_product_type = visual_attributes["product_type"]
+#             else:
+#                 detected_product_type = product_type_hint or "unknown"
+            
+#             # 2. Color Detection
+#             logger.info("Extracting colors...")
+#             colors = self.extract_dominant_colors(image, n_colors=3)
+#             if colors:
+#                 visual_attributes["primary_color"] = colors[0]["name"]
+#                 visual_attributes["color_palette"] = [c["name"] for c in colors]
+#                 visual_attributes["color_details"] = colors
+            
+#             # 3. Pattern Detection
+#             logger.info("Detecting patterns...")
+#             pattern_result = self.detect_patterns(image)
+#             if pattern_result["predictions"]:
+#                 visual_attributes["pattern"] = pattern_result["predictions"][0]["value"]
+            
+#             # 4. Material Detection
+#             logger.info("Detecting material...")
+#             material_result = self.detect_material(image)
+#             if material_result["predictions"]:
+#                 visual_attributes["material"] = material_result["predictions"][0]["value"]
+            
+#             # 5. Style Detection
+#             logger.info("Detecting style...")
+#             style_result = self.detect_style(image)
+#             if style_result["predictions"]:
+#                 visual_attributes["style"] = style_result["predictions"][0]["value"]
+            
+#             # 6. Fit Detection
+#             logger.info("Detecting fit...")
+#             fit_result = self.detect_fit(image)
+#             if fit_result["predictions"]:
+#                 visual_attributes["fit"] = fit_result["predictions"][0]["value"]
+            
+#             # 7. Neckline Detection (if applicable)
+#             logger.info("Detecting neckline...")
+#             neckline_result = self.detect_neckline(image, detected_product_type)
+#             if neckline_result["predictions"]:
+#                 visual_attributes["neckline"] = neckline_result["predictions"][0]["value"]
+            
+#             # 8. Sleeve Type Detection (if applicable)
+#             logger.info("Detecting sleeve type...")
+#             sleeve_result = self.detect_sleeve_type(image, detected_product_type)
+#             if sleeve_result["predictions"]:
+#                 visual_attributes["sleeve_type"] = sleeve_result["predictions"][0]["value"]
+            
+#             # 9. Closure Type Detection
+#             logger.info("Detecting closure type...")
+#             closure_result = self.detect_closure_type(image)
+#             if closure_result["predictions"]:
+#                 visual_attributes["closure_type"] = closure_result["predictions"][0]["value"]
+            
+#             # 10. Length Detection
+#             logger.info("Detecting length...")
+#             length_result = self.detect_length(image, detected_product_type)
+#             if length_result["predictions"]:
+#                 visual_attributes["length"] = length_result["predictions"][0]["value"]
+            
+#             # Format response
+#             return {
+#                 "visual_attributes": visual_attributes,
+#                 "detailed_predictions": {
+#                     "product_type": product_type_result,
+#                     "pattern": pattern_result,
+#                     "material": material_result,
+#                     "style": style_result,
+#                     "fit": fit_result,
+#                     "neckline": neckline_result,
+#                     "sleeve_type": sleeve_result,
+#                     "closure_type": closure_result,
+#                     "length": length_result
+#                 }
+#             }
+            
+#         except Exception as e:
+#             logger.error(f"Error processing image: {str(e)}")
+#             return {
+#                 "visual_attributes": {},
+#                 "error": str(e)
+#             }
+
+
+
+
+
+
+# # ==================== visual_processing_service_optimized.py ====================
+# """
+# Optimized version with:
+# - Result caching
+# - Batch processing support
+# - Memory management
+# - Error recovery
+# - Performance monitoring
+# """
+
+# import torch
+# import cv2
+# import numpy as np
+# import requests
+# from io import BytesIO
+# from PIL import Image
+# from typing import Dict, List, Optional, Tuple
+# import logging
+# import time
+# import hashlib
+# from functools import lru_cache
+# from transformers import CLIPProcessor, CLIPModel
+# from sklearn.cluster import KMeans
+# import webcolors
+
+# logger = logging.getLogger(__name__)
+
+
+# class VisualProcessingService:
+#     """Optimized service for extracting visual attributes from product images."""
+    
+#     # Class-level model caching (shared across instances)
+#     _clip_model = None
+#     _clip_processor = None
+#     _model_device = None
+    
+#     def __init__(self, use_cache: bool = True, cache_ttl: int = 3600):
+#         self.use_cache = use_cache
+#         self.cache_ttl = cache_ttl
+#         self._cache = {}
+        
+#     @classmethod
+#     def _get_device(cls):
+#         """Get optimal device (GPU if available, else CPU)."""
+#         if cls._model_device is None:
+#             cls._model_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+#             logger.info(f"Using device: {cls._model_device}")
+#         return cls._model_device
+    
+#     @classmethod
+#     def _get_clip_model(cls):
+#         """Lazy load CLIP model with class-level caching."""
+#         if cls._clip_model is None:
+#             logger.info("Loading CLIP model...")
+#             start_time = time.time()
+            
+#             try:
+#                 cls._clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
+#                 cls._clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
+                
+#                 device = cls._get_device()
+#                 cls._clip_model.to(device)
+#                 cls._clip_model.eval()
+                
+#                 logger.info(f"CLIP model loaded in {time.time() - start_time:.2f}s")
+#             except Exception as e:
+#                 logger.error(f"Failed to load CLIP model: {str(e)}")
+#                 raise
+                
+#         return cls._clip_model, cls._clip_processor
+    
+#     def _get_cache_key(self, image_url: str, operation: str) -> str:
+#         """Generate cache key for results."""
+#         url_hash = hashlib.md5(image_url.encode()).hexdigest()
+#         return f"visual_{operation}_{url_hash}"
+    
+#     def _get_cached(self, key: str) -> Optional[Dict]:
+#         """Get cached result if available and not expired."""
+#         if not self.use_cache:
+#             return None
+            
+#         if key in self._cache:
+#             result, timestamp = self._cache[key]
+#             if time.time() - timestamp < self.cache_ttl:
+#                 return result
+#             else:
+#                 del self._cache[key]
+#         return None
+    
+#     def _set_cached(self, key: str, value: Dict):
+#         """Cache result with timestamp."""
+#         if self.use_cache:
+#             self._cache[key] = (value, time.time())
+    
+#     def download_image(self, image_url: str, max_size: Tuple[int, int] = (1024, 1024)) -> Optional[Image.Image]:
+#         """Download and optionally resize image for faster processing."""
+#         try:
+#             response = requests.get(image_url, timeout=10)
+#             response.raise_for_status()
+            
+#             image = Image.open(BytesIO(response.content)).convert('RGB')
+            
+#             # Resize if image is too large
+#             if image.size[0] > max_size[0] or image.size[1] > max_size[1]:
+#                 image.thumbnail(max_size, Image.Resampling.LANCZOS)
+#                 logger.info(f"Resized image from original size to {image.size}")
+            
+#             return image
+            
+#         except Exception as e:
+#             logger.error(f"Error downloading image from {image_url}: {str(e)}")
+#             return None
+    
+#     @lru_cache(maxsize=100)
+#     def _get_color_name_cached(self, rgb: Tuple[int, int, int]) -> str:
+#         """Cached version of color name lookup."""
+#         try:
+#             return webcolors.rgb_to_name(rgb)
+#         except ValueError:
+#             min_distance = float('inf')
+#             closest_name = 'unknown'
+            
+#             for name in webcolors.CSS3_NAMES_TO_HEX:
+#                 hex_color = webcolors.CSS3_NAMES_TO_HEX[name]
+#                 r, g, b = webcolors.hex_to_rgb(hex_color)
+#                 distance = sum((c1 - c2) ** 2 for c1, c2 in zip(rgb, (r, g, b)))
+                
+#                 if distance < min_distance:
+#                     min_distance = distance
+#                     closest_name = name
+            
+#             return closest_name
+    
+#     def extract_dominant_colors(self, image: Image.Image, n_colors: int = 3) -> List[Dict]:
+#         """Extract dominant colors with optimized K-means."""
+#         try:
+#             # Resize for faster processing
+#             img_small = image.resize((150, 150))
+#             img_array = np.array(img_small)
+#             pixels = img_array.reshape(-1, 3)
+            
+#             # Sample pixels if too many
+#             if len(pixels) > 10000:
+#                 indices = np.random.choice(len(pixels), 10000, replace=False)
+#                 pixels = pixels[indices]
+            
+#             # K-means with optimized parameters
+#             kmeans = KMeans(
+#                 n_clusters=n_colors,
+#                 random_state=42,
+#                 n_init=5,  # Reduced from 10 for speed
+#                 max_iter=100,
+#                 algorithm='elkan'  # Faster for low dimensions
+#             )
+#             kmeans.fit(pixels)
+            
+#             colors = []
+#             labels_counts = np.bincount(kmeans.labels_)
+            
+#             for i, center in enumerate(kmeans.cluster_centers_):
+#                 rgb = tuple(center.astype(int))
+#                 color_name = self._get_color_name_cached(rgb)
+#                 percentage = float(labels_counts[i] / len(kmeans.labels_) * 100)
+                
+#                 colors.append({
+#                     "name": color_name,
+#                     "rgb": rgb,
+#                     "percentage": percentage
+#                 })
+            
+#             colors.sort(key=lambda x: x['percentage'], reverse=True)
+#             return colors
+            
+#         except Exception as e:
+#             logger.error(f"Error extracting colors: {str(e)}")
+#             return []
+    
+#     def classify_with_clip(
+#         self,
+#         image: Image.Image,
+#         candidates: List[str],
+#         attribute_name: str,
+#         confidence_threshold: float = 0.15
+#     ) -> Dict:
+#         """Optimized CLIP classification with batching."""
+#         try:
+#             model, processor = self._get_clip_model()
+#             device = self._get_device()
+            
+#             # Prepare inputs
+#             inputs = processor(
+#                 text=candidates,
+#                 images=image,
+#                 return_tensors="pt",
+#                 padding=True
+#             )
+            
+#             # Move to device
+#             inputs = {k: v.to(device) for k, v in inputs.items()}
+            
+#             # Get predictions with no_grad for speed
+#             with torch.no_grad():
+#                 outputs = model(**inputs)
+#                 logits_per_image = outputs.logits_per_image
+#                 probs = logits_per_image.softmax(dim=1).cpu()
+            
+#             # Get top predictions
+#             top_k = min(3, len(candidates))
+#             top_probs, top_indices = torch.topk(probs[0], k=top_k)
+            
+#             results = []
+#             for prob, idx in zip(top_probs, top_indices):
+#                 if prob.item() > confidence_threshold:
+#                     results.append({
+#                         "value": candidates[idx.item()],
+#                         "confidence": float(prob.item())
+#                     })
+            
+#             return {
+#                 "attribute": attribute_name,
+#                 "predictions": results
+#             }
+            
+#         except Exception as e:
+#             logger.error(f"Error in CLIP classification for {attribute_name}: {str(e)}")
+#             return {"attribute": attribute_name, "predictions": []}
+    
+#     def batch_classify(
+#         self,
+#         image: Image.Image,
+#         attribute_configs: List[Dict[str, any]]
+#     ) -> Dict[str, Dict]:
+#         """
+#         Batch multiple CLIP classifications for efficiency.
+#         attribute_configs: [{"name": "pattern", "candidates": [...], "threshold": 0.15}, ...]
+#         """
+#         results = {}
+        
+#         for config in attribute_configs:
+#             attr_name = config["name"]
+#             candidates = config["candidates"]
+#             threshold = config.get("threshold", 0.15)
+            
+#             result = self.classify_with_clip(image, candidates, attr_name, threshold)
+#             results[attr_name] = result
+        
+#         return results
+    
+#     def detect_patterns(self, image: Image.Image) -> Dict:
+#         """Detect patterns using CLIP."""
+#         pattern_types = [
+#             "solid color", "striped", "checkered", "polka dot", "floral",
+#             "geometric", "abstract", "graphic print", "camouflage", "paisley"
+#         ]
+#         return self.classify_with_clip(image, pattern_types, "pattern")
+    
+#     def process_image(
+#         self,
+#         image_url: str,
+#         product_type_hint: Optional[str] = None,
+#         attributes_to_extract: Optional[List[str]] = None
+#     ) -> Dict:
+#         """
+#         Main method with caching and selective attribute extraction.
+        
+#         Args:
+#             image_url: URL of the product image
+#             product_type_hint: Optional hint about product type
+#             attributes_to_extract: List of attributes to extract (None = all)
+#         """
+#         # Check cache
+#         cache_key = self._get_cache_key(image_url, "full")
+#         cached_result = self._get_cached(cache_key)
+#         if cached_result:
+#             logger.info(f"Returning cached result for {image_url}")
+#             return cached_result
+        
+#         start_time = time.time()
+        
+#         try:
+#             # Download image
+#             image = self.download_image(image_url)
+#             if image is None:
+#                 return {
+#                     "visual_attributes": {},
+#                     "error": "Failed to download image"
+#                 }
+            
+#             visual_attributes = {}
+#             detailed_predictions = {}
+            
+#             # Default: extract all attributes
+#             if attributes_to_extract is None:
+#                 attributes_to_extract = [
+#                     "product_type", "color", "pattern", "material",
+#                     "style", "fit", "neckline", "sleeve_type",
+#                     "closure_type", "length"
+#                 ]
+            
+#             # 1. Product Type Detection
+#             if "product_type" in attributes_to_extract:
+#                 logger.info("Detecting product type...")
+#                 product_types = [
+#                     "t-shirt", "shirt", "dress", "pants", "jeans", "shorts",
+#                     "skirt", "jacket", "coat", "sweater", "hoodie", "blazer",
+#                     "top", "blouse"
+#                 ]
+#                 product_type_result = self.classify_with_clip(image, product_types, "product_type")
+#                 if product_type_result["predictions"]:
+#                     visual_attributes["product_type"] = product_type_result["predictions"][0]["value"]
+#                     detected_product_type = visual_attributes["product_type"]
+#                 else:
+#                     detected_product_type = product_type_hint or "unknown"
+#                 detailed_predictions["product_type"] = product_type_result
+#             else:
+#                 detected_product_type = product_type_hint or "unknown"
+            
+#             # 2. Color Detection
+#             if "color" in attributes_to_extract:
+#                 logger.info("Extracting colors...")
+#                 colors = self.extract_dominant_colors(image, n_colors=3)
+#                 if colors:
+#                     visual_attributes["primary_color"] = colors[0]["name"]
+#                     visual_attributes["color_palette"] = [c["name"] for c in colors]
+#                     visual_attributes["color_details"] = colors
+            
+#             # 3. Batch classify remaining attributes
+#             batch_configs = []
+            
+#             if "pattern" in attributes_to_extract:
+#                 batch_configs.append({
+#                     "name": "pattern",
+#                     "candidates": [
+#                         "solid color", "striped", "checkered", "polka dot",
+#                         "floral", "geometric", "abstract", "graphic print"
+#                     ]
+#                 })
+            
+#             if "material" in attributes_to_extract:
+#                 batch_configs.append({
+#                     "name": "material",
+#                     "candidates": [
+#                         "cotton", "polyester", "denim", "leather", "silk",
+#                         "wool", "linen", "satin", "fleece", "knit"
+#                     ]
+#                 })
+            
+#             if "style" in attributes_to_extract:
+#                 batch_configs.append({
+#                     "name": "style",
+#                     "candidates": [
+#                         "casual", "formal", "sporty", "business", "vintage",
+#                         "modern", "streetwear", "elegant", "athletic"
+#                     ]
+#                 })
+            
+#             if "fit" in attributes_to_extract:
+#                 batch_configs.append({
+#                     "name": "fit",
+#                     "candidates": [
+#                         "slim fit", "regular fit", "loose fit", "oversized",
+#                         "relaxed", "tailored"
+#                     ]
+#                 })
+            
+#             # Product-type specific attributes
+#             if detected_product_type.lower() in ['shirt', 't-shirt', 'top', 'blouse', 'dress', 'sweater']:
+#                 if "neckline" in attributes_to_extract:
+#                     batch_configs.append({
+#                         "name": "neckline",
+#                         "candidates": [
+#                             "crew neck", "v-neck", "round neck", "collar",
+#                             "turtleneck", "scoop neck", "boat neck"
+#                         ]
+#                     })
+                
+#                 if "sleeve_type" in attributes_to_extract:
+#                     batch_configs.append({
+#                         "name": "sleeve_type",
+#                         "candidates": [
+#                             "short sleeve", "long sleeve", "sleeveless",
+#                             "three-quarter sleeve", "cap sleeve"
+#                         ]
+#                     })
+            
+#             if "closure_type" in attributes_to_extract:
+#                 batch_configs.append({
+#                     "name": "closure_type",
+#                     "candidates": [
+#                         "button", "zipper", "snap", "pull-on",
+#                         "lace-up", "elastic", "buckle"
+#                     ]
+#                 })
+            
+#             if "length" in attributes_to_extract:
+#                 if detected_product_type.lower() in ['pants', 'jeans', 'trousers']:
+#                     batch_configs.append({
+#                         "name": "length",
+#                         "candidates": ["full length", "ankle length", "cropped", "capri", "shorts"]
+#                     })
+#                 elif detected_product_type.lower() in ['skirt', 'dress']:
+#                     batch_configs.append({
+#                         "name": "length",
+#                         "candidates": ["mini", "knee length", "midi", "maxi", "floor length"]
+#                     })
+            
+#             # Execute batch classification
+#             logger.info(f"Batch classifying {len(batch_configs)} attributes...")
+#             batch_results = self.batch_classify(image, batch_configs)
+            
+#             # Process batch results
+#             for attr_name, result in batch_results.items():
+#                 detailed_predictions[attr_name] = result
+#                 if result["predictions"]:
+#                     visual_attributes[attr_name] = result["predictions"][0]["value"]
+            
+#             # Format response
+#             result = {
+#                 "visual_attributes": visual_attributes,
+#                 "detailed_predictions": detailed_predictions,
+#                 "processing_time": round(time.time() - start_time, 2)
+#             }
+            
+#             # Cache result
+#             self._set_cached(cache_key, result)
+            
+#             logger.info(f"Visual processing completed in {result['processing_time']}s")
+#             return result
+            
+#         except Exception as e:
+#             logger.error(f"Error processing image: {str(e)}")
+#             return {
+#                 "visual_attributes": {},
+#                 "error": str(e),
+#                 "processing_time": round(time.time() - start_time, 2)
+#             }
+    
+#     def clear_cache(self):
+#         """Clear all cached results."""
+#         self._cache.clear()
+#         logger.info("Cache cleared")
+    
+#     def get_cache_stats(self) -> Dict:
+#         """Get cache statistics."""
+#         return {
+#             "cache_size": len(self._cache),
+#             "cache_enabled": self.use_cache,
+#             "cache_ttl": self.cache_ttl
+#         }
+    
+#     @classmethod
+#     def cleanup_models(cls):
+#         """Free up memory by unloading models."""
+#         if cls._clip_model is not None:
+#             del cls._clip_model
+#             del cls._clip_processor
+#             cls._clip_model = None
+#             cls._clip_processor = None
+            
+#             if torch.cuda.is_available():
+#                 torch.cuda.empty_cache()
+            
+#             logger.info("Models unloaded and memory freed")
+
+
+# # ==================== Usage Example ====================
+
+# def example_usage():
+#     """Example of how to use the optimized service."""
+    
+#     # Initialize service with caching
+#     service = VisualProcessingService(use_cache=True, cache_ttl=3600)
+    
+#     # Process single image with all attributes
+#     result1 = service.process_image("https://example.com/product1.jpg")
+#     print("All attributes:", result1["visual_attributes"])
+    
+#     # Process with selective attributes (faster)
+#     result2 = service.process_image(
+#         "https://example.com/product2.jpg",
+#         product_type_hint="t-shirt",
+#         attributes_to_extract=["color", "pattern", "style"]
+#     )
+#     print("Selected attributes:", result2["visual_attributes"])
+    
+#     # Check cache stats
+#     print("Cache stats:", service.get_cache_stats())
+    
+#     # Clear cache when needed
+#     service.clear_cache()
+    
+#     # Cleanup models (call when shutting down)
+#     VisualProcessingService.cleanup_models()
+
+
+# if __name__ == "__main__":
+#     example_usage()
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+# ==================== visual_processing_service.py (FIXED) ====================
+import torch
+import cv2
+import numpy as np
+import requests
+from io import BytesIO
+from PIL import Image
+from typing import Dict, List, Optional, Tuple
+import logging
+from transformers import CLIPProcessor, CLIPModel
+from sklearn.cluster import KMeans
+
+logger = logging.getLogger(__name__)
+
+
+class VisualProcessingService:
+    """Service for extracting visual attributes from product images using CLIP."""
+    
+    # Class-level caching (shared across instances)
+    _clip_model = None
+    _clip_processor = None
+    _device = None
+    
+    def __init__(self):
+        pass
+    
+    @classmethod
+    def _get_device(cls):
+        """Get optimal device."""
+        if cls._device is None:
+            cls._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+            logger.info(f"Visual Processing using device: {cls._device}")
+        return cls._device
+    
+    @classmethod
+    def _get_clip_model(cls):
+        """Lazy load CLIP model with class-level caching."""
+        if cls._clip_model is None:
+            logger.info("Loading CLIP model (this may take a few minutes on first use)...")
+            cls._clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
+            cls._clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
+            
+            device = cls._get_device()
+            cls._clip_model.to(device)
+            cls._clip_model.eval()
+            
+            logger.info("✓ CLIP model loaded successfully")
+        return cls._clip_model, cls._clip_processor
+    
+    def download_image(self, image_url: str) -> Optional[Image.Image]:
+        """Download image from URL."""
+        try:
+            response = requests.get(image_url, timeout=10)
+            response.raise_for_status()
+            image = Image.open(BytesIO(response.content)).convert('RGB')
+            return image
+        except Exception as e:
+            logger.error(f"Error downloading image from {image_url}: {str(e)}")
+            return None
+    
+    def extract_dominant_colors(self, image: Image.Image, n_colors: int = 3) -> List[Dict]:
+        """Extract dominant colors using K-means (FIXED webcolors issue)."""
+        try:
+            # Resize for faster processing
+            img_small = image.resize((150, 150))
+            img_array = np.array(img_small)
+            pixels = img_array.reshape(-1, 3)
+            
+            # K-means clustering
+            kmeans = KMeans(n_clusters=n_colors, random_state=42, n_init=5)
+            kmeans.fit(pixels)
+            
+            colors = []
+            labels_counts = np.bincount(kmeans.labels_)
+            
+            for i, center in enumerate(kmeans.cluster_centers_):
+                rgb = tuple(center.astype(int))
+                color_name = self._get_color_name_simple(rgb)
+                percentage = float(labels_counts[i] / len(kmeans.labels_) * 100)
+                
+                colors.append({
+                    "name": color_name,
+                    "rgb": rgb,
+                    "percentage": percentage
+                })
+            
+            colors.sort(key=lambda x: x['percentage'], reverse=True)
+            return colors
+            
+        except Exception as e:
+            logger.error(f"Error extracting colors: {str(e)}")
+            return []
+    
+    def _get_color_name_simple(self, rgb: Tuple[int, int, int]) -> str:
+        """
+        Simple color name detection without webcolors dependency.
+        Maps RGB to basic color names.
+        """
+        r, g, b = rgb
+        
+        # Define basic color ranges
+        colors = {
+            'black': (r < 50 and g < 50 and b < 50),
+            'white': (r > 200 and g > 200 and b > 200),
+            'gray': (abs(r - g) < 30 and abs(g - b) < 30 and abs(r - b) < 30 and 50 <= r <= 200),
+            'red': (r > 150 and g < 100 and b < 100),
+            'green': (g > 150 and r < 100 and b < 100),
+            'blue': (b > 150 and r < 100 and g < 100),
+            'yellow': (r > 200 and g > 200 and b < 100),
+            'orange': (r > 200 and 100 < g < 200 and b < 100),
+            'purple': (r > 100 and b > 100 and g < 100),
+            'pink': (r > 200 and 100 < g < 200 and 100 < b < 200),
+            'brown': (50 < r < 150 and 30 < g < 100 and b < 80),
+            'cyan': (r < 100 and g > 150 and b > 150),
+        }
+        
+        for color_name, condition in colors.items():
+            if condition:
+                return color_name
+        
+        # Default fallback
+        if r > g and r > b:
+            return 'red'
+        elif g > r and g > b:
+            return 'green'
+        elif b > r and b > g:
+            return 'blue'
+        else:
+            return 'gray'
+    
+    def classify_with_clip(
+        self,
+        image: Image.Image,
+        candidates: List[str],
+        attribute_name: str
+    ) -> Dict:
+        """Use CLIP to classify image against candidate labels."""
+        try:
+            model, processor = self._get_clip_model()
+            device = self._get_device()
+            
+            # Prepare inputs
+            inputs = processor(
+                text=candidates,
+                images=image,
+                return_tensors="pt",
+                padding=True
+            )
+            
+            # Move to device
+            inputs = {k: v.to(device) for k, v in inputs.items()}
+            
+            # Get predictions
+            with torch.no_grad():
+                outputs = model(**inputs)
+                logits_per_image = outputs.logits_per_image
+                probs = logits_per_image.softmax(dim=1).cpu()
+            
+            # Get top predictions
+            top_k = min(3, len(candidates))
+            top_probs, top_indices = torch.topk(probs[0], k=top_k)
+            
+            results = []
+            for prob, idx in zip(top_probs, top_indices):
+                if prob.item() > 0.15:  # Confidence threshold
+                    results.append({
+                        "value": candidates[idx.item()],
+                        "confidence": float(prob.item())
+                    })
+            
+            return {
+                "attribute": attribute_name,
+                "predictions": results
+            }
+            
+        except Exception as e:
+            logger.error(f"Error in CLIP classification for {attribute_name}: {str(e)}")
+            return {"attribute": attribute_name, "predictions": []}
+    
+    def process_image(
+        self,
+        image_url: str,
+        product_type_hint: Optional[str] = None
+    ) -> Dict:
+        """
+        Main method to process image and extract visual attributes.
+        """
+        import time
+        start_time = time.time()
+        
+        try:
+            # Download image
+            image = self.download_image(image_url)
+            if image is None:
+                return {
+                    "visual_attributes": {},
+                    "error": "Failed to download image"
+                }
+            
+            visual_attributes = {}
+            detailed_predictions = {}
+            
+            # 1. Product Type Detection
+            product_types = [
+                "t-shirt", "shirt", "dress", "pants", "jeans", "shorts",
+                "skirt", "jacket", "coat", "sweater", "hoodie", "top"
+            ]
+            product_type_result = self.classify_with_clip(image, product_types, "product_type")
+            if product_type_result["predictions"]:
+                visual_attributes["product_type"] = product_type_result["predictions"][0]["value"]
+                detected_product_type = visual_attributes["product_type"]
+            else:
+                detected_product_type = product_type_hint or "unknown"
+            detailed_predictions["product_type"] = product_type_result
+            
+            # 2. Color Detection
+            colors = self.extract_dominant_colors(image, n_colors=3)
+            if colors:
+                visual_attributes["primary_color"] = colors[0]["name"]
+                visual_attributes["color_palette"] = [c["name"] for c in colors]
+            
+            # 3. Pattern Detection
+            patterns = ["solid color", "striped", "checkered", "graphic print", "floral", "geometric"]
+            pattern_result = self.classify_with_clip(image, patterns, "pattern")
+            if pattern_result["predictions"]:
+                visual_attributes["pattern"] = pattern_result["predictions"][0]["value"]
+            detailed_predictions["pattern"] = pattern_result
+            
+            # 4. Material Detection
+            materials = ["cotton", "polyester", "denim", "leather", "silk", "wool", "linen"]
+            material_result = self.classify_with_clip(image, materials, "material")
+            if material_result["predictions"]:
+                visual_attributes["material"] = material_result["predictions"][0]["value"]
+            detailed_predictions["material"] = material_result
+            
+            # 5. Style Detection
+            styles = ["casual", "formal", "sporty", "streetwear", "elegant", "vintage"]
+            style_result = self.classify_with_clip(image, styles, "style")
+            if style_result["predictions"]:
+                visual_attributes["style"] = style_result["predictions"][0]["value"]
+            detailed_predictions["style"] = style_result
+            
+            # 6. Fit Detection
+            fits = ["slim fit", "regular fit", "loose fit", "oversized"]
+            fit_result = self.classify_with_clip(image, fits, "fit")
+            if fit_result["predictions"]:
+                visual_attributes["fit"] = fit_result["predictions"][0]["value"]
+            detailed_predictions["fit"] = fit_result
+            
+            # 7. Neckline (for tops)
+            if detected_product_type.lower() in ['shirt', 't-shirt', 'top', 'dress']:
+                necklines = ["crew neck", "v-neck", "round neck", "collar"]
+                neckline_result = self.classify_with_clip(image, necklines, "neckline")
+                if neckline_result["predictions"]:
+                    visual_attributes["neckline"] = neckline_result["predictions"][0]["value"]
+                detailed_predictions["neckline"] = neckline_result
+            
+            # 8. Sleeve Type (for tops)
+            if detected_product_type.lower() in ['shirt', 't-shirt', 'top']:
+                sleeves = ["short sleeve", "long sleeve", "sleeveless"]
+                sleeve_result = self.classify_with_clip(image, sleeves, "sleeve_type")
+                if sleeve_result["predictions"]:
+                    visual_attributes["sleeve_type"] = sleeve_result["predictions"][0]["value"]
+                detailed_predictions["sleeve_type"] = sleeve_result
+            
+            # 9. Closure Type
+            closures = ["button", "zipper", "pull-on"]
+            closure_result = self.classify_with_clip(image, closures, "closure_type")
+            if closure_result["predictions"]:
+                visual_attributes["closure_type"] = closure_result["predictions"][0]["value"]
+            detailed_predictions["closure_type"] = closure_result
+            
+            processing_time = time.time() - start_time
+            
+            return {
+                "visual_attributes": visual_attributes,
+                "detailed_predictions": detailed_predictions,
+                "processing_time": round(processing_time, 2)
+            }
+            
+        except Exception as e:
+            logger.error(f"Error processing image: {str(e)}")
+            return {
+                "visual_attributes": {},
+                "error": str(e),
+                "processing_time": round(time.time() - start_time, 2)
+            }

BIN
db.sqlite3


BIN
media/products/image.jpg


BIN
media/products/levi_test_ocr2.jpg


+ 12 - 0
requirements_visual.txt

@@ -0,0 +1,12 @@
+torch>=2.0.0
+torchvision>=0.15.0
+transformers>=4.30.0
+opencv-python>=4.8.0
+Pillow>=10.0.0
+webcolors>=1.13
+numpy>=1.24.0
+scikit-learn>=1.3.0
+easyocr>=1.7.0
+sentence-transformers>=2.2.0
+requests>=2.31.0
+EOF