Harshit Pathak 3 mesiacov pred
rodič
commit
afae1b63f2
4 zmenil súbory, kde vykonal 1204 pridanie a 196 odobranie
  1. 119 13
      attr_extraction/apps.py
  2. 821 156
      attr_extraction/services.py
  3. 264 27
      attr_extraction/views.py
  4. BIN
      db.sqlite3

+ 119 - 13
attr_extraction/apps.py

@@ -268,6 +268,121 @@
 
 
 
+# # ==================== attr_extraction/apps.py ====================
+# from django.apps import AppConfig
+# import logging
+# import sys
+# import os
+# import threading
+
+# from django.core.cache import cache  # ✅ Import Django cache
+
+# logger = logging.getLogger(__name__)
+
+
+# class AttrExtractionConfig(AppConfig):
+#     default_auto_field = 'django.db.models.BigAutoField'
+#     name = 'attr_extraction'
+    
+#     models_loaded = False
+    
+#     def ready(self):
+#         """
+#         🔥 Pre-load all heavy ML models during Django startup.
+#         Also clears Django cache once when the server starts.
+#         """
+#         # Skip during migrations/management commands
+#         if any(cmd in sys.argv for cmd in ['migrate', 'makemigrations', 'test', 'collectstatic', 'shell']):
+#             return
+        
+#         # Skip in Django autoreloader parent process
+#         if os.environ.get('RUN_MAIN') != 'true':
+#             logger.info("⏭️  Skipping model loading in autoreloader parent process")
+#             return
+        
+#         # ✅ Clear cache once per startup
+#         try:
+#             cache.clear()
+#             logger.info("🧹 Django cache cleared successfully on startup.")
+#         except Exception as e:
+#             logger.warning(f"⚠️  Failed to clear cache: {e}")
+        
+#         # Prevent double loading
+#         if AttrExtractionConfig.models_loaded:
+#             logger.info("⏭️  Models already loaded, skipping...")
+#             return
+        
+#         AttrExtractionConfig.models_loaded = True
+        
+#         # Load models in background thread (non-blocking)
+#         thread = threading.Thread(target=self._load_models, daemon=True)
+#         thread.start()
+        
+#         logger.info("🔄 Model loading started in background...")
+    
+#     def _load_models(self):
+#         """Background thread to load heavy models."""
+#         import time
+        
+#         logger.info("=" * 70)
+#         logger.info("🔥 WARMING UP ML MODELS (background process)")
+#         logger.info("=" * 70)
+        
+#         startup_time = time.time()
+#         total_loaded = 0
+        
+#         # 1. Sentence Transformer
+#         # try:
+#         #     logger.info("📥 Loading Sentence Transformer...")
+#         #     st_start = time.time()
+#         #     from .services import model_embedder
+#         #     st_time = time.time() - st_start
+#         #     logger.info(f"✓ Sentence Transformer ready ({st_time:.1f}s)")
+#         #     total_loaded += 1
+#         # except Exception as e:
+#         #     logger.error(f"❌ Sentence Transformer failed: {e}")
+        
+#         # 2. Pre-load CLIP model
+#         try:
+#             logger.info("📥 Loading CLIP model (20-30s)...")
+#             clip_start = time.time()
+#             from .visual_processing_service import VisualProcessingService
+#             VisualProcessingService._get_clip_model()
+#             clip_time = time.time() - clip_start
+#             logger.info(f"✓ CLIP model cached ({clip_time:.1f}s)")
+#             total_loaded += 1
+#         except Exception as e:
+#             logger.error(f"❌ CLIP model failed: {e}")
+        
+#         # 3. Pre-load OCR model
+#         try:
+#             logger.info("📥 Loading EasyOCR model...")
+#             ocr_start = time.time()
+#             from .ocr_service import OCRService
+#             ocr_service = OCRService()
+#             ocr_service._get_reader()
+#             ocr_time = time.time() - ocr_start
+#             logger.info(f"✓ OCR model cached ({ocr_time:.1f}s)")
+#             total_loaded += 1
+#         except Exception as e:
+#             logger.error(f"❌ OCR model failed: {e}")
+        
+#         total_time = time.time() - startup_time
+        
+#         logger.info("=" * 70)
+#         logger.info(f"🎉 {total_loaded}/3 MODELS LOADED in {total_time:.1f}s")
+#         logger.info("⚡ API requests are now FAST (2-5 seconds)")
+#         logger.info("=" * 70)
+
+
+
+
+
+
+
+
+
+
 # ==================== attr_extraction/apps.py ====================
 from django.apps import AppConfig
 import logging
@@ -331,18 +446,9 @@ class AttrExtractionConfig(AppConfig):
         startup_time = time.time()
         total_loaded = 0
         
-        # 1. Sentence Transformer
-        try:
-            logger.info("📥 Loading Sentence Transformer...")
-            st_start = time.time()
-            from .services import model_embedder
-            st_time = time.time() - st_start
-            logger.info(f"✓ Sentence Transformer ready ({st_time:.1f}s)")
-            total_loaded += 1
-        except Exception as e:
-            logger.error(f"❌ Sentence Transformer failed: {e}")
+        # REMOVED: Sentence Transformer (no longer used in services.py)
         
-        # 2. Pre-load CLIP model
+        # 1. Pre-load CLIP model
         try:
             logger.info("📥 Loading CLIP model (20-30s)...")
             clip_start = time.time()
@@ -354,7 +460,7 @@ class AttrExtractionConfig(AppConfig):
         except Exception as e:
             logger.error(f"❌ CLIP model failed: {e}")
         
-        # 3. Pre-load OCR model
+        # 2. Pre-load OCR model
         try:
             logger.info("📥 Loading EasyOCR model...")
             ocr_start = time.time()
@@ -370,6 +476,6 @@ class AttrExtractionConfig(AppConfig):
         total_time = time.time() - startup_time
         
         logger.info("=" * 70)
-        logger.info(f"🎉 {total_loaded}/3 MODELS LOADED in {total_time:.1f}s")
+        logger.info(f"🎉 {total_loaded}/2 MODELS LOADED in {total_time:.1f}s")
         logger.info("⚡ API requests are now FAST (2-5 seconds)")
         logger.info("=" * 70)

Rozdielové dáta súboru neboli zobrazené, pretože súbor je príliš veľký
+ 821 - 156
attr_extraction/services.py


+ 264 - 27
attr_extraction/views.py

@@ -174,10 +174,260 @@ class ExtractProductAttributesView(APIView):
 
 # ==================== OPTIMIZED BATCH VIEW ====================
 
+# class BatchExtractProductAttributesView(APIView):
+#     """
+#     ⚡ PERFORMANCE OPTIMIZED: Batch extraction with intelligent parallelization
+#     Expected performance: 10 products in 30-60 seconds (with image processing)
+#     """
+
+#     def post(self, request):
+#         import time
+#         start_time = time.time()
+
+#         serializer = BatchProductRequestSerializer(data=request.data)
+#         if not serializer.is_valid():
+#             return Response({"error": serializer.errors}, status=status.HTTP_400_BAD_REQUEST)
+
+#         validated_data = serializer.validated_data
+#         product_list = validated_data.get("products", [])
+        
+#         logger.info(f"🚀 Starting batch processing for {len(product_list)} products")
+        
+#         # ==================== OPTIMIZATION 1: Bulk DB Query ====================
+#         item_ids = [p['item_id'] for p in product_list]
+#         products_queryset = Product.objects.filter(
+#             item_id__in=item_ids
+#         ).prefetch_related('attribute_values')
+        
+#         product_map = {product.item_id: product for product in products_queryset}
+        
+#         # Prefetch ALL original attribute values in ONE query
+#         original_values_qs = ProductAttributeValue.objects.filter(
+#             product__item_id__in=item_ids
+#         ).select_related('product')
+        
+#         original_values_map = {}
+#         for attr_val in original_values_qs:
+#             item_id = attr_val.product.item_id
+#             if item_id not in original_values_map:
+#                 original_values_map[item_id] = {}
+#             original_values_map[item_id][attr_val.attribute_name] = attr_val.original_value
+        
+#         logger.info(f"✓ Loaded {len(product_map)} products from database")
+        
+#         # Extract settings
+#         model = validated_data.get("model")
+#         extract_additional = validated_data.get("extract_additional", True)
+#         process_image = validated_data.get("process_image", True)
+#         multiple = validated_data.get("multiple", [])
+#         threshold_abs = validated_data.get("threshold_abs", 0.65)
+#         margin = validated_data.get("margin", 0.15)
+#         use_dynamic_thresholds = validated_data.get("use_dynamic_thresholds", False)
+#         use_adaptive_margin = validated_data.get("use_adaptive_margin", False)
+#         use_semantic_clustering = validated_data.get("use_semantic_clustering", False)
+        
+#         results = []
+#         successful = 0
+#         failed = 0
+        
+#         # ==================== OPTIMIZATION 2: Conditional Service Init ====================
+#         # Only initialize if processing images
+#         ocr_service = None
+#         visual_service = None
+        
+#         if process_image:
+#             from .ocr_service import OCRService
+#             from .visual_processing_service import VisualProcessingService
+#             ocr_service = OCRService()
+#             visual_service = VisualProcessingService()
+#             logger.info("✓ Image processing services initialized")
+
+#         # ==================== OPTIMIZATION 3: Smart Parallelization ====================
+#         def process_single_product(product_entry):
+#             """Process a single product (runs in parallel)"""
+#             import time
+#             product_start = time.time()
+            
+#             item_id = product_entry['item_id']
+#             mandatory_attrs = product_entry['mandatory_attrs']
+
+#             if item_id not in product_map:
+#                 return {
+#                     "product_id": item_id,
+#                     "error": "Product not found in database"
+#                 }, False
+
+#             product = product_map[item_id]
+            
+#             try:
+#                 title = product.product_name
+#                 short_desc = product.product_short_description
+#                 long_desc = product.product_long_description
+#                 image_url = product.image_path
+                
+#                 ocr_results = None
+#                 ocr_text = None
+#                 visual_results = None
+
+#                 # ⚡ SKIP IMAGE PROCESSING IF DISABLED (HUGE TIME SAVER)
+#                 if process_image and image_url:
+#                     if ocr_service:
+#                         ocr_results = ocr_service.process_image(image_url)
+                        
+#                         if ocr_results and ocr_results.get("detected_text"):
+#                             ocr_attrs = ProductAttributeService.extract_attributes_from_ocr(
+#                                 ocr_results, model
+#                             )
+#                             ocr_results["extracted_attributes"] = ocr_attrs
+#                             ocr_text = "\n".join([
+#                                 f"{item['text']} (confidence: {item['confidence']:.2f})"
+#                                 for item in ocr_results["detected_text"]
+#                             ])
+                    
+#                     if visual_service:
+#                         product_type_hint = product.product_type if hasattr(product, 'product_type') else None
+#                         visual_results = visual_service.process_image(image_url, product_type_hint)
+                        
+#                         if visual_results and visual_results.get('visual_attributes'):
+#                             visual_results['visual_attributes'] = ProductAttributeService.format_visual_attributes(
+#                                 visual_results['visual_attributes']
+#                             )
+
+#                 # Combine product text with source tracking
+#                 product_text, source_map = ProductAttributeService.combine_product_text(
+#                     title=title,
+#                     short_desc=short_desc,
+#                     long_desc=long_desc,
+#                     ocr_text=ocr_text
+#                 )
+
+#                 # ⚡ EXTRACT ATTRIBUTES WITH CACHING ENABLED
+#                 extracted = ProductAttributeService.extract_attributes(
+#                     product_text=product_text,
+#                     mandatory_attrs=mandatory_attrs,
+#                     source_map=source_map,
+#                     model=model,
+#                     extract_additional=extract_additional,
+#                     multiple=multiple,
+#                     # threshold_abs=threshold_abs,
+#                     # margin=margin,
+#                     # use_dynamic_thresholds=use_dynamic_thresholds,
+#                     # use_adaptive_margin=use_adaptive_margin,
+#                     # use_semantic_clustering=use_semantic_clustering,
+#                     use_cache=True  # ⚡ CRITICAL: Enable caching
+#                 )
+
+#                 # Add original values
+#                 original_attrs = original_values_map.get(item_id, {})
+                
+#                 for attr_name, attr_values in extracted.get("mandatory", {}).items():
+#                     if isinstance(attr_values, list):
+#                         for attr_obj in attr_values:
+#                             if isinstance(attr_obj, dict):
+#                                 attr_obj["original_value"] = original_attrs.get(attr_name, "")
+                
+#                 for attr_name, attr_values in extracted.get("additional", {}).items():
+#                     if isinstance(attr_values, list):
+#                         for attr_obj in attr_values:
+#                             if isinstance(attr_obj, dict):
+#                                 attr_obj["original_value"] = original_attrs.get(attr_name, "")
+
+#                 result = {
+#                     "product_id": product.item_id,
+#                     "mandatory": extracted.get("mandatory", {}),
+#                     "additional": extracted.get("additional", {}),
+#                 }
+
+#                 if ocr_results:
+#                     result["ocr_results"] = ocr_results
+                
+#                 if visual_results:
+#                     result["visual_results"] = visual_results
+                
+#                 processing_time = time.time() - product_start
+#                 logger.info(f"✓ Processed {item_id} in {processing_time:.2f}s")
+
+#                 return result, True
+
+#             except Exception as e:
+#                 logger.error(f"❌ Error processing {item_id}: {str(e)}")
+#                 return {
+#                     "product_id": item_id,
+#                     "error": str(e)
+#                 }, False
+
+#         # ==================== OPTIMIZATION 4: Parallel Execution ====================
+#         # Adjust workers based on whether image processing is enabled
+#         max_workers = min(3 if process_image else 10, len(product_list))
+        
+#         logger.info(f"⚡ Using {max_workers} parallel workers")
+        
+#         with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
+#             # Submit all tasks
+#             future_to_product = {
+#                 executor.submit(process_single_product, product): product
+#                 for product in product_list
+#             }
+            
+#             # Collect results as they complete
+#             for future in concurrent.futures.as_completed(future_to_product):
+#                 try:
+#                     result, success = future.result()
+#                     results.append(result)
+#                     if success:
+#                         successful += 1
+#                     else:
+#                         failed += 1
+#                 except Exception as e:
+#                     failed += 1
+#                     logger.error(f"❌ Future execution error: {str(e)}")
+#                     results.append({
+#                         "product_id": "unknown",
+#                         "error": str(e)
+#                     })
+
+#         total_time = time.time() - start_time
+        
+#         # Get cache statistics
+#         cache_stats = ProductAttributeService.get_cache_stats()
+        
+#         logger.info(f"""
+# 🎉 BATCH PROCESSING COMPLETE
+# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
+#   Total products: {len(product_list)}
+#   Successful: {successful}
+#   Failed: {failed}
+#   Total time: {total_time:.2f}s
+#   Avg time/product: {total_time/len(product_list):.2f}s
+# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
+#         """)
+
+#         batch_result = {
+#             "results": results,
+#             "total_products": len(product_list),
+#             "successful": successful,
+#             "failed": failed,
+#             "performance": {
+#                 "total_time_seconds": round(total_time, 2),
+#                 "avg_time_per_product": round(total_time / len(product_list), 2),
+#                 "workers_used": max_workers
+#             },
+#             "cache_stats": cache_stats
+#         }
+
+#         response_serializer = BatchProductResponseSerializer(data=batch_result)
+#         if response_serializer.is_valid():
+#             return Response(response_serializer.data, status=status.HTTP_200_OK)
+
+#         return Response(batch_result, status=status.HTTP_200_OK)
+
+
+# VERSION WITH PARALLELIZATION
 class BatchExtractProductAttributesView(APIView):
     """
     ⚡ PERFORMANCE OPTIMIZED: Batch extraction with intelligent parallelization
     Expected performance: 10 products in 30-60 seconds (with image processing)
+    NOW WITH USER VALUE REASONING
     """
 
     def post(self, request):
@@ -214,6 +464,7 @@ class BatchExtractProductAttributesView(APIView):
             original_values_map[item_id][attr_val.attribute_name] = attr_val.original_value
         
         logger.info(f"✓ Loaded {len(product_map)} products from database")
+        logger.info(f"✓ Loaded user values for {len(original_values_map)} products")
         
         # Extract settings
         model = validated_data.get("model")
@@ -231,7 +482,6 @@ class BatchExtractProductAttributesView(APIView):
         failed = 0
         
         # ==================== OPTIMIZATION 2: Conditional Service Init ====================
-        # Only initialize if processing images
         ocr_service = None
         visual_service = None
         
@@ -269,7 +519,7 @@ class BatchExtractProductAttributesView(APIView):
                 ocr_text = None
                 visual_results = None
 
-                # ⚡ SKIP IMAGE PROCESSING IF DISABLED (HUGE TIME SAVER)
+                # ⚡ SKIP IMAGE PROCESSING IF DISABLED
                 if process_image and image_url:
                     if ocr_service:
                         ocr_results = ocr_service.process_image(image_url)
@@ -301,7 +551,13 @@ class BatchExtractProductAttributesView(APIView):
                     ocr_text=ocr_text
                 )
 
-                # ⚡ EXTRACT ATTRIBUTES WITH CACHING ENABLED
+                # 🆕 GET USER-ENTERED VALUES FOR THIS PRODUCT
+                user_entered_values = original_values_map.get(item_id, {})
+                print("user entered values are ")
+                print(user_entered_values)
+                logger.info(f"Processing {item_id} with {len(user_entered_values)} user-entered values")
+
+                # ⚡ EXTRACT ATTRIBUTES WITH USER VALUES AND REASONING
                 extracted = ProductAttributeService.extract_attributes(
                     product_text=product_text,
                     mandatory_attrs=mandatory_attrs,
@@ -309,29 +565,13 @@ class BatchExtractProductAttributesView(APIView):
                     model=model,
                     extract_additional=extract_additional,
                     multiple=multiple,
-                    # threshold_abs=threshold_abs,
-                    # margin=margin,
-                    # use_dynamic_thresholds=use_dynamic_thresholds,
-                    # use_adaptive_margin=use_adaptive_margin,
-                    # use_semantic_clustering=use_semantic_clustering,
-                    use_cache=True  # ⚡ CRITICAL: Enable caching
+                    use_cache=True,
+                    user_entered_values=user_entered_values  # 🆕 PASS USER VALUES
                 )
 
-                # Add original values
-                original_attrs = original_values_map.get(item_id, {})
-                
-                for attr_name, attr_values in extracted.get("mandatory", {}).items():
-                    if isinstance(attr_values, list):
-                        for attr_obj in attr_values:
-                            if isinstance(attr_obj, dict):
-                                attr_obj["original_value"] = original_attrs.get(attr_name, "")
+                # NOTE: Original values are now part of LLM response with reasoning
+                # No need to add them separately - they're already in the "user_value" field
                 
-                for attr_name, attr_values in extracted.get("additional", {}).items():
-                    if isinstance(attr_values, list):
-                        for attr_obj in attr_values:
-                            if isinstance(attr_obj, dict):
-                                attr_obj["original_value"] = original_attrs.get(attr_name, "")
-
                 result = {
                     "product_id": product.item_id,
                     "mandatory": extracted.get("mandatory", {}),
@@ -357,19 +597,16 @@ class BatchExtractProductAttributesView(APIView):
                 }, False
 
         # ==================== OPTIMIZATION 4: Parallel Execution ====================
-        # Adjust workers based on whether image processing is enabled
         max_workers = min(3 if process_image else 10, len(product_list))
         
         logger.info(f"⚡ Using {max_workers} parallel workers")
         
         with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
-            # Submit all tasks
             future_to_product = {
                 executor.submit(process_single_product, product): product
                 for product in product_list
             }
             
-            # Collect results as they complete
             for future in concurrent.futures.as_completed(future_to_product):
                 try:
                     result, success = future.result()
@@ -399,7 +636,6 @@ class BatchExtractProductAttributesView(APIView):
   Failed: {failed}
   Total time: {total_time:.2f}s
   Avg time/product: {total_time/len(product_list):.2f}s
-  Cache hit rate: {cache_stats['embedding_cache']['hit_rate_percent']:.1f}%
 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
         """)
 
@@ -423,6 +659,7 @@ class BatchExtractProductAttributesView(APIView):
         return Response(batch_result, status=status.HTTP_200_OK)
 
 
+
 class ProductListView(APIView):
     """
     GET API to list all products with details

BIN
db.sqlite3


Niektoré súbory nie sú zobrazené, pretože je v týchto rozdielových dátach zmenené mnoho súborov