Explorar o código

caching config added

Harshit Pathak hai 3 meses
pai
achega
c1d3a2598b

+ 38 - 0
attr_extraction/cache_config.py

@@ -0,0 +1,38 @@
+# ==================== cache_config.py ====================
+"""
+Centralized cache configuration for the application.
+Set ENABLE_CACHING to True to enable all caches, False to disable.
+"""
+
+# ⚡ MASTER CACHE CONTROL - Change this single variable to enable/disable ALL caching
+ENABLE_CACHING = False  # Default: OFF
+
+# Individual cache controls (controlled by ENABLE_CACHING)
+ENABLE_ATTRIBUTE_EXTRACTION_CACHE = ENABLE_CACHING
+ENABLE_EMBEDDING_CACHE = ENABLE_CACHING
+ENABLE_CLIP_MODEL_CACHE = ENABLE_CACHING
+
+# Cache size limits (only used when caching is enabled)
+ATTRIBUTE_CACHE_MAX_SIZE = 1000
+EMBEDDING_CACHE_MAX_SIZE = 500
+
+def is_caching_enabled() -> bool:
+    """
+    Check if caching is enabled globally.
+    Returns: bool indicating if caching is enabled
+    """
+    return ENABLE_CACHING
+
+def get_cache_config() -> dict:
+    """
+    Get current cache configuration.
+    Returns: dict with cache settings
+    """
+    return {
+        "master_cache_enabled": ENABLE_CACHING,
+        "attribute_extraction_cache": ENABLE_ATTRIBUTE_EXTRACTION_CACHE,
+        "embedding_cache": ENABLE_EMBEDDING_CACHE,
+        "clip_model_cache": ENABLE_CLIP_MODEL_CACHE,
+        "attribute_cache_max_size": ATTRIBUTE_CACHE_MAX_SIZE,
+        "embedding_cache_max_size": EMBEDDING_CACHE_MAX_SIZE
+    }

+ 0 - 209
attr_extraction/serializers.py

@@ -1,212 +1,3 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-# # ==================== serializers.py ====================
-# from rest_framework import serializers
-# from .models import Product, ProductType, ProductAttribute, AttributePossibleValue
-
-
-# class ProductInputSerializer(serializers.Serializer):
-#     """Serializer for individual product input."""
-#     product_id = serializers.CharField(required=False, allow_blank=True, allow_null=True)
-#     title = serializers.CharField(required=False, allow_blank=True, allow_null=True)
-#     short_desc = serializers.CharField(required=False, allow_blank=True, allow_null=True)
-#     long_desc = serializers.CharField(required=False, allow_blank=True, allow_null=True)
-#     image_url = serializers.URLField(required=False, allow_blank=True, allow_null=True)
-
-
-# class MandatoryAttrsField(serializers.DictField):
-#     """Custom DictField to validate mandatory_attrs structure."""
-#     child = serializers.ListField(child=serializers.CharField())
-
-
-# class ProductBatchInputSerializer(serializers.Serializer):
-#     """Serializer for an individual product input within the batch request."""
-#     item_id = serializers.CharField(required=True)
-#     mandatory_attrs = MandatoryAttrsField(
-#         required=True,
-#         help_text="A dictionary of attribute names and their possible values."
-#     )
-
-
-# class SingleProductRequestSerializer(serializers.Serializer):
-#     """Serializer for single product extraction request."""
-#     item_id = serializers.CharField(required=True)
-#     mandatory_attrs = serializers.DictField(
-#         child=serializers.ListField(child=serializers.CharField()),
-#         required=True
-#     )
-#     model = serializers.CharField(required=False, default="llama-3.1-8b-instant")
-#     extract_additional = serializers.BooleanField(required=False, default=True)
-#     process_image = serializers.BooleanField(required=False, default=True)
-#     multiple = serializers.ListField(
-#         child=serializers.CharField(),
-#         required=False,
-#         default=list,
-#         help_text="List of attribute names that can have multiple values"
-#     )
-#     threshold_abs = serializers.FloatField(default=0.65, required=False)
-#     margin = serializers.FloatField(default=0.15, required=False)
-#     use_dynamic_thresholds = serializers.BooleanField(default=True, required=False)
-#     use_adaptive_margin = serializers.BooleanField(default=True, required=False)
-#     use_semantic_clustering = serializers.BooleanField(default=True, required=False)
-
-#     def validate_model(self, value):
-#         from django.conf import settings
-#         if value not in settings.SUPPORTED_MODELS:
-#             raise serializers.ValidationError(
-#                 f"Model must be one of {settings.SUPPORTED_MODELS}"
-#             )
-#         return value
-
-
-# class BatchProductRequestSerializer(serializers.Serializer):
-#     """Serializer for batch product extraction request (with item-specific attributes)."""
-#     products = serializers.ListField(
-#         child=ProductBatchInputSerializer(),
-#         required=True,
-#         min_length=1
-#     )
-#     model = serializers.CharField(required=False, default="llama-3.1-8b-instant")
-#     extract_additional = serializers.BooleanField(required=False, default=True)
-#     process_image = serializers.BooleanField(required=False, default=True)
-#     multiple = serializers.ListField(
-#         child=serializers.CharField(),
-#         required=False,
-#         default=list,
-#         help_text="List of attribute names that can have multiple values"
-#     )
-#     threshold_abs = serializers.FloatField(default=0.65, required=False)
-#     margin = serializers.FloatField(default=0.15, required=False)
-#     use_dynamic_thresholds = serializers.BooleanField(default=True, required=False)
-#     use_adaptive_margin = serializers.BooleanField(default=True, required=False)
-#     use_semantic_clustering = serializers.BooleanField(default=True, required=False)
-    
-#     def validate_model(self, value):
-#         from django.conf import settings
-#         if value not in settings.SUPPORTED_MODELS:
-#             raise serializers.ValidationError(
-#                 f"Model must be one of {settings.SUPPORTED_MODELS}"
-#             )
-#         return value
-    
-#     def validate_products(self, value):
-#         from django.conf import settings
-#         max_size = getattr(settings, 'MAX_BATCH_SIZE', 100)
-#         if len(value) > max_size:
-#             raise serializers.ValidationError(
-#                 f"Batch size cannot exceed {max_size} products"
-#             )
-#         return value
-
-
-# class OCRResultSerializer(serializers.Serializer):
-#     """Serializer for OCR results."""
-#     detected_text = serializers.ListField(child=serializers.DictField())
-#     extracted_attributes = serializers.DictField()
-
-
-# class ProductAttributeResultSerializer(serializers.Serializer):
-#     """Serializer for individual product extraction result."""
-#     product_id = serializers.CharField(required=False)
-#     mandatory = serializers.DictField()
-#     additional = serializers.DictField(required=False)
-#     ocr_results = OCRResultSerializer(required=False)
-#     error = serializers.CharField(required=False)
-#     raw_output = serializers.CharField(required=False)
-
-
-# class BatchProductResponseSerializer(serializers.Serializer):
-#     """Serializer for batch extraction response."""
-#     results = serializers.ListField(child=ProductAttributeResultSerializer())
-#     total_products = serializers.IntegerField()
-#     successful = serializers.IntegerField()
-#     failed = serializers.IntegerField()
-
-
-# class ProductSerializer(serializers.ModelSerializer):
-#     """Serializer for Product model with product type details."""
-#     product_type_details = serializers.SerializerMethodField()
-    
-#     class Meta:
-#         model = Product
-#         fields = [
-#             'id',
-#             'item_id',
-#             'product_name',
-#             'product_long_description',
-#             'product_short_description',
-#             'product_type',
-#             'image_path',
-#             'image',
-#             'product_type_details',
-#         ]
-
-#     def get_product_type_details(self, obj):
-#         """Fetch ProductType object and its attributes for this product."""
-#         try:
-#             product_type = ProductType.objects.get(name=obj.product_type)
-#         except ProductType.DoesNotExist:
-#             return []
-
-#         # Serialize its attributes
-#         attributes = ProductAttribute.objects.filter(product_type=product_type)
-#         return [
-#             {
-#                 "attribute_name": attr.name,
-#                 "is_mandatory": "Yes" if attr.is_mandatory else "No",
-#                 "possible_values": [pv.value for pv in attr.possible_values.all()]
-#             }
-#             for attr in attributes
-#         ]
-
-
-# class AttributePossibleValueSerializer(serializers.ModelSerializer):
-#     """Serializer for AttributePossibleValue model."""
-#     class Meta:
-#         model = AttributePossibleValue
-#         fields = ['value']
-
-
-# class ProductAttributeSerializer(serializers.ModelSerializer):
-#     """Serializer for ProductAttribute model with possible values."""
-#     possible_values = AttributePossibleValueSerializer(many=True, read_only=True)
-    
-#     class Meta:
-#         model = ProductAttribute
-#         fields = ['name', 'is_mandatory', 'possible_values']
-
-
-# class ProductTypeSerializer(serializers.ModelSerializer):
-#     """Serializer for ProductType model with attributes."""
-#     attributes = ProductAttributeSerializer(many=True, read_only=True)
-    
-#     class Meta:
-#         model = ProductType
-#         fields = ['name', 'attributes']
-
-
-
-
-
-
-
-        
-
-
-
 # ==================== Updated serializers.py ====================
 from rest_framework import serializers
 from .models import Product, ProductType, ProductAttribute, AttributePossibleValue

+ 55 - 13
attr_extraction/services.py

@@ -1,5 +1,4 @@
-
-# ==================== services.py (PERFORMANCE OPTIMIZED) ====================
+# ==================== services.py (WITH CACHE CONTROL) ====================
 import requests
 import json
 import re
@@ -11,6 +10,15 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
 from sentence_transformers import SentenceTransformer, util
 import numpy as np
 
+# ⚡ IMPORT CACHE CONFIGURATION
+from .cache_config import (
+    is_caching_enabled,
+    ENABLE_ATTRIBUTE_EXTRACTION_CACHE,
+    ENABLE_EMBEDDING_CACHE,
+    ATTRIBUTE_CACHE_MAX_SIZE,
+    EMBEDDING_CACHE_MAX_SIZE
+)
+
 logger = logging.getLogger(__name__)
 
 # ⚡ CRITICAL FIX: Initialize embedding model ONCE at module level
@@ -27,14 +35,21 @@ print("✓ Model loaded successfully")
 class SimpleCache:
     """In-memory cache for attribute extraction results."""
     _cache = {}
-    _max_size = 1000
+    _max_size = ATTRIBUTE_CACHE_MAX_SIZE
     
     @classmethod
     def get(cls, key: str) -> Optional[Dict]:
+        """Get value from cache. Returns None if caching is disabled."""
+        if not ENABLE_ATTRIBUTE_EXTRACTION_CACHE:
+            return None
         return cls._cache.get(key)
     
     @classmethod
     def set(cls, key: str, value: Dict):
+        """Set value in cache. Does nothing if caching is disabled."""
+        if not ENABLE_ATTRIBUTE_EXTRACTION_CACHE:
+            return
+            
         if len(cls._cache) >= cls._max_size:
             items = list(cls._cache.items())
             cls._cache = dict(items[int(cls._max_size * 0.2):])
@@ -42,27 +57,39 @@ class SimpleCache:
     
     @classmethod
     def clear(cls):
+        """Clear the cache."""
         cls._cache.clear()
     
     @classmethod
     def get_stats(cls) -> Dict:
+        """Get cache statistics."""
         return {
+            "enabled": ENABLE_ATTRIBUTE_EXTRACTION_CACHE,
             "size": len(cls._cache),
             "max_size": cls._max_size,
-            "usage_percent": round(len(cls._cache) / cls._max_size * 100, 2)
+            "usage_percent": round(len(cls._cache) / cls._max_size * 100, 2) if cls._max_size > 0 else 0
         }
 
 
 class EmbeddingCache:
     """Cache for sentence transformer embeddings."""
     _cache = {}
-    _max_size = 500
+    _max_size = EMBEDDING_CACHE_MAX_SIZE
     _hit_count = 0
     _miss_count = 0
     
     @classmethod
     def get_embedding(cls, text: str, model):
-        """Get or compute embedding with caching"""
+        """Get or compute embedding with optional caching"""
+        # If caching is disabled, always compute fresh
+        if not ENABLE_EMBEDDING_CACHE:
+            import warnings
+            with warnings.catch_warnings():
+                warnings.simplefilter("ignore")
+                embedding = model.encode(text, convert_to_tensor=True, show_progress_bar=False)
+            return embedding
+        
+        # Caching is enabled, check cache first
         if text in cls._cache:
             cls._hit_count += 1
             return cls._cache[text]
@@ -73,7 +100,7 @@ class EmbeddingCache:
             items = list(cls._cache.items())
             cls._cache = dict(items[int(cls._max_size * 0.3):])
         
-        # ⚡ CRITICAL: Disable verbose output
+        # Compute embedding
         import warnings
         with warnings.catch_warnings():
             warnings.simplefilter("ignore")
@@ -84,15 +111,18 @@ class EmbeddingCache:
     
     @classmethod
     def clear(cls):
+        """Clear the cache and reset statistics."""
         cls._cache.clear()
         cls._hit_count = 0
         cls._miss_count = 0
     
     @classmethod
     def get_stats(cls) -> Dict:
+        """Get cache statistics."""
         total = cls._hit_count + cls._miss_count
         hit_rate = (cls._hit_count / total * 100) if total > 0 else 0
         return {
+            "enabled": ENABLE_EMBEDDING_CACHE,
             "size": len(cls._cache),
             "hits": cls._hit_count,
             "misses": cls._miss_count,
@@ -624,7 +654,7 @@ If an attribute is not present, do not include it in the response.
         use_dynamic_thresholds: bool = True,
         use_adaptive_margin: bool = True,
         use_semantic_clustering: bool = True,
-        use_cache: bool = True
+        use_cache: bool = None  # ⚡ NEW: Can override global setting
     ) -> dict:
         """Extract attributes from product text using Groq LLM."""
         
@@ -637,6 +667,14 @@ If an attribute is not present, do not include it in the response.
         if source_map is None:
             source_map = {}
 
+        # ⚡ CACHE CONTROL: use parameter if provided, otherwise use global setting
+        if use_cache is None:
+            use_cache = ENABLE_ATTRIBUTE_EXTRACTION_CACHE
+        
+        # If caching is globally disabled, force use_cache to False
+        if not is_caching_enabled():
+            use_cache = False
+
         if not product_text or product_text == "No product information available":
             return ProductAttributeService._create_error_response(
                 "No product information provided",
@@ -644,13 +682,15 @@ If an attribute is not present, do not include it in the response.
                 extract_additional
             )
 
-        # ⚡ CHECK CACHE FIRST
+        # ⚡ CHECK CACHE FIRST (only if enabled)
         if use_cache:
             cache_key = ProductAttributeService._generate_cache_key(product_text, mandatory_attrs)
             cached_result = SimpleCache.get(cache_key)
             if cached_result:
-                logger.info(f"✓ Cache hit")
+                logger.info(f"✓ Cache hit (caching enabled)")
                 return cached_result
+        else:
+            logger.info(f"⚠ Cache disabled - processing fresh")
 
         mandatory_attr_list = []
         for attr_name, allowed_values in mandatory_attrs.items():
@@ -791,9 +831,10 @@ Required Output Format:
                 parsed["mandatory"][attr] = result[attr]
                 extracted_so_far[attr] = result[attr]
 
-            # ⚡ CACHE THE RESULT
+            # ⚡ CACHE THE RESULT (only if enabled)
             if use_cache:
                 SimpleCache.set(cache_key, parsed)
+                logger.info(f"✓ Result cached")
 
             return parsed
 
@@ -907,15 +948,16 @@ Required Output Format:
 
     @staticmethod
     def get_cache_stats() -> Dict:
-        """Get statistics for both caches."""
+        """Get statistics for all caches including global status."""
         return {
+            "global_caching_enabled": is_caching_enabled(),
             "simple_cache": SimpleCache.get_stats(),
             "embedding_cache": EmbeddingCache.get_stats()
         }
 
     @staticmethod
     def clear_all_caches():
-        """Clear both caches."""
+        """Clear all caches."""
         SimpleCache.clear()
         EmbeddingCache.clear()
         logger.info("All caches cleared")

+ 3 - 17
attr_extraction/urls.py

@@ -1,20 +1,3 @@
-# # ==================== urls.py ====================
-# from django.urls import path
-# from .views import ExtractProductAttributesView,ProductTypeListView, ProductTypeAttributesView, ProductAttributesUploadView, BatchExtractProductAttributesView, ProductListView, ProductUploadExcelView
-
-# urlpatterns = [
-#     path('extract/', ExtractProductAttributesView.as_view(), name='extract-attributes'),
-#     path('batch-extract/', BatchExtractProductAttributesView.as_view(), name='batch-extract-attributes'),
-#     path('products/', ProductListView.as_view(), name='batch-extract-attributes'),
-#     path('products/upload-excel/', ProductUploadExcelView.as_view(), name='product-upload-excel'),
-#     path('products/upload-attributes/', ProductAttributesUploadView.as_view(), name='product-upload-excel'),
-#     path('products/attributes/', ProductTypeAttributesView.as_view(), name='product-upload-excel'),
-#     path('product-types/', ProductTypeListView.as_view(), name='product-types-list'),
-# ]
-
-
-
-
 # urls.py
 from django.urls import path
 from .views import (
@@ -46,6 +29,9 @@ urlpatterns = [
     path('attribute-values/', ProductAttributeValueView.as_view(), name='attribute-values'),
     path('attribute-values/bulk/', BulkProductAttributeValueView.as_view(), name='attribute-values-bulk'),
     # path('attribute-values/upload-excel/', ProductAttributeValueUploadExcelView.as_view(), name='attribute-values-upload'),
+
+    path('attribute-values/', ProductAttributeValueView.as_view(), name='attribute-values'),
+    path('attribute-values/bulk/', BulkProductAttributeValueView.as_view(), name='attribute-values-bulk'),
 ]
 
 

+ 154 - 1
attr_extraction/views.py

@@ -22,6 +22,10 @@ from rest_framework.parsers import MultiPartParser, FormParser
 from openpyxl import Workbook
 from openpyxl.styles import Font, PatternFill, Alignment
 
+
+from rest_framework.views import APIView
+from . import cache_config
+
 # --- Local imports ---
 from .models import (
     Product,
@@ -1719,4 +1723,153 @@ class ProductListWithAttributesView(APIView):
         else:
             products = Product.objects.all()
             serializer = ProductWithAttributesSerializer(products, many=True)
-            return Response(serializer.data, status=status.HTTP_200_OK)
+            return Response(serializer.data, status=status.HTTP_200_OK)
+        
+
+
+
+class CacheManagementView(APIView):
+    """
+    API endpoint to manage caching system.
+    
+    GET: Get current cache statistics and configuration
+    POST: Enable/disable caching or clear caches
+    """
+    
+    def get(self, request):
+        """
+        Get current cache configuration and statistics.
+        """
+        config = cache_config.get_cache_config()
+        stats = ProductAttributeService.get_cache_stats()
+        
+        return Response({
+            "configuration": config,
+            "statistics": stats,
+            "message": "Cache status retrieved successfully"
+        }, status=status.HTTP_200_OK)
+    
+    def post(self, request):
+        """
+        Manage cache settings.
+        
+        Expected payload examples:
+        
+        1. Enable/disable caching:
+        {
+            "action": "toggle",
+            "enable": true  // or false
+        }
+        
+        2. Clear all caches:
+        {
+            "action": "clear"
+        }
+        
+        3. Clear specific cache:
+        {
+            "action": "clear",
+            "cache_type": "embedding"  // or "attribute" or "clip"
+        }
+        
+        4. Get statistics:
+        {
+            "action": "stats"
+        }
+        """
+        action = request.data.get('action')
+        
+        if not action:
+            return Response({
+                "error": "action is required",
+                "valid_actions": ["toggle", "clear", "stats"]
+            }, status=status.HTTP_400_BAD_REQUEST)
+        
+        # Toggle caching on/off
+        if action == "toggle":
+            enable = request.data.get('enable')
+            
+            if enable is None:
+                return Response({
+                    "error": "enable parameter is required (true/false)"
+                }, status=status.HTTP_400_BAD_REQUEST)
+            
+            # Update the cache configuration
+            cache_config.ENABLE_CACHING = bool(enable)
+            cache_config.ENABLE_ATTRIBUTE_EXTRACTION_CACHE = bool(enable)
+            cache_config.ENABLE_EMBEDDING_CACHE = bool(enable)
+            cache_config.ENABLE_CLIP_MODEL_CACHE = bool(enable)
+            
+            status_msg = "enabled" if enable else "disabled"
+            
+            return Response({
+                "message": f"Caching has been {status_msg}",
+                "configuration": cache_config.get_cache_config()
+            }, status=status.HTTP_200_OK)
+        
+        # Clear caches
+        elif action == "clear":
+            cache_type = request.data.get('cache_type', 'all')
+            
+            if cache_type == 'all':
+                ProductAttributeService.clear_all_caches()
+                VisualProcessingService.clear_clip_cache()
+                message = "All caches cleared successfully"
+            
+            elif cache_type == 'embedding':
+                from .services import EmbeddingCache
+                EmbeddingCache.clear()
+                message = "Embedding cache cleared successfully"
+            
+            elif cache_type == 'attribute':
+                from .services import SimpleCache
+                SimpleCache.clear()
+                message = "Attribute extraction cache cleared successfully"
+            
+            elif cache_type == 'clip':
+                VisualProcessingService.clear_clip_cache()
+                message = "CLIP model cache cleared successfully"
+            
+            else:
+                return Response({
+                    "error": f"Invalid cache_type: {cache_type}",
+                    "valid_types": ["all", "embedding", "attribute", "clip"]
+                }, status=status.HTTP_400_BAD_REQUEST)
+            
+            return Response({
+                "message": message,
+                "statistics": ProductAttributeService.get_cache_stats()
+            }, status=status.HTTP_200_OK)
+        
+        # Get statistics
+        elif action == "stats":
+            stats = ProductAttributeService.get_cache_stats()
+            config = cache_config.get_cache_config()
+            
+            return Response({
+                "configuration": config,
+                "statistics": stats
+            }, status=status.HTTP_200_OK)
+        
+        else:
+            return Response({
+                "error": f"Invalid action: {action}",
+                "valid_actions": ["toggle", "clear", "stats"]
+            }, status=status.HTTP_400_BAD_REQUEST)
+
+
+class CacheStatsView(APIView):
+    """
+    Simple GET endpoint to retrieve cache statistics.
+    """
+    
+    def get(self, request):
+        """Get current cache statistics."""
+        stats = ProductAttributeService.get_cache_stats()
+        config = cache_config.get_cache_config()
+        
+        return Response({
+            "cache_enabled": config["master_cache_enabled"],
+            "statistics": stats,
+            "timestamp": datetime.now().isoformat()
+        }, status=status.HTTP_200_OK)

+ 45 - 42
attr_extraction/visual_processing_service.py

@@ -1,5 +1,4 @@
-
-# ==================== visual_processing_service.py (FIXED - Smart Subcategory Detection) ====================
+# ==================== visual_processing_service.py (WITH CACHE CONTROL) ====================
 import torch
 import numpy as np
 import requests
@@ -10,18 +9,21 @@ import logging
 from transformers import CLIPProcessor, CLIPModel
 from sklearn.cluster import KMeans
 
+# ⚡ IMPORT CACHE CONFIGURATION
+from .cache_config import ENABLE_CLIP_MODEL_CACHE
+
 logger = logging.getLogger(__name__)
 
 import os
-os.environ['TOKENIZERS_PARALLELISM'] = 'false'  # Disable tokenizer warnings
+os.environ['TOKENIZERS_PARALLELISM'] = 'false'
 import warnings
-warnings.filterwarnings('ignore')  # Suppress all warnings
+warnings.filterwarnings('ignore')
 
 
 class VisualProcessingService:
     """Service for extracting visual attributes from product images using CLIP with smart subcategory detection."""
     
-    # Class-level caching (shared across instances)
+    # ⚡ Class-level caching (controlled by cache_config)
     _clip_model = None
     _clip_processor = None
     _device = None
@@ -129,7 +131,24 @@ class VisualProcessingService:
     
     @classmethod
     def _get_clip_model(cls):
-        """Lazy load CLIP model with class-level caching."""
+        """
+        Lazy load CLIP model with optional class-level caching.
+        ⚡ If caching is disabled, model is still loaded but not persisted at class level.
+        """
+        # ⚡ CACHE CONTROL: If caching is disabled, always reload (no persistence)
+        if not ENABLE_CLIP_MODEL_CACHE:
+            logger.info("⚠ CLIP model caching is DISABLED - loading fresh instance")
+            model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
+            processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
+            
+            device = cls._get_device()
+            model.to(device)
+            model.eval()
+            
+            logger.info("✓ CLIP model loaded (no caching)")
+            return model, processor
+        
+        # Caching is enabled - use class-level cache
         if cls._clip_model is None:
             logger.info("Loading CLIP model (this may take a few minutes on first use)...")
             cls._clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
@@ -139,9 +158,24 @@ class VisualProcessingService:
             cls._clip_model.to(device)
             cls._clip_model.eval()
             
-            logger.info("✓ CLIP model loaded successfully")
+            logger.info("✓ CLIP model loaded and cached successfully")
+        else:
+            logger.info("✓ Using cached CLIP model")
+            
         return cls._clip_model, cls._clip_processor
     
+    @classmethod
+    def clear_clip_cache(cls):
+        """Clear the cached CLIP model to free memory."""
+        if cls._clip_model is not None:
+            del cls._clip_model
+            del cls._clip_processor
+            cls._clip_model = None
+            cls._clip_processor = None
+            if torch.cuda.is_available():
+                torch.cuda.empty_cache()
+            logger.info("✓ CLIP model cache cleared")
+    
     def download_image(self, image_url: str) -> Optional[Image.Image]:
         """Download image from URL."""
         try:
@@ -156,12 +190,10 @@ class VisualProcessingService:
     def extract_dominant_colors(self, image: Image.Image, n_colors: int = 3) -> List[Dict]:
         """Extract dominant colors using K-means clustering."""
         try:
-            # Resize for faster processing
             img_small = image.resize((150, 150))
             img_array = np.array(img_small)
             pixels = img_array.reshape(-1, 3)
             
-            # K-means clustering
             kmeans = KMeans(n_clusters=n_colors, random_state=42, n_init=5)
             kmeans.fit(pixels)
             
@@ -179,7 +211,6 @@ class VisualProcessingService:
                     "percentage": round(percentage, 2)
                 })
             
-            # Sort by percentage (most dominant first)
             colors.sort(key=lambda x: x['percentage'], reverse=True)
             return colors
             
@@ -191,7 +222,6 @@ class VisualProcessingService:
         """Map RGB values to basic color names."""
         r, g, b = rgb
         
-        # Define color ranges with priorities
         colors = {
             'black': (r < 50 and g < 50 and b < 50),
             'white': (r > 200 and g > 200 and b > 200),
@@ -212,7 +242,6 @@ class VisualProcessingService:
             if condition:
                 return color_name
         
-        # Fallback to dominant channel
         if r > g and r > b:
             return 'red'
         elif g > r and g > b:
@@ -234,14 +263,12 @@ class VisualProcessingService:
             model, processor = self._get_clip_model()
             device = self._get_device()
             
-            # ⚡ OPTIMIZATION: Process in smaller batches to avoid memory issues
-            batch_size = 16  # Process 16 candidates at a time
+            batch_size = 16
             all_results = []
             
             for i in range(0, len(candidates), batch_size):
                 batch_candidates = candidates[i:i + batch_size]
                 
-                # Prepare inputs WITHOUT progress bars
                 inputs = processor(
                     text=batch_candidates,
                     images=image,
@@ -249,16 +276,13 @@ class VisualProcessingService:
                     padding=True
                 )
                 
-                # Move to device
                 inputs = {k: v.to(device) for k, v in inputs.items()}
                 
-                # Get predictions
                 with torch.no_grad():
                     outputs = model(**inputs)
                     logits_per_image = outputs.logits_per_image
                     probs = logits_per_image.softmax(dim=1).cpu()
                 
-                # Collect results from this batch
                 for j, prob in enumerate(probs[0]):
                     if prob.item() > confidence_threshold:
                         all_results.append({
@@ -266,7 +290,6 @@ class VisualProcessingService:
                             "confidence": round(float(prob.item()), 3)
                         })
             
-            # Sort by confidence and return top 3
             all_results.sort(key=lambda x: x['confidence'], reverse=True)
             
             return {
@@ -278,16 +301,11 @@ class VisualProcessingService:
             logger.error(f"Error in CLIP classification for {attribute_name}: {str(e)}")
             return {"attribute": attribute_name, "predictions": []}
 
-
-
-
-
     def detect_category_and_subcategory(self, image: Image.Image) -> Tuple[str, str, str, float]:
         """
         Hierarchically detect category, subcategory, and specific product.
         Returns: (category, subcategory, product_type, confidence)
         """
-        # Step 1: Detect if it's clothing or something else
         main_categories = list(self.CATEGORY_ATTRIBUTES.keys())
         category_prompts = [f"a photo of {cat}" for cat in main_categories]
         
@@ -301,11 +319,9 @@ class VisualProcessingService:
         
         logger.info(f"Step 1 - Main category detected: {detected_category} (confidence: {category_confidence:.3f})")
         
-        # Step 2: For clothing, detect subcategory (tops/bottoms/dresses/outerwear)
         if detected_category == "clothing":
             subcategories = self.CATEGORY_ATTRIBUTES["clothing"]["subcategories"]
             
-            # Collect all products grouped by subcategory
             all_products = []
             product_to_subcategory = {}
             
@@ -315,7 +331,6 @@ class VisualProcessingService:
                     all_products.append(prompt)
                     product_to_subcategory[prompt] = subcat
             
-            # Step 3: Detect specific product type
             product_result = self.classify_with_clip(
                 image, 
                 all_products, 
@@ -336,11 +351,9 @@ class VisualProcessingService:
                 logger.warning("Could not detect specific product type for clothing")
                 return detected_category, "unknown", "unknown", category_confidence
         
-        # Step 3: For non-clothing categories, just detect product type
         else:
             category_data = self.CATEGORY_ATTRIBUTES[detected_category]
             
-            # Check if this category has subcategories or direct products
             if "products" in category_data:
                 products = category_data["products"]
                 product_prompts = [f"a photo of {p}" for p in products]
@@ -374,7 +387,6 @@ class VisualProcessingService:
         start_time = time.time()
         
         try:
-            # Download image
             image = self.download_image(image_url)
             if image is None:
                 return {
@@ -385,10 +397,8 @@ class VisualProcessingService:
             visual_attributes = {}
             detailed_predictions = {}
             
-            # Step 1: Detect category, subcategory, and product type
             category, subcategory, product_type, confidence = self.detect_category_and_subcategory(image)
             
-            # Low confidence check
             if confidence < 0.10:
                 logger.warning(f"Low confidence in detection ({confidence:.3f}). Returning basic attributes only.")
                 colors = self.extract_dominant_colors(image, n_colors=3)
@@ -403,13 +413,11 @@ class VisualProcessingService:
                     "processing_time": round(time.time() - start_time, 2)
                 }
             
-            # Add detected metadata
             visual_attributes["product_type"] = product_type
             visual_attributes["category"] = category
             if subcategory != "none" and subcategory != "unknown":
                 visual_attributes["subcategory"] = subcategory
             
-            # Step 2: Extract color information (universal)
             colors = self.extract_dominant_colors(image, n_colors=3)
             if colors:
                 visual_attributes["primary_color"] = colors[0]["name"]
@@ -419,7 +427,6 @@ class VisualProcessingService:
                     for c in colors
                 ]
             
-            # Step 3: Get the right attribute configuration based on subcategory
             attributes_config = None
             
             if category == "clothing":
@@ -434,7 +441,6 @@ class VisualProcessingService:
                     attributes_config = self.CATEGORY_ATTRIBUTES[category]["attributes"]
                     logger.info(f"Using attributes for category: {category}")
             
-            # Step 4: Extract category-specific attributes
             if attributes_config:
                 for attr_name, attr_values in attributes_config.items():
                     result = self.classify_with_clip(
@@ -446,11 +452,9 @@ class VisualProcessingService:
                     
                     if result["predictions"]:
                         best_prediction = result["predictions"][0]
-                        # Only add attributes with reasonable confidence
                         if best_prediction["confidence"] > 0.20:
                             visual_attributes[attr_name] = best_prediction["value"]
                         
-                        # Store detailed predictions for debugging
                         detailed_predictions[attr_name] = result
             
             processing_time = time.time() - start_time
@@ -461,7 +465,8 @@ class VisualProcessingService:
                 "visual_attributes": visual_attributes,
                 "detailed_predictions": detailed_predictions,
                 "detection_confidence": confidence,
-                "processing_time": round(processing_time, 2)
+                "processing_time": round(processing_time, 2),
+                "cache_status": "enabled" if ENABLE_CLIP_MODEL_CACHE else "disabled"
             }
             
         except Exception as e:
@@ -470,6 +475,4 @@ class VisualProcessingService:
                 "visual_attributes": {},
                 "error": str(e),
                 "processing_time": round(time.time() - start_time, 2)
-            }
-
-
+            }

+ 10 - 1
content_quality_tool/settings.py

@@ -6,6 +6,13 @@ https://docs.djangoproject.com/en/5.2/topics/settings/
 For the full list of settings and their values, see
 https://docs.djangoproject.com/en/5.2/ref/settings/
 """
+
+import sys
+import io
+sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
+sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8')
+
+
 from pathlib import Path
 import os
 from django.contrib.messages import constants as messages
@@ -122,4 +129,6 @@ MESSAGE_TAGS = {
 GROQ_API_KEY = "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"
 GROQ_API_URL = "https://api.groq.com/openai/v1/chat/completions"
 SUPPORTED_MODELS = ["llama-3.1-8b-instant", "llama-3.3-70b-versatile", "mixtral-8x7b-32768"]
-MAX_BATCH_SIZE = 100  # Maximum products per batch request
+MAX_BATCH_SIZE = 100  # Maximum products per batch request
+
+