Browse Source

updated code, optimised, cache etc

Harshit Pathak 3 months ago
parent
commit
049eb5ed74

+ 2 - 0
attr_extraction/__init__.py

@@ -0,0 +1,2 @@
+# ==================== __init__.py (in your app directory) ====================
+default_app_config = 'attr_extraction.apps.ProductAttributesConfig'

+ 369 - 0
attr_extraction/apps.py

@@ -1,6 +1,375 @@
+
+
+
+
+# # ==================== apps.py ====================
+# from django.apps import AppConfig
+# import logging
+
+# logger = logging.getLogger(__name__)
+
+# class ProductAttributesConfig(AppConfig):
+#     default_auto_field = 'django.db.models.BigAutoField'
+#     name = 'attr_extraction'  # Replace with your actual app name
+    
+#     def ready(self):
+#         """
+#         🔥 CRITICAL: Pre-load all heavy models during Django startup
+#         This runs ONCE when the server starts, not on every request
+#         """
+#         import time
+#         from django.conf import settings
+        
+#         # Only load models if not in migration/management command
+#         import sys
+#         if 'migrate' in sys.argv or 'makemigrations' in sys.argv:
+#             return
+        
+#         logger.info("=" * 60)
+#         logger.info("🔥 WARMING UP ML MODELS (one-time startup delay)")
+#         logger.info("=" * 60)
+        
+#         startup_time = time.time()
+        
+#         # 1. Pre-load Sentence Transformer (already done at module level in services.py)
+#         logger.info("✓ Sentence Transformer: already loaded at module level")
+        
+#         # 2. Pre-load CLIP model
+#         try:
+#             clip_start = time.time()
+#             from .visual_processing_service import VisualProcessingService
+#             VisualProcessingService._get_clip_model()
+#             clip_time = time.time() - clip_start
+#             logger.info(f"✓ CLIP model loaded in {clip_time:.1f}s")
+#         except Exception as e:
+#             logger.warning(f"⚠️  CLIP model loading failed: {e}")
+        
+#         # 3. Pre-load OCR model
+#         try:
+#             ocr_start = time.time()
+#             from .ocr_service import OCRService
+#             ocr_service = OCRService()
+#             ocr_service._get_reader()
+#             ocr_time = time.time() - ocr_start
+#             logger.info(f"✓ OCR model loaded in {ocr_time:.1f}s")
+#         except Exception as e:
+#             logger.warning(f"⚠️  OCR model loading failed: {e}")
+        
+#         total_time = time.time() - startup_time
+        
+#         logger.info("=" * 60)
+#         logger.info(f"🎉 ALL MODELS READY in {total_time:.1f}s")
+#         logger.info("⚡ First API request will now be FAST (2-5 seconds)")
+#         logger.info("=" * 60)
+
+
+
+
+# # ==================== attr_extraction/apps.py ====================
+# from django.apps import AppConfig
+# import logging
+# import sys
+# import threading
+
+# logger = logging.getLogger(__name__)
+
+
+# class AttrExtractionConfig(AppConfig):  # ✅ This is the correct name
+#     default_auto_field = 'django.db.models.BigAutoField'
+#     name = 'attr_extraction'
+    
+#     # Flag to prevent double loading
+#     models_loaded = False
+    
+#     def ready(self):
+#         """
+#         🔥 Pre-load all heavy ML models during Django startup.
+#         Uses background thread to not block server startup.
+#         """
+#         # Skip during migrations/management commands
+#         if any(cmd in sys.argv for cmd in ['migrate', 'makemigrations', 'test', 'collectstatic', 'shell']):
+#             return
+        
+#         # Prevent double loading
+#         if AttrExtractionConfig.models_loaded:
+#             logger.info("⏭️  Models already loaded, skipping...")
+#             return
+        
+#         AttrExtractionConfig.models_loaded = True
+        
+#         # 🔥 Load models in background thread (non-blocking)
+#         thread = threading.Thread(target=self._load_models, daemon=True)
+#         thread.start()
+        
+#         logger.info("🔄 Model loading started in background...")
+    
+#     def _load_models(self):
+#         """Background thread to load heavy models."""
+#         import time
+        
+#         logger.info("=" * 70)
+#         logger.info("🔥 WARMING UP ML MODELS (background process)")
+#         logger.info("=" * 70)
+        
+#         startup_time = time.time()
+#         total_loaded = 0
+        
+#         # 1. Sentence Transformer
+#         try:
+#             logger.info("📥 Loading Sentence Transformer...")
+#             st_start = time.time()
+#             from .services import model_embedder
+#             st_time = time.time() - st_start
+#             logger.info(f"✓ Sentence Transformer ready ({st_time:.1f}s)")
+#             total_loaded += 1
+#         except Exception as e:
+#             logger.error(f"❌ Sentence Transformer failed: {e}")
+        
+#         # 2. Pre-load CLIP model
+#         try:
+#             logger.info("📥 Loading CLIP model (20-30s)...")
+#             clip_start = time.time()
+#             from .visual_processing_service import VisualProcessingService
+#             VisualProcessingService._get_clip_model()
+#             clip_time = time.time() - clip_start
+#             logger.info(f"✓ CLIP model cached ({clip_time:.1f}s)")
+#             total_loaded += 1
+#         except Exception as e:
+#             logger.error(f"❌ CLIP model failed: {e}")
+        
+#         # 3. Pre-load OCR model
+#         try:
+#             logger.info("📥 Loading EasyOCR model...")
+#             ocr_start = time.time()
+#             from .ocr_service import OCRService
+#             ocr_service = OCRService()
+#             ocr_service._get_reader()
+#             ocr_time = time.time() - ocr_start
+#             logger.info(f"✓ OCR model cached ({ocr_time:.1f}s)")
+#             total_loaded += 1
+#         except Exception as e:
+#             logger.error(f"❌ OCR model failed: {e}")
+        
+#         total_time = time.time() - startup_time
+        
+#         logger.info("=" * 70)
+#         logger.info(f"🎉 {total_loaded}/3 MODELS LOADED in {total_time:.1f}s")
+#         logger.info("⚡ API requests are now FAST (2-5 seconds)")
+#         logger.info("=" * 70)
+
+
+
+
+
+
+# # ==================== attr_extraction/apps.py ====================
+# from django.apps import AppConfig
+# import logging
+# import sys
+# import os
+# import threading
+# from django.core.cache import cache  # ✅ Import Django cache
+
+
+# logger = logging.getLogger(__name__)
+
+
+# class AttrExtractionConfig(AppConfig):
+#     default_auto_field = 'django.db.models.BigAutoField'
+#     name = 'attr_extraction'
+    
+#     # Flag to prevent double loading
+#     models_loaded = False
+    
+#     def ready(self):
+#         """
+#         🔥 Pre-load all heavy ML models during Django startup.
+#         """
+#         # Skip during migrations/management commands
+#         if any(cmd in sys.argv for cmd in ['migrate', 'makemigrations', 'test', 'collectstatic', 'shell']):
+#             return
+        
+#         # 🔥 CRITICAL: Skip in Django autoreloader parent process
+#         # Only run in the actual worker process
+#         if os.environ.get('RUN_MAIN') != 'true':
+#             logger.info("⏭️  Skipping model loading in autoreloader parent process")
+#             return
+        
+#         # Prevent double loading
+#         if AttrExtractionConfig.models_loaded:
+#             logger.info("⏭️  Models already loaded, skipping...")
+#             return
+        
+#         AttrExtractionConfig.models_loaded = True
+        
+#         # 🔥 Load models in background thread (non-blocking)
+#         thread = threading.Thread(target=self._load_models, daemon=True)
+#         thread.start()
+        
+#         logger.info("🔄 Model loading started in background...")
+    
+#     def _load_models(self):
+#         """Background thread to load heavy models."""
+#         import time
+        
+#         logger.info("=" * 70)
+#         logger.info("🔥 WARMING UP ML MODELS (background process)")
+#         logger.info("=" * 70)
+        
+#         startup_time = time.time()
+#         total_loaded = 0
+        
+#         # 1. Sentence Transformer
+#         try:
+#             logger.info("📥 Loading Sentence Transformer...")
+#             st_start = time.time()
+#             from .services import model_embedder
+#             st_time = time.time() - st_start
+#             logger.info(f"✓ Sentence Transformer ready ({st_time:.1f}s)")
+#             total_loaded += 1
+#         except Exception as e:
+#             logger.error(f"❌ Sentence Transformer failed: {e}")
+        
+#         # 2. Pre-load CLIP model
+#         try:
+#             logger.info("📥 Loading CLIP model (20-30s)...")
+#             clip_start = time.time()
+#             from .visual_processing_service import VisualProcessingService
+#             VisualProcessingService._get_clip_model()
+#             clip_time = time.time() - clip_start
+#             logger.info(f"✓ CLIP model cached ({clip_time:.1f}s)")
+#             total_loaded += 1
+#         except Exception as e:
+#             logger.error(f"❌ CLIP model failed: {e}")
+        
+#         # 3. Pre-load OCR model
+#         try:
+#             logger.info("📥 Loading EasyOCR model...")
+#             ocr_start = time.time()
+#             from .ocr_service import OCRService
+#             ocr_service = OCRService()
+#             ocr_service._get_reader()
+#             ocr_time = time.time() - ocr_start
+#             logger.info(f"✓ OCR model cached ({ocr_time:.1f}s)")
+#             total_loaded += 1
+#         except Exception as e:
+#             logger.error(f"❌ OCR model failed: {e}")
+        
+#         total_time = time.time() - startup_time
+        
+#         logger.info("=" * 70)
+#         logger.info(f"🎉 {total_loaded}/3 MODELS LOADED in {total_time:.1f}s")
+#         logger.info("⚡ API requests are now FAST (2-5 seconds)")
+#         logger.info("=" * 70)
+
+
+
+
+
+
+
+# ==================== attr_extraction/apps.py ====================
 from django.apps import AppConfig
+import logging
+import sys
+import os
+import threading
+
+from django.core.cache import cache  # ✅ Import Django cache
+
+logger = logging.getLogger(__name__)
 
 
 class AttrExtractionConfig(AppConfig):
     default_auto_field = 'django.db.models.BigAutoField'
     name = 'attr_extraction'
+    
+    models_loaded = False
+    
+    def ready(self):
+        """
+        🔥 Pre-load all heavy ML models during Django startup.
+        Also clears Django cache once when the server starts.
+        """
+        # Skip during migrations/management commands
+        if any(cmd in sys.argv for cmd in ['migrate', 'makemigrations', 'test', 'collectstatic', 'shell']):
+            return
+        
+        # Skip in Django autoreloader parent process
+        if os.environ.get('RUN_MAIN') != 'true':
+            logger.info("⏭️  Skipping model loading in autoreloader parent process")
+            return
+        
+        # ✅ Clear cache once per startup
+        try:
+            cache.clear()
+            logger.info("🧹 Django cache cleared successfully on startup.")
+        except Exception as e:
+            logger.warning(f"⚠️  Failed to clear cache: {e}")
+        
+        # Prevent double loading
+        if AttrExtractionConfig.models_loaded:
+            logger.info("⏭️  Models already loaded, skipping...")
+            return
+        
+        AttrExtractionConfig.models_loaded = True
+        
+        # Load models in background thread (non-blocking)
+        thread = threading.Thread(target=self._load_models, daemon=True)
+        thread.start()
+        
+        logger.info("🔄 Model loading started in background...")
+    
+    def _load_models(self):
+        """Background thread to load heavy models."""
+        import time
+        
+        logger.info("=" * 70)
+        logger.info("🔥 WARMING UP ML MODELS (background process)")
+        logger.info("=" * 70)
+        
+        startup_time = time.time()
+        total_loaded = 0
+        
+        # 1. Sentence Transformer
+        try:
+            logger.info("📥 Loading Sentence Transformer...")
+            st_start = time.time()
+            from .services import model_embedder
+            st_time = time.time() - st_start
+            logger.info(f"✓ Sentence Transformer ready ({st_time:.1f}s)")
+            total_loaded += 1
+        except Exception as e:
+            logger.error(f"❌ Sentence Transformer failed: {e}")
+        
+        # 2. Pre-load CLIP model
+        try:
+            logger.info("📥 Loading CLIP model (20-30s)...")
+            clip_start = time.time()
+            from .visual_processing_service import VisualProcessingService
+            VisualProcessingService._get_clip_model()
+            clip_time = time.time() - clip_start
+            logger.info(f"✓ CLIP model cached ({clip_time:.1f}s)")
+            total_loaded += 1
+        except Exception as e:
+            logger.error(f"❌ CLIP model failed: {e}")
+        
+        # 3. Pre-load OCR model
+        try:
+            logger.info("📥 Loading EasyOCR model...")
+            ocr_start = time.time()
+            from .ocr_service import OCRService
+            ocr_service = OCRService()
+            ocr_service._get_reader()
+            ocr_time = time.time() - ocr_start
+            logger.info(f"✓ OCR model cached ({ocr_time:.1f}s)")
+            total_loaded += 1
+        except Exception as e:
+            logger.error(f"❌ OCR model failed: {e}")
+        
+        total_time = time.time() - startup_time
+        
+        logger.info("=" * 70)
+        logger.info(f"🎉 {total_loaded}/3 MODELS LOADED in {total_time:.1f}s")
+        logger.info("⚡ API requests are now FAST (2-5 seconds)")
+        logger.info("=" * 70)

+ 1 - 1
attr_extraction/cache_config.py

@@ -14,7 +14,7 @@ ENABLE_CLIP_MODEL_CACHE = ENABLE_CACHING
 
 # Cache size limits (only used when caching is enabled)
 ATTRIBUTE_CACHE_MAX_SIZE = 1000
-EMBEDDING_CACHE_MAX_SIZE = 500
+EMBEDDING_CACHE_MAX_SIZE = 5000
 
 def is_caching_enabled() -> bool:
     """

+ 19 - 5
attr_extraction/ocr_service.py

@@ -15,14 +15,28 @@ logger = logging.getLogger(__name__)
 class OCRService:
     """Service for extracting text from product images using OCR."""
     
+    # 🔥 Class-level cache (persists across requests)
+    _shared_reader = None
+    
     def __init__(self):
-        self.reader = None
+        # Don't initialize here - use lazy loading with class cache
+        pass
     
     def _get_reader(self):
-        """Lazy load EasyOCR reader."""
-        if self.reader is None:
-            self.reader = easyocr.Reader(['en'], gpu=False)
-        return self.reader
+        """🔥 Lazy load EasyOCR reader with class-level caching."""
+        if OCRService._shared_reader is None:
+            import time
+            start = time.time()
+            logger.info("📥 Loading EasyOCR model...")
+            
+            OCRService._shared_reader = easyocr.Reader(['en'], gpu=False)
+            
+            load_time = time.time() - start
+            logger.info(f"✓ EasyOCR loaded in {load_time:.1f}s and cached in memory")
+        else:
+            logger.debug("✓ Using cached EasyOCR reader")
+        
+        return OCRService._shared_reader
     
     def download_image(self, image_url: str) -> Optional[np.ndarray]:
         """Download image from URL and convert to OpenCV format."""

+ 242 - 829
attr_extraction/services.py

@@ -1,645 +1,216 @@
-# ==================== services.py (WITH CACHE CONTROL) ====================
-import requests
+# ==================== services.py (FINAL PERFECT + FULL WHITELIST + SEMANTIC RECOVERY) ====================
 import json
-import re
 import hashlib
 import logging
+import warnings
+import time
+from functools import wraps
 from typing import Dict, List, Optional, Tuple
+import os
+import requests
 from django.conf import settings
-from concurrent.futures import ThreadPoolExecutor, as_completed
 from sentence_transformers import SentenceTransformer, util
-import numpy as np
 
-# ⚡ IMPORT CACHE CONFIGURATION
+# --------------------------------------------------------------------------- #
+# CACHE CONFIG
+# --------------------------------------------------------------------------- #
 from .cache_config import (
     is_caching_enabled,
     ENABLE_ATTRIBUTE_EXTRACTION_CACHE,
     ENABLE_EMBEDDING_CACHE,
     ATTRIBUTE_CACHE_MAX_SIZE,
-    EMBEDDING_CACHE_MAX_SIZE
+    EMBEDDING_CACHE_MAX_SIZE,
 )
 
 logger = logging.getLogger(__name__)
 
-# ⚡ CRITICAL FIX: Initialize embedding model ONCE at module level
-print("Loading sentence transformer model (one-time initialization)...")
+# --------------------------------------------------------------------------- #
+# ONE-TIME MODEL LOAD
+# --------------------------------------------------------------------------- #
+print("Loading sentence transformer model (semantic recovery)...")
 model_embedder = SentenceTransformer("all-MiniLM-L6-v2")
-# Disable progress bars to prevent "Batches: 100%" spam
-import os
-os.environ['TOKENIZERS_PARALLELISM'] = 'false'
-print("✓ Model loaded successfully")
-
-
-# ==================== CACHING CLASSES ====================
+os.environ["TOKENIZERS_PARALLELISM"] = "false"
+print("Model loaded")
 
+# --------------------------------------------------------------------------- #
+# CACHES
+# --------------------------------------------------------------------------- #
 class SimpleCache:
-    """In-memory cache for attribute extraction results."""
     _cache = {}
     _max_size = ATTRIBUTE_CACHE_MAX_SIZE
-    
+
     @classmethod
     def get(cls, key: str) -> Optional[Dict]:
-        """Get value from cache. Returns None if caching is disabled."""
-        if not ENABLE_ATTRIBUTE_EXTRACTION_CACHE:
-            return None
+        if not ENABLE_ATTRIBUTE_EXTRACTION_CACHE: return None
         return cls._cache.get(key)
-    
+
     @classmethod
     def set(cls, key: str, value: Dict):
-        """Set value in cache. Does nothing if caching is disabled."""
-        if not ENABLE_ATTRIBUTE_EXTRACTION_CACHE:
-            return
-            
+        if not ENABLE_ATTRIBUTE_EXTRACTION_CACHE: return
         if len(cls._cache) >= cls._max_size:
             items = list(cls._cache.items())
             cls._cache = dict(items[int(cls._max_size * 0.2):])
         cls._cache[key] = value
-    
+
     @classmethod
-    def clear(cls):
-        """Clear the cache."""
-        cls._cache.clear()
-    
+    def clear(cls): cls._cache.clear()
+
     @classmethod
     def get_stats(cls) -> Dict:
-        """Get cache statistics."""
         return {
             "enabled": ENABLE_ATTRIBUTE_EXTRACTION_CACHE,
             "size": len(cls._cache),
             "max_size": cls._max_size,
-            "usage_percent": round(len(cls._cache) / cls._max_size * 100, 2) if cls._max_size > 0 else 0
+            "usage_percent": round(len(cls._cache)/cls._max_size*100, 2) if cls._max_size else 0
         }
 
-
 class EmbeddingCache:
-    """Cache for sentence transformer embeddings."""
     _cache = {}
     _max_size = EMBEDDING_CACHE_MAX_SIZE
-    _hit_count = 0
-    _miss_count = 0
-    
+    _hit = _miss = 0
+
     @classmethod
     def get_embedding(cls, text: str, model):
-        """Get or compute embedding with optional caching"""
-        # If caching is disabled, always compute fresh
         if not ENABLE_EMBEDDING_CACHE:
-            import warnings
             with warnings.catch_warnings():
                 warnings.simplefilter("ignore")
-                embedding = model.encode(text, convert_to_tensor=True, show_progress_bar=False)
-            return embedding
-        
-        # Caching is enabled, check cache first
+                return model.encode(text, convert_to_tensor=True, show_progress_bar=False)
         if text in cls._cache:
-            cls._hit_count += 1
+            cls._hit += 1
             return cls._cache[text]
-        
-        cls._miss_count += 1
-        
+        cls._miss += 1
         if len(cls._cache) >= cls._max_size:
             items = list(cls._cache.items())
             cls._cache = dict(items[int(cls._max_size * 0.3):])
-        
-        # Compute embedding
-        import warnings
         with warnings.catch_warnings():
             warnings.simplefilter("ignore")
-            embedding = model.encode(text, convert_to_tensor=True, show_progress_bar=False)
-        
-        cls._cache[text] = embedding
-        return embedding
-    
+            emb = model.encode(text, convert_to_tensor=True, show_progress_bar=False)
+        cls._cache[text] = emb
+        return emb
+
     @classmethod
     def clear(cls):
-        """Clear the cache and reset statistics."""
         cls._cache.clear()
-        cls._hit_count = 0
-        cls._miss_count = 0
-    
+        cls._hit = cls._miss = 0
+
     @classmethod
     def get_stats(cls) -> Dict:
-        """Get cache statistics."""
-        total = cls._hit_count + cls._miss_count
-        hit_rate = (cls._hit_count / total * 100) if total > 0 else 0
+        total = cls._hit + cls._miss
+        rate = (cls._hit / total * 100) if total else 0
         return {
             "enabled": ENABLE_EMBEDDING_CACHE,
             "size": len(cls._cache),
-            "hits": cls._hit_count,
-            "misses": cls._miss_count,
-            "hit_rate_percent": round(hit_rate, 2)
+            "hits": cls._hit,
+            "misses": cls._miss,
+            "hit_rate_percent": round(rate, 2),
         }
 
-
-# ==================== MAIN SERVICE CLASS ====================
-
+# --------------------------------------------------------------------------- #
+# RETRY DECORATOR
+# --------------------------------------------------------------------------- #
+def retry(max_attempts=3, delay=1.0):
+    def decorator(f):
+        @wraps(f)
+        def wrapper(*args, **kwargs):
+            last_exc = None
+            for i in range(max_attempts):
+                try:
+                    return f(*args, **kwargs)
+                except Exception as e:
+                    last_exc = e
+                    if i < max_attempts - 1:
+                        wait = delay * (2 ** i)
+                        logger.warning(f"Retry {i+1}/{max_attempts} after {wait}s: {e}")
+                        time.sleep(wait)
+            raise last_exc or RuntimeError("Retry failed")
+        return wrapper
+    return decorator
+
+# --------------------------------------------------------------------------- #
+# MAIN SERVICE
+# --------------------------------------------------------------------------- #
 class ProductAttributeService:
-    """Service class for extracting product attributes using Groq LLM."""
-
-    @staticmethod
-    def _generate_cache_key(product_text: str, mandatory_attrs: Dict) -> str:
-        """Generate cache key from product text and attributes."""
-        attrs_str = json.dumps(mandatory_attrs, sort_keys=True)
-        content = f"{product_text}:{attrs_str}"
-        return f"attr_{hashlib.md5(content.encode()).hexdigest()}"
-
-    @staticmethod
-    def normalize_dimension_text(text: str) -> str:
-        """Normalize dimension text to format like '16x20'."""
-        if not text:
-            return ""
-        
-        text = text.lower()
-        text = re.sub(r'\s*(inches|inch|in|cm|centimeters|mm|millimeters)\s*', '', text, flags=re.IGNORECASE)
-        
-        numbers = re.findall(r'\d+\.?\d*', text)
-        if not numbers:
-            return ""
-        
-        float_numbers = []
-        for num in numbers:
-            try:
-                float_numbers.append(float(num))
-            except:
-                continue
-        
-        if len(float_numbers) < 2:
-            return ""
-        
-        if len(float_numbers) == 3:
-            float_numbers = [float_numbers[0], float_numbers[2]]
-        elif len(float_numbers) > 3:
-            float_numbers = sorted(float_numbers)[-2:]
-        else:
-            float_numbers = float_numbers[:2]
-        
-        formatted_numbers = []
-        for num in float_numbers:
-            if num.is_integer():
-                formatted_numbers.append(str(int(num)))
-            else:
-                formatted_numbers.append(f"{num:.1f}")
-        
-        formatted_numbers.sort(key=lambda x: float(x))
-        return f"{formatted_numbers[0]}x{formatted_numbers[1]}"
-    
-    @staticmethod
-    def normalize_value_for_matching(value: str, attr_name: str = "") -> str:
-        """Normalize a value based on its attribute type."""
-        dimension_keywords = ['dimension', 'size', 'measurement']
-        if any(keyword in attr_name.lower() for keyword in dimension_keywords):
-            normalized = ProductAttributeService.normalize_dimension_text(value)
-            if normalized:
-                return normalized
-        return value.strip()
-
     @staticmethod
-    def combine_product_text(
-        title: Optional[str] = None,
-        short_desc: Optional[str] = None,
-        long_desc: Optional[str] = None,
-        ocr_text: Optional[str] = None
-    ) -> Tuple[str, Dict[str, str]]:
-        """Combine product metadata into a single text block."""
+    def combine_product_text(title=None, short_desc=None, long_desc=None, ocr_text=None) -> Tuple[str, Dict[str, str]]:
         parts = []
         source_map = {}
-        
         if title:
-            title_str = str(title).strip()
-            parts.append(f"Title: {title_str}")
-            source_map['title'] = title_str
+            t = str(title).strip()
+            parts.append(f"Title: {t}")
+            source_map["title"] = t
         if short_desc:
-            short_str = str(short_desc).strip()
-            parts.append(f"Description: {short_str}")
-            source_map['short_desc'] = short_str
+            s = str(short_desc).strip()
+            parts.append(f"Description: {s}")
+            source_map["short_desc"] = s
         if long_desc:
-            long_str = str(long_desc).strip()
-            parts.append(f"Details: {long_str}")
-            source_map['long_desc'] = long_str
+            l = str(long_desc).strip()
+            parts.append(f"Details: {l}")
+            source_map["long_desc"] = l
         if ocr_text:
             parts.append(f"OCR Text: {ocr_text}")
-            source_map['ocr_text'] = ocr_text
-        
+            source_map["ocr_text"] = ocr_text
         combined = "\n".join(parts).strip()
-        
-        if not combined:
-            return "No product information available", {}
-        
-        return combined, source_map
+        return (combined or "No product information", source_map)
+
+    @staticmethod
+    def _cache_key(product_text: str, mandatory_attrs: Dict, extract_additional: bool, multiple: List[str]) -> str:
+        payload = {"text": product_text, "attrs": mandatory_attrs, "extra": extract_additional, "multiple": sorted(multiple)}
+        return f"attr_{hashlib.md5(json.dumps(payload, sort_keys=True).encode()).hexdigest()}"
+
+    @staticmethod
+    def _clean_json(text: str) -> str:
+        start = text.find("{")
+        end = text.rfind("}") + 1
+        if start != -1 and end > start:
+            text = text[start:end]
+        if "```json" in text:
+            text = text.split("```json", 1)[1].split("```", 1)[0]
+        elif "```" in text:
+            text = text.split("```", 1)[1].split("```", 1)[0]
+            if text.lstrip().startswith("json"): text = text[4:]
+        return text.strip()
+
+    @staticmethod
+    def _lexical_evidence(product_text: str, label: str) -> float:
+        pt = product_text.lower()
+        tokens = [t for t in label.lower().replace("-", " ").split() if t]
+        if not tokens: return 0.0
+        hits = sum(1 for t in tokens if t in pt)
+        return hits / len(tokens)
 
     @staticmethod
-    def find_value_source(value: str, source_map: Dict[str, str], attr_name: str = "") -> str:
-        """Find which source(s) contain the given value."""
+    def _find_source(value: str, source_map: Dict[str, str]) -> str:
         value_lower = value.lower()
-        value_tokens = set(value_lower.replace("-", " ").replace("x", " ").split())
-        
-        is_dimension_attr = any(keyword in attr_name.lower() for keyword in ['dimension', 'size', 'measurement'])
-        
-        sources_found = []
-        source_scores = {}
-        
-        for source_name, source_text in source_map.items():
-            source_lower = source_text.lower()
-            
-            if value_lower in source_lower:
-                source_scores[source_name] = 1.0
-                continue
-            
-            if is_dimension_attr:
-                normalized_value = ProductAttributeService.normalize_dimension_text(value)
-                if not normalized_value:
-                    normalized_value = value.replace("x", " ").strip()
-                
-                normalized_source = ProductAttributeService.normalize_dimension_text(source_text)
-                
-                if normalized_value == normalized_source:
-                    source_scores[source_name] = 0.95
-                    continue
-                
-                dim_parts = normalized_value.split("x") if "x" in normalized_value else []
-                if len(dim_parts) == 2:
-                    if all(part in source_text for part in dim_parts):
-                        source_scores[source_name] = 0.85
-                        continue
-            
-            token_matches = sum(1 for token in value_tokens if token and token in source_lower)
-            if token_matches > 0 and len(value_tokens) > 0:
-                source_scores[source_name] = token_matches / len(value_tokens)
-        
-        if source_scores:
-            max_score = max(source_scores.values())
-            sources_found = [s for s, score in source_scores.items() if score == max_score]
-            
-            priority = ['title', 'short_desc', 'long_desc', 'ocr_text']
-            for p in priority:
-                if p in sources_found:
-                    return p
-            
-            return sources_found[0] if sources_found else "Not found"
-        
-        return "Not found"
+        for src_key, text in source_map.items():
+            if value_lower in text.lower():
+                return src_key
+        return "not_found"
 
     @staticmethod
     def format_visual_attributes(visual_attributes: Dict) -> Dict:
-        """Convert visual attributes to array format with source tracking."""
         formatted = {}
-        
         for key, value in visual_attributes.items():
             if isinstance(value, list):
                 formatted[key] = [{"value": str(item), "source": "image"} for item in value]
             elif isinstance(value, dict):
-                nested_formatted = {}
-                for nested_key, nested_value in value.items():
-                    if isinstance(nested_value, list):
-                        nested_formatted[nested_key] = [{"value": str(item), "source": "image"} for item in nested_value]
+                nested = {}
+                for sub_key, sub_val in value.items():
+                    if isinstance(sub_val, list):
+                        nested[sub_key] = [{"value": str(v), "source": "image"} for v in sub_val]
                     else:
-                        nested_formatted[nested_key] = [{"value": str(nested_value), "source": "image"}]
-                formatted[key] = nested_formatted
+                        nested[sub_key] = [{"value": str(sub_val), "source": "image"}]
+                formatted[key] = nested
             else:
                 formatted[key] = [{"value": str(value), "source": "image"}]
-        
         return formatted
 
     @staticmethod
-    def extract_attributes_from_ocr(ocr_results: Dict, model: str = None) -> Dict:
-        """Extract structured attributes from OCR text using LLM."""
-        if model is None:
-            model = settings.SUPPORTED_MODELS[0]
-        
-        detected_text = ocr_results.get('detected_text', [])
-        if not detected_text:
-            return {}
-        
-        ocr_text = "\n".join([f"Text: {item['text']}, Confidence: {item['confidence']:.2f}" 
-                              for item in detected_text])
-        
-        prompt = f"""
-You are an AI model that extracts structured attributes from OCR text detected on product images.
-Given the OCR detections below, infer the possible product attributes and return them as a clean JSON object.
-
-OCR Text:
-{ocr_text}
-
-Extract relevant attributes like:
-- brand
-- model_number
-- size (waist_size, length, etc.)
-- collection
-- any other relevant product information
-
-Return a JSON object with only the attributes you can confidently identify.
-If an attribute is not present, do not include it in the response.
-"""
-        
-        payload = {
-            "model": model,
-            "messages": [
-                {
-                    "role": "system",
-                    "content": "You are a helpful AI that extracts structured data from OCR output. Return only valid JSON."
-                },
-                {"role": "user", "content": prompt}
-            ],
-            "temperature": 0.2,
-            "max_tokens": 500
-        }
-        
-        headers = {
-            "Authorization": f"Bearer {settings.GROQ_API_KEY}",
-            "Content-Type": "application/json",
-        }
-        
-        try:
-            response = requests.post(
-                settings.GROQ_API_URL,
-                headers=headers,
-                json=payload,
-                timeout=30
-            )
-            response.raise_for_status()
-            result_text = response.json()["choices"][0]["message"]["content"].strip()
-            
-            result_text = ProductAttributeService._clean_json_response(result_text)
-            parsed = json.loads(result_text)
-            
-            formatted_attributes = {}
-            for key, value in parsed.items():
-                if key == "error":
-                    continue
-                
-                if isinstance(value, dict):
-                    nested_formatted = {}
-                    for nested_key, nested_value in value.items():
-                        nested_formatted[nested_key] = [{"value": str(nested_value), "source": "image"}]
-                    formatted_attributes[key] = nested_formatted
-                elif isinstance(value, list):
-                    formatted_attributes[key] = [{"value": str(item), "source": "image"} for item in value]
-                else:
-                    formatted_attributes[key] = [{"value": str(value), "source": "image"}]
-            
-            return formatted_attributes
-        except Exception as e:
-            logger.error(f"OCR attribute extraction failed: {str(e)}")
-            return {"error": f"Failed to extract attributes from OCR: {str(e)}"}
-
-    @staticmethod
-    def calculate_attribute_relationships(
-        mandatory_attrs: Dict[str, List[str]],
-        product_text: str
-    ) -> Dict[str, float]:
-        """Calculate semantic relationships between attribute values."""
-        pt_emb = EmbeddingCache.get_embedding(product_text, model_embedder)
-
-        attr_scores = {}
-        for attr, values in mandatory_attrs.items():
-            attr_scores[attr] = {}
-            for val in values:
-                contexts = [val, f"for {val}", f"use in {val}", f"suitable for {val}"]
-                ctx_embs = [EmbeddingCache.get_embedding(c, model_embedder) for c in contexts]
-                sem_sim = max(float(util.cos_sim(pt_emb, ce).item()) for ce in ctx_embs)
-                attr_scores[attr][val] = sem_sim
-
-        relationships = {}
-        attr_list = list(mandatory_attrs.keys())
-
-        for i, attr1 in enumerate(attr_list):
-            for attr2 in attr_list[i+1:]:
-                for val1 in mandatory_attrs[attr1]:
-                    for val2 in mandatory_attrs[attr2]:
-                        emb1 = EmbeddingCache.get_embedding(val1, model_embedder)
-                        emb2 = EmbeddingCache.get_embedding(val2, model_embedder)
-                        sim = float(util.cos_sim(emb1, emb2).item())
-
-                        key1 = f"{attr1}:{val1}->{attr2}:{val2}"
-                        key2 = f"{attr2}:{val2}->{attr1}:{val1}"
-                        relationships[key1] = sim
-                        relationships[key2] = sim
-
-        return relationships
-
-    @staticmethod
-    def calculate_value_clusters(
-        values: List[str],
-        scores: List[Tuple[str, float]],
-        cluster_threshold: float = 0.4
-    ) -> List[List[str]]:
-        """Group values into semantic clusters."""
-        if len(values) <= 1:
-            return [[val] for val, _ in scores]
-
-        embeddings = [EmbeddingCache.get_embedding(val, model_embedder) for val in values]
-
-        similarity_matrix = np.zeros((len(values), len(values)))
-        for i in range(len(values)):
-            for j in range(i+1, len(values)):
-                sim = float(util.cos_sim(embeddings[i], embeddings[j]).item())
-                similarity_matrix[i][j] = sim
-                similarity_matrix[j][i] = sim
-
-        clusters = []
-        visited = set()
-
-        for i, (val, score) in enumerate(scores):
-            if i in visited:
-                continue
-
-            cluster = [val]
-            visited.add(i)
-
-            for j in range(len(values)):
-                if j not in visited and similarity_matrix[i][j] >= cluster_threshold:
-                    cluster.append(values[j])
-                    visited.add(j)
-
-            clusters.append(cluster)
-
-        return clusters
-
-    @staticmethod
-    def get_dynamic_threshold(
-        attr: str,
-        val: str,
-        base_score: float,
-        extracted_attrs: Dict[str, List[Dict[str, str]]],
-        relationships: Dict[str, float],
-        mandatory_attrs: Dict[str, List[str]],
-        base_threshold: float = 0.65,
-        boost_factor: float = 0.15
-    ) -> float:
-        """Calculate dynamic threshold based on relationships."""
-        threshold = base_threshold
-
-        max_relationship = 0.0
-        for other_attr, other_values_list in extracted_attrs.items():
-            if other_attr == attr:
-                continue
-
-            for other_val_dict in other_values_list:
-                other_val = other_val_dict['value']
-                key = f"{attr}:{val}->{other_attr}:{other_val}"
-                if key in relationships:
-                    max_relationship = max(max_relationship, relationships[key])
-
-        if max_relationship > 0.6:
-            threshold = base_threshold - (boost_factor * max_relationship)
-
-        return max(0.3, threshold)
-
-    @staticmethod
-    def get_adaptive_margin(
-        scores: List[Tuple[str, float]],
-        base_margin: float = 0.15,
-        max_margin: float = 0.22
-    ) -> float:
-        """Calculate adaptive margin based on score distribution."""
-        if len(scores) < 2:
-            return base_margin
-
-        score_values = [s for _, s in scores]
-        best_score = score_values[0]
-
-        if best_score < 0.5:
-            top_scores = score_values[:min(4, len(score_values))]
-            score_range = max(top_scores) - min(top_scores)
-
-            if score_range < 0.30:
-                score_factor = (0.5 - best_score) * 0.35
-                adaptive = base_margin + score_factor + (0.30 - score_range) * 0.2
-                return min(adaptive, max_margin)
-
-        return base_margin
-
-    @staticmethod
-    def _lexical_evidence(product_text: str, label: str) -> float:
-        """Calculate lexical overlap between product text and label."""
-        pt = product_text.lower()
-        tokens = [t for t in label.lower().replace("-", " ").split() if t]
-        if not tokens:
-            return 0.0
-        hits = sum(1 for t in tokens if t in pt)
-        return hits / len(tokens)
-
-    @staticmethod
-    def normalize_against_product_text(
-        product_text: str,
-        mandatory_attrs: Dict[str, List[str]],
-        source_map: Dict[str, str],
-        threshold_abs: float = 0.65,
-        margin: float = 0.15,
-        allow_multiple: bool = False,
-        sem_weight: float = 0.8,
-        lex_weight: float = 0.2,
-        extracted_attrs: Optional[Dict[str, List[Dict[str, str]]]] = None,
-        relationships: Optional[Dict[str, float]] = None,
-        use_dynamic_thresholds: bool = True,
-        use_adaptive_margin: bool = True,
-        use_semantic_clustering: bool = True
-    ) -> dict:
-        """Score each allowed value against the product_text."""
-        if extracted_attrs is None:
-            extracted_attrs = {}
-        if relationships is None:
-            relationships = {}
-
-        pt_emb = EmbeddingCache.get_embedding(product_text, model_embedder)
-        extracted = {}
-
-        for attr, allowed_values in mandatory_attrs.items():
-            scores: List[Tuple[str, float]] = []
-            
-            is_dimension_attr = any(keyword in attr.lower() for keyword in ['dimension', 'size', 'measurement'])
-            normalized_product_text = ProductAttributeService.normalize_dimension_text(product_text) if is_dimension_attr else ""
-
-            for val in allowed_values:
-                if is_dimension_attr:
-                    normalized_val = ProductAttributeService.normalize_dimension_text(val)
-                    
-                    if normalized_val and normalized_product_text and normalized_val == normalized_product_text:
-                        scores.append((val, 1.0))
-                        continue
-                    
-                    if normalized_val:
-                        val_numbers = normalized_val.split('x')
-                        text_lower = product_text.lower()
-                        if all(num in text_lower for num in val_numbers):
-                            idx1 = text_lower.find(val_numbers[0])
-                            idx2 = text_lower.find(val_numbers[1])
-                            if idx1 != -1 and idx2 != -1:
-                                distance = abs(idx2 - idx1)
-                                if distance < 20:
-                                    scores.append((val, 0.95))
-                                    continue
-                
-                contexts = [val, f"for {val}", f"use in {val}", f"suitable for {val}", f"{val} room"]
-                ctx_embs = [EmbeddingCache.get_embedding(c, model_embedder) for c in contexts]
-                sem_sim = max(float(util.cos_sim(pt_emb, ce).item()) for ce in ctx_embs)
-
-                lex_score = ProductAttributeService._lexical_evidence(product_text, val)
-                final_score = sem_weight * sem_sim + lex_weight * lex_score
-                scores.append((val, final_score))
-
-            scores.sort(key=lambda x: x[1], reverse=True)
-            best_val, best_score = scores[0]
-
-            effective_margin = margin
-            if allow_multiple and use_adaptive_margin:
-                effective_margin = ProductAttributeService.get_adaptive_margin(scores, margin)
-
-            if is_dimension_attr and best_score >= 0.90:
-                source = ProductAttributeService.find_value_source(best_val, source_map, attr)
-                extracted[attr] = [{"value": best_val, "source": source}]
-                continue
-
-            if not allow_multiple:
-                source = ProductAttributeService.find_value_source(best_val, source_map, attr)
-                extracted[attr] = [{"value": best_val, "source": source}]
-            else:
-                candidates = [best_val]
-                use_base_threshold = best_score >= threshold_abs
-
-                clusters = []
-                if use_semantic_clustering:
-                    clusters = ProductAttributeService.calculate_value_clusters(
-                        allowed_values, scores, cluster_threshold=0.4
-                    )
-                    best_cluster = next((c for c in clusters if best_val in c), [best_val])
-
-                for val, sc in scores[1:]:
-                    min_score = 0.4 if is_dimension_attr else 0.3
-                    if sc < min_score:
-                        continue
-                    
-                    if use_dynamic_thresholds and extracted_attrs:
-                        dynamic_thresh = ProductAttributeService.get_dynamic_threshold(
-                            attr, val, sc, extracted_attrs, relationships,
-                            mandatory_attrs, threshold_abs
-                        )
-                    else:
-                        dynamic_thresh = threshold_abs
-
-                    within_margin = (best_score - sc) <= effective_margin
-                    above_threshold = sc >= dynamic_thresh
-
-                    in_cluster = False
-                    if use_semantic_clustering and clusters:
-                        in_cluster = any(best_val in c and val in c for c in clusters)
-
-                    if use_base_threshold:
-                        if above_threshold and within_margin:
-                            candidates.append(val)
-                        elif in_cluster and within_margin:
-                            candidates.append(val)
-                    else:
-                        if within_margin:
-                            candidates.append(val)
-                        elif in_cluster and (best_score - sc) <= effective_margin * 2.0:
-                            candidates.append(val)
-
-                extracted[attr] = []
-                for candidate in candidates:
-                    source = ProductAttributeService.find_value_source(candidate, source_map, attr)
-                    extracted[attr].append({"value": candidate, "source": source})
-
-        return extracted
+    @retry(max_attempts=3, delay=1.0)
+    def _call_llm(payload: dict) -> str:
+        headers = {"Authorization": f"Bearer {settings.GROQ_API_KEY}", "Content-Type": "application/json"}
+        resp = requests.post(settings.GROQ_API_URL, headers=headers, json=payload, timeout=30)
+        resp.raise_for_status()
+        return resp.json()["choices"][0]["message"]["content"]
 
     @staticmethod
     def extract_attributes(
@@ -649,315 +220,157 @@ If an attribute is not present, do not include it in the response.
         model: str = None,
         extract_additional: bool = True,
         multiple: Optional[List[str]] = None,
-        threshold_abs: float = 0.65,
-        margin: float = 0.15,
-        use_dynamic_thresholds: bool = True,
-        use_adaptive_margin: bool = True,
-        use_semantic_clustering: bool = True,
-        use_cache: bool = None  # ⚡ NEW: Can override global setting
+        use_cache: Optional[bool] = None,
     ) -> dict:
-        """Extract attributes from product text using Groq LLM."""
-        
-        if model is None:
-            model = settings.SUPPORTED_MODELS[0]
-
-        if multiple is None:
-            multiple = []
-
-        if source_map is None:
-            source_map = {}
-
-        # ⚡ CACHE CONTROL: use parameter if provided, otherwise use global setting
-        if use_cache is None:
-            use_cache = ENABLE_ATTRIBUTE_EXTRACTION_CACHE
-        
-        # If caching is globally disabled, force use_cache to False
-        if not is_caching_enabled():
-            use_cache = False
-
-        if not product_text or product_text == "No product information available":
-            return ProductAttributeService._create_error_response(
-                "No product information provided",
-                mandatory_attrs,
-                extract_additional
-            )
-
-        # ⚡ CHECK CACHE FIRST (only if enabled)
-        if use_cache:
-            cache_key = ProductAttributeService._generate_cache_key(product_text, mandatory_attrs)
-            cached_result = SimpleCache.get(cache_key)
-            if cached_result:
-                logger.info(f"✓ Cache hit (caching enabled)")
-                return cached_result
-        else:
-            logger.info(f"⚠ Cache disabled - processing fresh")
-
-        mandatory_attr_list = []
-        for attr_name, allowed_values in mandatory_attrs.items():
-            mandatory_attr_list.append(f"{attr_name}: {', '.join(allowed_values)}")
-        mandatory_attr_text = "\n".join(mandatory_attr_list)
-
-        additional_instruction = ""
-        if extract_additional:
-            additional_instruction = """
-2. Extract ADDITIONAL attributes: Identify any other relevant attributes from the product text 
-   that are NOT in the mandatory list. Only include attributes where you can find actual values
-   in the product text. Do NOT include attributes with "Not Specified" or empty values.
-   
-   Examples of attributes to look for (only if present): Brand, Material, Size, Color, Dimensions,
-   Weight, Features, Style, Theme, Pattern, Finish, Care Instructions, etc."""
-
-        output_format = {
-            "mandatory": {attr: "value or list of values" for attr in mandatory_attrs.keys()},
-        }
+        if model is None: model = settings.SUPPORTED_MODELS[0]
+        if multiple is None: multiple = []
+        if source_map is None: source_map = {}
 
-        if extract_additional:
-            output_format["additional"] = {
-                "example_attribute_1": "actual value found",
-                "example_attribute_2": "actual value found"
-            }
-            output_format["additional"]["_note"] = "Only include attributes with actual values found in text"
+        if use_cache is None: use_cache = ENABLE_ATTRIBUTE_EXTRACTION_CACHE
+        if not is_caching_enabled(): use_cache = False
+
+        cache_key = None
+        if use_cache:
+            cache_key = ProductAttributeService._cache_key(product_text, mandatory_attrs, extract_additional, multiple)
+            cached = SimpleCache.get(cache_key)
+            if cached:
+                logger.info(f"CACHE HIT {cache_key[:16]}...")
+                return cached
+
+        # --------------------------- PROMPT WITH FULL WHITELIST ---------------------------
+        allowed_lines = [f"{attr}: {', '.join(vals)}" for attr, vals in mandatory_attrs.items()]
+        allowed_text = "\n".join(allowed_lines)
+        allowed_sources = list(source_map.keys()) + ["not_found"]
+        source_hint = "|".join(allowed_sources)
+        multiple_text = f"\nMULTIPLE ALLOWED FOR: {', '.join(multiple)}" if multiple else ""
 
         prompt = f"""
-You are an intelligent product attribute extractor that works with ANY product type.
+You are a product-attribute classifier.
+Pick **exactly one** value from the list below for each attribute.
+If nothing matches, return "Not Specified".
 
-TASK:
-1. Extract MANDATORY attributes: For each mandatory attribute, select the most appropriate value(s)
-   from the provided list. Choose the value(s) that best match the product description.
-{additional_instruction}
+ALLOWED VALUES:
+{allowed_text}
+{multiple_text}
 
-Product Text:
+PRODUCT TEXT:
 {product_text}
 
-Mandatory Attribute Lists (MUST select from these allowed values):
-{mandatory_attr_text}
-
-CRITICAL INSTRUCTIONS:
-- Return ONLY valid JSON, nothing else
-- No explanations, no markdown, no text before or after the JSON
-- For mandatory attributes, choose the value(s) from the provided list that best match
-- If a mandatory attribute cannot be determined from the product text, use "Not Specified"
-- Prefer exact matches from the allowed values list over generic synonyms
-- If multiple values are plausible, you MAY return more than one
-{f"- For additional attributes: ONLY include attributes where you found actual values in the product text. DO NOT include attributes with 'Not Specified', 'None', 'N/A', or empty values. If you cannot find a value for an attribute, simply don't include that attribute." if extract_additional else ""}
-- Be precise and only extract information that is explicitly stated or clearly implied
-
-Required Output Format:
-{json.dumps(output_format, indent=2)}
-        """
+OUTPUT (strict JSON only):
+{{
+  "mandatory": {{
+    "<attr>": [{{"value":"<chosen>", "source":"<{source_hint}>"}}]
+  }},
+  "additional": {{}}
+}}
+
+RULES:
+- Pick from allowed values only
+- If not found: "Not Specified" + "not_found"
+- Source must be: {source_hint}
+- Return ONLY JSON
+"""
 
         payload = {
             "model": model,
             "messages": [
-                {
-                    "role": "system",
-                    "content": f"You are a precise attribute extraction model. Return ONLY valid JSON with {'mandatory and additional' if extract_additional else 'mandatory'} sections. No explanations, no markdown, no other text."
-                },
-                {"role": "user", "content": prompt}
+                {"role": "system", "content": "You are a JSON-only extractor."},
+                {"role": "user", "content": prompt},
             ],
             "temperature": 0.0,
-            "max_tokens": 1500
-        }
-
-        headers = {
-            "Authorization": f"Bearer {settings.GROQ_API_KEY}",
-            "Content-Type": "application/json",
+            "max_tokens": 1200,
         }
 
         try:
-            response = requests.post(
-                settings.GROQ_API_URL,
-                headers=headers,
-                json=payload,
-                timeout=30
-            )
-            response.raise_for_status()
-            result_text = response.json()["choices"][0]["message"]["content"].strip()
-
-            result_text = ProductAttributeService._clean_json_response(result_text)
-
-            parsed = json.loads(result_text)
-
-            parsed = ProductAttributeService._validate_response_structure(
-                parsed, mandatory_attrs, extract_additional, source_map
-            )
-
-            if extract_additional and "additional" in parsed:
-                cleaned_additional = {}
-                for k, v in parsed["additional"].items():
-                    if v and v not in ["Not Specified", "None", "N/A", "", "not specified", "none", "n/a"]:
-                        if not (isinstance(v, str) and v.lower() in ["not specified", "none", "n/a", ""]):
-                            if isinstance(v, list):
-                                cleaned_additional[k] = []
-                                for item in v:
-                                    if isinstance(item, dict) and "value" in item:
-                                        if "source" not in item:
-                                            item["source"] = ProductAttributeService.find_value_source(
-                                                item["value"], source_map, k
-                                            )
-                                        cleaned_additional[k].append(item)
-                                    else:
-                                        source = ProductAttributeService.find_value_source(str(item), source_map, k)
-                                        cleaned_additional[k].append({"value": str(item), "source": source})
-                            else:
-                                source = ProductAttributeService.find_value_source(str(v), source_map, k)
-                                cleaned_additional[k] = [{"value": str(v), "source": source}]
-                parsed["additional"] = cleaned_additional
-
-            relationships = {}
-            if use_dynamic_thresholds:
-                relationships = ProductAttributeService.calculate_attribute_relationships(
-                    mandatory_attrs, product_text
-                )
-
-            extracted_so_far = {}
-            for attr in mandatory_attrs.keys():
-                allow_multiple = attr in multiple
-
-                result = ProductAttributeService.normalize_against_product_text(
-                    product_text=product_text,
-                    mandatory_attrs={attr: mandatory_attrs[attr]},
-                    source_map=source_map,
-                    threshold_abs=threshold_abs,
-                    margin=margin,
-                    allow_multiple=allow_multiple,
-                    extracted_attrs=extracted_so_far,
-                    relationships=relationships,
-                    use_dynamic_thresholds=use_dynamic_thresholds,
-                    use_adaptive_margin=use_adaptive_margin,
-                    use_semantic_clustering=use_semantic_clustering
-                )
-
-                parsed["mandatory"][attr] = result[attr]
-                extracted_so_far[attr] = result[attr]
-
-            # ⚡ CACHE THE RESULT (only if enabled)
-            if use_cache:
-                SimpleCache.set(cache_key, parsed)
-                logger.info(f"✓ Result cached")
-
-            return parsed
-
-        except requests.exceptions.RequestException as e:
-            logger.error(f"Request exception: {str(e)}")
-            return ProductAttributeService._create_error_response(
-                str(e), mandatory_attrs, extract_additional
-            )
-        except json.JSONDecodeError as e:
-            logger.error(f"JSON decode error: {str(e)}")
-            return ProductAttributeService._create_error_response(
-                f"Invalid JSON: {str(e)}", mandatory_attrs, extract_additional, result_text
-            )
-        except Exception as e:
-            logger.error(f"Unexpected error: {str(e)}")
-            return ProductAttributeService._create_error_response(
-                str(e), mandatory_attrs, extract_additional
-            )
-
-    @staticmethod
-    def _clean_json_response(text: str) -> str:
-        """Clean LLM response to extract valid JSON."""
-        start_idx = text.find('{')
-        end_idx = text.rfind('}')
+            raw = ProductAttributeService._call_llm(payload)
+            cleaned = ProductAttributeService._clean_json(raw)
+            parsed = json.loads(cleaned)
+        except Exception as exc:
+            logger.error(f"LLM failed: {exc}")
+            return {
+                "mandatory": {a: [{"value": "Not Specified", "source": "llm_error"}] for a in mandatory_attrs},
+                "additional": {} if not extract_additional else {},
+                "error": str(exc)
+            }
 
-        if start_idx != -1 and end_idx != -1:
-            text = text[start_idx:end_idx + 1]
+        # --------------------------- VALIDATION + SMART RECOVERY ---------------------------
+        pt_emb = EmbeddingCache.get_embedding(product_text, model_embedder)
 
-        if "```json" in text:
-            text = text.split("```json")[1].split("```")[0].strip()
-        elif "```" in text:
-            text = text.split("```")[1].split("```")[0].strip()
-            if text.startswith("json"):
-                text = text[4:].strip()
+        def _sanitize(section: dict, allowed: Dict):
+            sanitized = {}
+            for attr, items in section.items():
+                if attr not in allowed: continue
+                chosen = []
+                for it in (items if isinstance(items, list) else [items]):
+                    if not isinstance(it, dict): it = {"value": str(it), "source": "not_found"}
+                    val = str(it.get("value", "")).strip()
+                    src = str(it.get("source", "not_found")).lower()
+
+                    # --- LLM SAYS "Not Specified" → SMART RECOVERY ---
+                    if val == "Not Specified":
+                        # 1. Lexical recovery
+                        for av in allowed[attr]:
+                            if ProductAttributeService._lexical_evidence(product_text, av) > 0.6:
+                                src = ProductAttributeService._find_source(av, source_map)
+                                chosen.append({"value": av, "source": src})
+                                break
+                        else:
+                            # 2. Semantic recovery
+                            best_val, best_score = max(
+                                ((av, float(util.cos_sim(pt_emb, EmbeddingCache.get_embedding(av, model_embedder)).item()))
+                                for av in allowed[attr]),
+                                key=lambda x: x[1]
+                            )
+                            if best_score > 0.75:
+                                src = ProductAttributeService._find_source(best_val, source_map)
+                                chosen.append({"value": best_val, "source": src})
+                            else:
+                                chosen.append({"value": "Not Specified", "source": "not_found"})
+                        continue
 
-        return text
+                    # --- VALIDATE LLM CHOICE ---
+                    if val not in allowed[attr]: continue
+                    if ProductAttributeService._lexical_evidence(product_text, val) < 0.2: continue
+                    if src not in source_map and src != "not_found": src = "not_found"
+                    chosen.append({"value": val, "source": src})
+
+                sanitized[attr] = chosen or [{"value": "Not Specified", "source": "not_found"}]
+            return sanitized
+
+        parsed["mandatory"] = _sanitize(parsed.get("mandatory", {}), mandatory_attrs)
+
+        # --- ADDITIONAL ATTRIBUTES ---
+        if extract_additional and "additional" in parsed:
+            additional = {}
+            for attr, items in parsed["additional"].items():
+                good = []
+                for it in (items if isinstance(items, list) else [items]):
+                    if not isinstance(it, dict): it = {"value": str(it), "source": "not_found"}
+                    val = str(it.get("value", "")).strip()
+                    src = str(it.get("source", "not_found")).lower()
+                    if src not in source_map and src != "not_found": src = "not_found"
+                    if val: good.append({"value": val, "source": src})
+                if good: additional[attr] = good
+            parsed["additional"] = additional
+        else:
+            parsed.pop("additional", None)
 
-    @staticmethod
-    def _validate_response_structure(
-        parsed: dict,
-        mandatory_attrs: Dict[str, List[str]],
-        extract_additional: bool,
-        source_map: Dict[str, str] = None
-    ) -> dict:
-        """Validate and fix the response structure."""
-        if source_map is None:
-            source_map = {}
-        
-        expected_sections = ["mandatory"]
-        if extract_additional:
-            expected_sections.append("additional")
-
-        if not all(section in parsed for section in expected_sections):
-            if isinstance(parsed, dict):
-                mandatory_keys = set(mandatory_attrs.keys())
-                mandatory = {k: v for k, v in parsed.items() if k in mandatory_keys}
-                additional = {k: v for k, v in parsed.items() if k not in mandatory_keys}
-
-                result = {"mandatory": mandatory}
-                if extract_additional:
-                    result["additional"] = additional
-                parsed = result
-            else:
-                return ProductAttributeService._create_error_response(
-                    "Invalid response structure",
-                    mandatory_attrs,
-                    extract_additional,
-                    str(parsed)
-                )
-
-        if "mandatory" in parsed:
-            converted_mandatory = {}
-            for attr, value in parsed["mandatory"].items():
-                if isinstance(value, list):
-                    converted_mandatory[attr] = []
-                    for item in value:
-                        if isinstance(item, dict) and "value" in item:
-                            if "source" not in item:
-                                item["source"] = ProductAttributeService.find_value_source(
-                                    item["value"], source_map, attr
-                                )
-                            converted_mandatory[attr].append(item)
-                        else:
-                            source = ProductAttributeService.find_value_source(str(item), source_map, attr)
-                            converted_mandatory[attr].append({"value": str(item), "source": source})
-                else:
-                    source = ProductAttributeService.find_value_source(str(value), source_map, attr)
-                    converted_mandatory[attr] = [{"value": str(value), "source": source}]
-            
-            parsed["mandatory"] = converted_mandatory
+        if use_cache and cache_key:
+            SimpleCache.set(cache_key, parsed)
+            logger.info(f"CACHE SET {cache_key[:16]}...")
 
         return parsed
 
-    @staticmethod
-    def _create_error_response(
-        error: str,
-        mandatory_attrs: Dict[str, List[str]],
-        extract_additional: bool,
-        raw_output: Optional[str] = None
-    ) -> dict:
-        """Create a standardized error response."""
-        response = {
-            "mandatory": {attr: [{"value": "Not Specified", "source": "error"}] for attr in mandatory_attrs.keys()},
-            "error": error
-        }
-        if extract_additional:
-            response["additional"] = {}
-        if raw_output:
-            response["raw_output"] = raw_output
-        return response
-
     @staticmethod
     def get_cache_stats() -> Dict:
-        """Get statistics for all caches including global status."""
         return {
-            "global_caching_enabled": is_caching_enabled(),
-            "simple_cache": SimpleCache.get_stats(),
-            "embedding_cache": EmbeddingCache.get_stats()
+            "global_enabled": is_caching_enabled(),
+            "result_cache": SimpleCache.get_stats(),
+            "embedding_cache": EmbeddingCache.get_stats(),
         }
 
     @staticmethod
     def clear_all_caches():
-        """Clear all caches."""
         SimpleCache.clear()
         EmbeddingCache.clear()
         logger.info("All caches cleared")

+ 6 - 0
attr_extraction/urls.py

@@ -13,6 +13,8 @@ from .views import (
     # ProductAttributeValueUploadExcelView,
     ProductListWithAttributesView
 )
+from .views import CacheManagementView, CacheStatsView
+
 
 urlpatterns = [
     # Existing endpoints
@@ -32,8 +34,12 @@ urlpatterns = [
 
     path('attribute-values/', ProductAttributeValueView.as_view(), name='attribute-values'),
     path('attribute-values/bulk/', BulkProductAttributeValueView.as_view(), name='attribute-values-bulk'),
+    path('cache/management/', CacheManagementView.as_view(), name='cache-management'),
+    path('cache/stats/', CacheStatsView.as_view(), name='cache-stats'),
+
 ]
 
 
 
 
+            

+ 9 - 9
attr_extraction/views.py

@@ -183,7 +183,7 @@ class BatchExtractProductAttributesView(APIView):
     def post(self, request):
         import time
         start_time = time.time()
-        
+
         serializer = BatchProductRequestSerializer(data=request.data)
         if not serializer.is_valid():
             return Response({"error": serializer.errors}, status=status.HTTP_400_BAD_REQUEST)
@@ -222,9 +222,9 @@ class BatchExtractProductAttributesView(APIView):
         multiple = validated_data.get("multiple", [])
         threshold_abs = validated_data.get("threshold_abs", 0.65)
         margin = validated_data.get("margin", 0.15)
-        use_dynamic_thresholds = validated_data.get("use_dynamic_thresholds", True)
-        use_adaptive_margin = validated_data.get("use_adaptive_margin", True)
-        use_semantic_clustering = validated_data.get("use_semantic_clustering", True)
+        use_dynamic_thresholds = validated_data.get("use_dynamic_thresholds", False)
+        use_adaptive_margin = validated_data.get("use_adaptive_margin", False)
+        use_semantic_clustering = validated_data.get("use_semantic_clustering", False)
         
         results = []
         successful = 0
@@ -309,11 +309,11 @@ class BatchExtractProductAttributesView(APIView):
                     model=model,
                     extract_additional=extract_additional,
                     multiple=multiple,
-                    threshold_abs=threshold_abs,
-                    margin=margin,
-                    use_dynamic_thresholds=use_dynamic_thresholds,
-                    use_adaptive_margin=use_adaptive_margin,
-                    use_semantic_clustering=use_semantic_clustering,
+                    # threshold_abs=threshold_abs,
+                    # margin=margin,
+                    # use_dynamic_thresholds=use_dynamic_thresholds,
+                    # use_adaptive_margin=use_adaptive_margin,
+                    # use_semantic_clustering=use_semantic_clustering,
                     use_cache=True  # ⚡ CRITICAL: Enable caching
                 )
 

+ 16 - 22
attr_extraction/visual_processing_service.py

@@ -128,29 +128,19 @@ class VisualProcessingService:
             cls._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
             logger.info(f"Visual Processing using device: {cls._device}")
         return cls._device
-    
+        
+    # ==================== visual_processing_service.py ====================
     @classmethod
     def _get_clip_model(cls):
         """
-        Lazy load CLIP model with optional class-level caching.
-        ⚡ If caching is disabled, model is still loaded but not persisted at class level.
+        🔥 ALWAYS cache CLIP model (ignores global cache setting).
+        This is a 400MB model that takes 30-60s to load.
         """
-        # ⚡ CACHE CONTROL: If caching is disabled, always reload (no persistence)
-        if not ENABLE_CLIP_MODEL_CACHE:
-            logger.info("⚠ CLIP model caching is DISABLED - loading fresh instance")
-            model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
-            processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
-            
-            device = cls._get_device()
-            model.to(device)
-            model.eval()
-            
-            logger.info("✓ CLIP model loaded (no caching)")
-            return model, processor
-        
-        # Caching is enabled - use class-level cache
         if cls._clip_model is None:
-            logger.info("Loading CLIP model (this may take a few minutes on first use)...")
+            import time
+            start = time.time()
+            logger.info("📥 Loading CLIP model from HuggingFace...")
+            
             cls._clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
             cls._clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
             
@@ -158,12 +148,16 @@ class VisualProcessingService:
             cls._clip_model.to(device)
             cls._clip_model.eval()
             
-            logger.info("✓ CLIP model loaded and cached successfully")
+            load_time = time.time() - start
+            logger.info(f"✓ CLIP model loaded in {load_time:.1f}s and cached in memory")
         else:
-            logger.info("✓ Using cached CLIP model")
-            
+            logger.debug("✓ Using cached CLIP model")
+        
         return cls._clip_model, cls._clip_processor
-    
+
+
+
+
     @classmethod
     def clear_clip_cache(cls):
         """Clear the cached CLIP model to free memory."""

BIN
content_quality_tool/__pycache__/settings.cpython-313.pyc


+ 27 - 1
content_quality_tool/settings.py

@@ -36,7 +36,9 @@ INSTALLED_APPS = [
     'django.contrib.staticfiles',
     'core',
     'rest_framework',
-    'attr_extraction',
+    # 'attr_extraction',
+    'attr_extraction.apps.AttrExtractionConfig',
+    # 'attr_extraction.apps.ProductAttributesConfig',  # Full path
 ]
 MIDDLEWARE = [
     'django.middleware.security.SecurityMiddleware',
@@ -125,6 +127,9 @@ MESSAGE_TAGS = {
 }
 
 
+OPENAI_API_KEY = "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"
+
+
 
 GROQ_API_KEY = "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"
 GROQ_API_URL = "https://api.groq.com/openai/v1/chat/completions"
@@ -132,3 +137,24 @@ SUPPORTED_MODELS = ["llama-3.1-8b-instant", "llama-3.3-70b-versatile", "mixtral-
 MAX_BATCH_SIZE = 100  # Maximum products per batch request
 
 
+# ==================== settings.py ====================
+LOGGING = {
+    'version': 1,
+    'disable_existing_loggers': False,
+    'handlers': {
+        'console': {
+            'class': 'logging.StreamHandler',
+        },
+    },
+    'root': {
+        'handlers': ['console'],
+        'level': 'INFO',
+    },
+    'loggers': {
+        'attr_extraction': {
+            'handlers': ['console'],
+            'level': 'INFO',
+            'propagate': False,
+        },
+    },
+}

BIN
core/__pycache__/views.cpython-313.pyc


BIN
core/services/__pycache__/gemini_service.cpython-313.pyc


BIN
db.sqlite3