Harshit Pathak 3 mesi fa
parent
commit
afae1b63f2
4 ha cambiato i file con 1204 aggiunte e 196 eliminazioni
  1. 119 13
      attr_extraction/apps.py
  2. 821 156
      attr_extraction/services.py
  3. 264 27
      attr_extraction/views.py
  4. BIN
      db.sqlite3

+ 119 - 13
attr_extraction/apps.py

@@ -268,6 +268,121 @@
 
 
 
 
 
 
+# # ==================== attr_extraction/apps.py ====================
+# from django.apps import AppConfig
+# import logging
+# import sys
+# import os
+# import threading
+
+# from django.core.cache import cache  # ✅ Import Django cache
+
+# logger = logging.getLogger(__name__)
+
+
+# class AttrExtractionConfig(AppConfig):
+#     default_auto_field = 'django.db.models.BigAutoField'
+#     name = 'attr_extraction'
+    
+#     models_loaded = False
+    
+#     def ready(self):
+#         """
+#         🔥 Pre-load all heavy ML models during Django startup.
+#         Also clears Django cache once when the server starts.
+#         """
+#         # Skip during migrations/management commands
+#         if any(cmd in sys.argv for cmd in ['migrate', 'makemigrations', 'test', 'collectstatic', 'shell']):
+#             return
+        
+#         # Skip in Django autoreloader parent process
+#         if os.environ.get('RUN_MAIN') != 'true':
+#             logger.info("⏭️  Skipping model loading in autoreloader parent process")
+#             return
+        
+#         # ✅ Clear cache once per startup
+#         try:
+#             cache.clear()
+#             logger.info("🧹 Django cache cleared successfully on startup.")
+#         except Exception as e:
+#             logger.warning(f"⚠️  Failed to clear cache: {e}")
+        
+#         # Prevent double loading
+#         if AttrExtractionConfig.models_loaded:
+#             logger.info("⏭️  Models already loaded, skipping...")
+#             return
+        
+#         AttrExtractionConfig.models_loaded = True
+        
+#         # Load models in background thread (non-blocking)
+#         thread = threading.Thread(target=self._load_models, daemon=True)
+#         thread.start()
+        
+#         logger.info("🔄 Model loading started in background...")
+    
+#     def _load_models(self):
+#         """Background thread to load heavy models."""
+#         import time
+        
+#         logger.info("=" * 70)
+#         logger.info("🔥 WARMING UP ML MODELS (background process)")
+#         logger.info("=" * 70)
+        
+#         startup_time = time.time()
+#         total_loaded = 0
+        
+#         # 1. Sentence Transformer
+#         # try:
+#         #     logger.info("📥 Loading Sentence Transformer...")
+#         #     st_start = time.time()
+#         #     from .services import model_embedder
+#         #     st_time = time.time() - st_start
+#         #     logger.info(f"✓ Sentence Transformer ready ({st_time:.1f}s)")
+#         #     total_loaded += 1
+#         # except Exception as e:
+#         #     logger.error(f"❌ Sentence Transformer failed: {e}")
+        
+#         # 2. Pre-load CLIP model
+#         try:
+#             logger.info("📥 Loading CLIP model (20-30s)...")
+#             clip_start = time.time()
+#             from .visual_processing_service import VisualProcessingService
+#             VisualProcessingService._get_clip_model()
+#             clip_time = time.time() - clip_start
+#             logger.info(f"✓ CLIP model cached ({clip_time:.1f}s)")
+#             total_loaded += 1
+#         except Exception as e:
+#             logger.error(f"❌ CLIP model failed: {e}")
+        
+#         # 3. Pre-load OCR model
+#         try:
+#             logger.info("📥 Loading EasyOCR model...")
+#             ocr_start = time.time()
+#             from .ocr_service import OCRService
+#             ocr_service = OCRService()
+#             ocr_service._get_reader()
+#             ocr_time = time.time() - ocr_start
+#             logger.info(f"✓ OCR model cached ({ocr_time:.1f}s)")
+#             total_loaded += 1
+#         except Exception as e:
+#             logger.error(f"❌ OCR model failed: {e}")
+        
+#         total_time = time.time() - startup_time
+        
+#         logger.info("=" * 70)
+#         logger.info(f"🎉 {total_loaded}/3 MODELS LOADED in {total_time:.1f}s")
+#         logger.info("⚡ API requests are now FAST (2-5 seconds)")
+#         logger.info("=" * 70)
+
+
+
+
+
+
+
+
+
+
 # ==================== attr_extraction/apps.py ====================
 # ==================== attr_extraction/apps.py ====================
 from django.apps import AppConfig
 from django.apps import AppConfig
 import logging
 import logging
@@ -331,18 +446,9 @@ class AttrExtractionConfig(AppConfig):
         startup_time = time.time()
         startup_time = time.time()
         total_loaded = 0
         total_loaded = 0
         
         
-        # 1. Sentence Transformer
-        try:
-            logger.info("📥 Loading Sentence Transformer...")
-            st_start = time.time()
-            from .services import model_embedder
-            st_time = time.time() - st_start
-            logger.info(f"✓ Sentence Transformer ready ({st_time:.1f}s)")
-            total_loaded += 1
-        except Exception as e:
-            logger.error(f"❌ Sentence Transformer failed: {e}")
+        # REMOVED: Sentence Transformer (no longer used in services.py)
         
         
-        # 2. Pre-load CLIP model
+        # 1. Pre-load CLIP model
         try:
         try:
             logger.info("📥 Loading CLIP model (20-30s)...")
             logger.info("📥 Loading CLIP model (20-30s)...")
             clip_start = time.time()
             clip_start = time.time()
@@ -354,7 +460,7 @@ class AttrExtractionConfig(AppConfig):
         except Exception as e:
         except Exception as e:
             logger.error(f"❌ CLIP model failed: {e}")
             logger.error(f"❌ CLIP model failed: {e}")
         
         
-        # 3. Pre-load OCR model
+        # 2. Pre-load OCR model
         try:
         try:
             logger.info("📥 Loading EasyOCR model...")
             logger.info("📥 Loading EasyOCR model...")
             ocr_start = time.time()
             ocr_start = time.time()
@@ -370,6 +476,6 @@ class AttrExtractionConfig(AppConfig):
         total_time = time.time() - startup_time
         total_time = time.time() - startup_time
         
         
         logger.info("=" * 70)
         logger.info("=" * 70)
-        logger.info(f"🎉 {total_loaded}/3 MODELS LOADED in {total_time:.1f}s")
+        logger.info(f"🎉 {total_loaded}/2 MODELS LOADED in {total_time:.1f}s")
         logger.info("⚡ API requests are now FAST (2-5 seconds)")
         logger.info("⚡ API requests are now FAST (2-5 seconds)")
         logger.info("=" * 70)
         logger.info("=" * 70)

File diff suppressed because it is too large
+ 821 - 156
attr_extraction/services.py


+ 264 - 27
attr_extraction/views.py

@@ -174,10 +174,260 @@ class ExtractProductAttributesView(APIView):
 
 
 # ==================== OPTIMIZED BATCH VIEW ====================
 # ==================== OPTIMIZED BATCH VIEW ====================
 
 
+# class BatchExtractProductAttributesView(APIView):
+#     """
+#     ⚡ PERFORMANCE OPTIMIZED: Batch extraction with intelligent parallelization
+#     Expected performance: 10 products in 30-60 seconds (with image processing)
+#     """
+
+#     def post(self, request):
+#         import time
+#         start_time = time.time()
+
+#         serializer = BatchProductRequestSerializer(data=request.data)
+#         if not serializer.is_valid():
+#             return Response({"error": serializer.errors}, status=status.HTTP_400_BAD_REQUEST)
+
+#         validated_data = serializer.validated_data
+#         product_list = validated_data.get("products", [])
+        
+#         logger.info(f"🚀 Starting batch processing for {len(product_list)} products")
+        
+#         # ==================== OPTIMIZATION 1: Bulk DB Query ====================
+#         item_ids = [p['item_id'] for p in product_list]
+#         products_queryset = Product.objects.filter(
+#             item_id__in=item_ids
+#         ).prefetch_related('attribute_values')
+        
+#         product_map = {product.item_id: product for product in products_queryset}
+        
+#         # Prefetch ALL original attribute values in ONE query
+#         original_values_qs = ProductAttributeValue.objects.filter(
+#             product__item_id__in=item_ids
+#         ).select_related('product')
+        
+#         original_values_map = {}
+#         for attr_val in original_values_qs:
+#             item_id = attr_val.product.item_id
+#             if item_id not in original_values_map:
+#                 original_values_map[item_id] = {}
+#             original_values_map[item_id][attr_val.attribute_name] = attr_val.original_value
+        
+#         logger.info(f"✓ Loaded {len(product_map)} products from database")
+        
+#         # Extract settings
+#         model = validated_data.get("model")
+#         extract_additional = validated_data.get("extract_additional", True)
+#         process_image = validated_data.get("process_image", True)
+#         multiple = validated_data.get("multiple", [])
+#         threshold_abs = validated_data.get("threshold_abs", 0.65)
+#         margin = validated_data.get("margin", 0.15)
+#         use_dynamic_thresholds = validated_data.get("use_dynamic_thresholds", False)
+#         use_adaptive_margin = validated_data.get("use_adaptive_margin", False)
+#         use_semantic_clustering = validated_data.get("use_semantic_clustering", False)
+        
+#         results = []
+#         successful = 0
+#         failed = 0
+        
+#         # ==================== OPTIMIZATION 2: Conditional Service Init ====================
+#         # Only initialize if processing images
+#         ocr_service = None
+#         visual_service = None
+        
+#         if process_image:
+#             from .ocr_service import OCRService
+#             from .visual_processing_service import VisualProcessingService
+#             ocr_service = OCRService()
+#             visual_service = VisualProcessingService()
+#             logger.info("✓ Image processing services initialized")
+
+#         # ==================== OPTIMIZATION 3: Smart Parallelization ====================
+#         def process_single_product(product_entry):
+#             """Process a single product (runs in parallel)"""
+#             import time
+#             product_start = time.time()
+            
+#             item_id = product_entry['item_id']
+#             mandatory_attrs = product_entry['mandatory_attrs']
+
+#             if item_id not in product_map:
+#                 return {
+#                     "product_id": item_id,
+#                     "error": "Product not found in database"
+#                 }, False
+
+#             product = product_map[item_id]
+            
+#             try:
+#                 title = product.product_name
+#                 short_desc = product.product_short_description
+#                 long_desc = product.product_long_description
+#                 image_url = product.image_path
+                
+#                 ocr_results = None
+#                 ocr_text = None
+#                 visual_results = None
+
+#                 # ⚡ SKIP IMAGE PROCESSING IF DISABLED (HUGE TIME SAVER)
+#                 if process_image and image_url:
+#                     if ocr_service:
+#                         ocr_results = ocr_service.process_image(image_url)
+                        
+#                         if ocr_results and ocr_results.get("detected_text"):
+#                             ocr_attrs = ProductAttributeService.extract_attributes_from_ocr(
+#                                 ocr_results, model
+#                             )
+#                             ocr_results["extracted_attributes"] = ocr_attrs
+#                             ocr_text = "\n".join([
+#                                 f"{item['text']} (confidence: {item['confidence']:.2f})"
+#                                 for item in ocr_results["detected_text"]
+#                             ])
+                    
+#                     if visual_service:
+#                         product_type_hint = product.product_type if hasattr(product, 'product_type') else None
+#                         visual_results = visual_service.process_image(image_url, product_type_hint)
+                        
+#                         if visual_results and visual_results.get('visual_attributes'):
+#                             visual_results['visual_attributes'] = ProductAttributeService.format_visual_attributes(
+#                                 visual_results['visual_attributes']
+#                             )
+
+#                 # Combine product text with source tracking
+#                 product_text, source_map = ProductAttributeService.combine_product_text(
+#                     title=title,
+#                     short_desc=short_desc,
+#                     long_desc=long_desc,
+#                     ocr_text=ocr_text
+#                 )
+
+#                 # ⚡ EXTRACT ATTRIBUTES WITH CACHING ENABLED
+#                 extracted = ProductAttributeService.extract_attributes(
+#                     product_text=product_text,
+#                     mandatory_attrs=mandatory_attrs,
+#                     source_map=source_map,
+#                     model=model,
+#                     extract_additional=extract_additional,
+#                     multiple=multiple,
+#                     # threshold_abs=threshold_abs,
+#                     # margin=margin,
+#                     # use_dynamic_thresholds=use_dynamic_thresholds,
+#                     # use_adaptive_margin=use_adaptive_margin,
+#                     # use_semantic_clustering=use_semantic_clustering,
+#                     use_cache=True  # ⚡ CRITICAL: Enable caching
+#                 )
+
+#                 # Add original values
+#                 original_attrs = original_values_map.get(item_id, {})
+                
+#                 for attr_name, attr_values in extracted.get("mandatory", {}).items():
+#                     if isinstance(attr_values, list):
+#                         for attr_obj in attr_values:
+#                             if isinstance(attr_obj, dict):
+#                                 attr_obj["original_value"] = original_attrs.get(attr_name, "")
+                
+#                 for attr_name, attr_values in extracted.get("additional", {}).items():
+#                     if isinstance(attr_values, list):
+#                         for attr_obj in attr_values:
+#                             if isinstance(attr_obj, dict):
+#                                 attr_obj["original_value"] = original_attrs.get(attr_name, "")
+
+#                 result = {
+#                     "product_id": product.item_id,
+#                     "mandatory": extracted.get("mandatory", {}),
+#                     "additional": extracted.get("additional", {}),
+#                 }
+
+#                 if ocr_results:
+#                     result["ocr_results"] = ocr_results
+                
+#                 if visual_results:
+#                     result["visual_results"] = visual_results
+                
+#                 processing_time = time.time() - product_start
+#                 logger.info(f"✓ Processed {item_id} in {processing_time:.2f}s")
+
+#                 return result, True
+
+#             except Exception as e:
+#                 logger.error(f"❌ Error processing {item_id}: {str(e)}")
+#                 return {
+#                     "product_id": item_id,
+#                     "error": str(e)
+#                 }, False
+
+#         # ==================== OPTIMIZATION 4: Parallel Execution ====================
+#         # Adjust workers based on whether image processing is enabled
+#         max_workers = min(3 if process_image else 10, len(product_list))
+        
+#         logger.info(f"⚡ Using {max_workers} parallel workers")
+        
+#         with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
+#             # Submit all tasks
+#             future_to_product = {
+#                 executor.submit(process_single_product, product): product
+#                 for product in product_list
+#             }
+            
+#             # Collect results as they complete
+#             for future in concurrent.futures.as_completed(future_to_product):
+#                 try:
+#                     result, success = future.result()
+#                     results.append(result)
+#                     if success:
+#                         successful += 1
+#                     else:
+#                         failed += 1
+#                 except Exception as e:
+#                     failed += 1
+#                     logger.error(f"❌ Future execution error: {str(e)}")
+#                     results.append({
+#                         "product_id": "unknown",
+#                         "error": str(e)
+#                     })
+
+#         total_time = time.time() - start_time
+        
+#         # Get cache statistics
+#         cache_stats = ProductAttributeService.get_cache_stats()
+        
+#         logger.info(f"""
+# 🎉 BATCH PROCESSING COMPLETE
+# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
+#   Total products: {len(product_list)}
+#   Successful: {successful}
+#   Failed: {failed}
+#   Total time: {total_time:.2f}s
+#   Avg time/product: {total_time/len(product_list):.2f}s
+# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
+#         """)
+
+#         batch_result = {
+#             "results": results,
+#             "total_products": len(product_list),
+#             "successful": successful,
+#             "failed": failed,
+#             "performance": {
+#                 "total_time_seconds": round(total_time, 2),
+#                 "avg_time_per_product": round(total_time / len(product_list), 2),
+#                 "workers_used": max_workers
+#             },
+#             "cache_stats": cache_stats
+#         }
+
+#         response_serializer = BatchProductResponseSerializer(data=batch_result)
+#         if response_serializer.is_valid():
+#             return Response(response_serializer.data, status=status.HTTP_200_OK)
+
+#         return Response(batch_result, status=status.HTTP_200_OK)
+
+
+# VERSION WITH PARALLELIZATION
 class BatchExtractProductAttributesView(APIView):
 class BatchExtractProductAttributesView(APIView):
     """
     """
     ⚡ PERFORMANCE OPTIMIZED: Batch extraction with intelligent parallelization
     ⚡ PERFORMANCE OPTIMIZED: Batch extraction with intelligent parallelization
     Expected performance: 10 products in 30-60 seconds (with image processing)
     Expected performance: 10 products in 30-60 seconds (with image processing)
+    NOW WITH USER VALUE REASONING
     """
     """
 
 
     def post(self, request):
     def post(self, request):
@@ -214,6 +464,7 @@ class BatchExtractProductAttributesView(APIView):
             original_values_map[item_id][attr_val.attribute_name] = attr_val.original_value
             original_values_map[item_id][attr_val.attribute_name] = attr_val.original_value
         
         
         logger.info(f"✓ Loaded {len(product_map)} products from database")
         logger.info(f"✓ Loaded {len(product_map)} products from database")
+        logger.info(f"✓ Loaded user values for {len(original_values_map)} products")
         
         
         # Extract settings
         # Extract settings
         model = validated_data.get("model")
         model = validated_data.get("model")
@@ -231,7 +482,6 @@ class BatchExtractProductAttributesView(APIView):
         failed = 0
         failed = 0
         
         
         # ==================== OPTIMIZATION 2: Conditional Service Init ====================
         # ==================== OPTIMIZATION 2: Conditional Service Init ====================
-        # Only initialize if processing images
         ocr_service = None
         ocr_service = None
         visual_service = None
         visual_service = None
         
         
@@ -269,7 +519,7 @@ class BatchExtractProductAttributesView(APIView):
                 ocr_text = None
                 ocr_text = None
                 visual_results = None
                 visual_results = None
 
 
-                # ⚡ SKIP IMAGE PROCESSING IF DISABLED (HUGE TIME SAVER)
+                # ⚡ SKIP IMAGE PROCESSING IF DISABLED
                 if process_image and image_url:
                 if process_image and image_url:
                     if ocr_service:
                     if ocr_service:
                         ocr_results = ocr_service.process_image(image_url)
                         ocr_results = ocr_service.process_image(image_url)
@@ -301,7 +551,13 @@ class BatchExtractProductAttributesView(APIView):
                     ocr_text=ocr_text
                     ocr_text=ocr_text
                 )
                 )
 
 
-                # ⚡ EXTRACT ATTRIBUTES WITH CACHING ENABLED
+                # 🆕 GET USER-ENTERED VALUES FOR THIS PRODUCT
+                user_entered_values = original_values_map.get(item_id, {})
+                print("user entered values are ")
+                print(user_entered_values)
+                logger.info(f"Processing {item_id} with {len(user_entered_values)} user-entered values")
+
+                # ⚡ EXTRACT ATTRIBUTES WITH USER VALUES AND REASONING
                 extracted = ProductAttributeService.extract_attributes(
                 extracted = ProductAttributeService.extract_attributes(
                     product_text=product_text,
                     product_text=product_text,
                     mandatory_attrs=mandatory_attrs,
                     mandatory_attrs=mandatory_attrs,
@@ -309,29 +565,13 @@ class BatchExtractProductAttributesView(APIView):
                     model=model,
                     model=model,
                     extract_additional=extract_additional,
                     extract_additional=extract_additional,
                     multiple=multiple,
                     multiple=multiple,
-                    # threshold_abs=threshold_abs,
-                    # margin=margin,
-                    # use_dynamic_thresholds=use_dynamic_thresholds,
-                    # use_adaptive_margin=use_adaptive_margin,
-                    # use_semantic_clustering=use_semantic_clustering,
-                    use_cache=True  # ⚡ CRITICAL: Enable caching
+                    use_cache=True,
+                    user_entered_values=user_entered_values  # 🆕 PASS USER VALUES
                 )
                 )
 
 
-                # Add original values
-                original_attrs = original_values_map.get(item_id, {})
-                
-                for attr_name, attr_values in extracted.get("mandatory", {}).items():
-                    if isinstance(attr_values, list):
-                        for attr_obj in attr_values:
-                            if isinstance(attr_obj, dict):
-                                attr_obj["original_value"] = original_attrs.get(attr_name, "")
+                # NOTE: Original values are now part of LLM response with reasoning
+                # No need to add them separately - they're already in the "user_value" field
                 
                 
-                for attr_name, attr_values in extracted.get("additional", {}).items():
-                    if isinstance(attr_values, list):
-                        for attr_obj in attr_values:
-                            if isinstance(attr_obj, dict):
-                                attr_obj["original_value"] = original_attrs.get(attr_name, "")
-
                 result = {
                 result = {
                     "product_id": product.item_id,
                     "product_id": product.item_id,
                     "mandatory": extracted.get("mandatory", {}),
                     "mandatory": extracted.get("mandatory", {}),
@@ -357,19 +597,16 @@ class BatchExtractProductAttributesView(APIView):
                 }, False
                 }, False
 
 
         # ==================== OPTIMIZATION 4: Parallel Execution ====================
         # ==================== OPTIMIZATION 4: Parallel Execution ====================
-        # Adjust workers based on whether image processing is enabled
         max_workers = min(3 if process_image else 10, len(product_list))
         max_workers = min(3 if process_image else 10, len(product_list))
         
         
         logger.info(f"⚡ Using {max_workers} parallel workers")
         logger.info(f"⚡ Using {max_workers} parallel workers")
         
         
         with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
         with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
-            # Submit all tasks
             future_to_product = {
             future_to_product = {
                 executor.submit(process_single_product, product): product
                 executor.submit(process_single_product, product): product
                 for product in product_list
                 for product in product_list
             }
             }
             
             
-            # Collect results as they complete
             for future in concurrent.futures.as_completed(future_to_product):
             for future in concurrent.futures.as_completed(future_to_product):
                 try:
                 try:
                     result, success = future.result()
                     result, success = future.result()
@@ -399,7 +636,6 @@ class BatchExtractProductAttributesView(APIView):
   Failed: {failed}
   Failed: {failed}
   Total time: {total_time:.2f}s
   Total time: {total_time:.2f}s
   Avg time/product: {total_time/len(product_list):.2f}s
   Avg time/product: {total_time/len(product_list):.2f}s
-  Cache hit rate: {cache_stats['embedding_cache']['hit_rate_percent']:.1f}%
 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
         """)
         """)
 
 
@@ -423,6 +659,7 @@ class BatchExtractProductAttributesView(APIView):
         return Response(batch_result, status=status.HTTP_200_OK)
         return Response(batch_result, status=status.HTTP_200_OK)
 
 
 
 
+
 class ProductListView(APIView):
 class ProductListView(APIView):
     """
     """
     GET API to list all products with details
     GET API to list all products with details

BIN
db.sqlite3


Some files were not shown because too many files changed in this diff