Harshit Pathak пре 3 месеци
родитељ
комит
11960d6f41
2 измењених фајлова са 154 додато и 210 уклоњено
  1. 21 44
      attr_extraction/serializers.py
  2. 133 166
      attr_extraction/views.py

+ 21 - 44
attr_extraction/serializers.py

@@ -95,6 +95,19 @@ class ProductInputSerializer(serializers.Serializer):
     long_desc = serializers.CharField(required=False, allow_blank=True, allow_null=True)
     image_url = serializers.URLField(required=False, allow_blank=True, allow_null=True)
 
+class MandatoryAttrsField(serializers.DictField):
+    """Custom DictField to validate mandatory_attrs structure."""
+    child = serializers.ListField(child=serializers.CharField())
+
+class ProductBatchInputSerializer(serializers.Serializer):
+    """Serializer for an individual product input within the batch request."""
+    item_id = serializers.CharField(required=True)
+    mandatory_attrs = MandatoryAttrsField(
+        required=True,
+        help_text="A dictionary of attribute names and their possible values."
+    )
+    # You can also allow per-product model/flags if needed, but keeping it batch-level for simplicity here.
+
 
 class SingleProductRequestSerializer(serializers.Serializer):
     """Serializer for single product extraction request."""
@@ -120,54 +133,19 @@ class SingleProductRequestSerializer(serializers.Serializer):
         return value
 
 
-# class BatchProductRequestSerializer(serializers.Serializer):
-#     """Serializer for batch product extraction request."""
-#     products = serializers.ListField(
-#         child=ProductInputSerializer(),
-#         required=True,
-#         min_length=1
-#     )
-#     mandatory_attrs = serializers.DictField(
-#         child=serializers.ListField(child=serializers.CharField()),
-#         required=True
-#     )
-#     model = serializers.CharField(required=False, default="llama-3.1-8b-instant")
-#     extract_additional = serializers.BooleanField(required=False, default=True)
-#     process_image = serializers.BooleanField(required=False, default=True)
-
-#     def validate_model(self, value):
-#         from django.conf import settings
-#         if value not in settings.SUPPORTED_MODELS:
-#             raise serializers.ValidationError(
-#                 f"Model must be one of {settings.SUPPORTED_MODELS}"
-#             )
-#         return value
-
-#     def validate_products(self, value):
-#         from django.conf import settings
-#         max_size = getattr(settings, 'MAX_BATCH_SIZE', 100)
-#         if len(value) > max_size:
-#             raise serializers.ValidationError(
-#                 f"Batch size cannot exceed {max_size} products"
-#             )
-#         return value
-
 
 class BatchProductRequestSerializer(serializers.Serializer):
-    """Serializer for batch product extraction request (by item_id)."""
-    item_ids = serializers.ListField(
-        child=serializers.CharField(),
+    """Serializer for batch product extraction request (with item-specific attributes)."""
+    products = serializers.ListField(
+        child=ProductBatchInputSerializer(), # <--- Changed
         required=True,
         min_length=1
     )
-    mandatory_attrs = serializers.DictField(
-        child=serializers.ListField(child=serializers.CharField()),
-        required=True
-    )
     model = serializers.CharField(required=False, default="llama-3.1-8b-instant")
     extract_additional = serializers.BooleanField(required=False, default=True)
     process_image = serializers.BooleanField(required=False, default=True)
-
+    
+    # ... validate_model method ...
     def validate_model(self, value):
         from django.conf import settings
         if value not in settings.SUPPORTED_MODELS:
@@ -175,8 +153,9 @@ class BatchProductRequestSerializer(serializers.Serializer):
                 f"Model must be one of {settings.SUPPORTED_MODELS}"
             )
         return value
-
-    def validate_item_ids(self, value):
+    
+    # ... validate_products method (updated to use products instead of item_ids) ...
+    def validate_products(self, value):
         from django.conf import settings
         max_size = getattr(settings, 'MAX_BATCH_SIZE', 100)
         if len(value) > max_size:
@@ -185,8 +164,6 @@ class BatchProductRequestSerializer(serializers.Serializer):
             )
         return value
 
-
-
 class OCRResultSerializer(serializers.Serializer):
     """Serializer for OCR results."""
     detected_text = serializers.ListField(child=serializers.DictField())

+ 133 - 166
attr_extraction/views.py

@@ -235,71 +235,6 @@ from .services import ProductAttributeService
 from .ocr_service import OCRService
 
 
-# class ExtractProductAttributesView(APIView):
-#     """
-#     API endpoint to extract product attributes for a single product.
-#     Now supports image URL for OCR-based text extraction.
-#     """
-
-#     def post(self, request):
-#         serializer = SingleProductRequestSerializer(data=request.data)
-#         if not serializer.is_valid():
-#             return Response(
-#                 {"error": serializer.errors},
-#                 status=status.HTTP_400_BAD_REQUEST
-#             )
-
-#         validated_data = serializer.validated_data
-        
-#         # Process image if URL provided
-#         ocr_results = None
-#         ocr_text = None
-        
-#         if validated_data.get('process_image', True) and validated_data.get('image_url'):
-#             ocr_service = OCRService()
-#             ocr_results = ocr_service.process_image(validated_data['image_url'])
-            
-#             # Extract attributes from OCR
-#             if ocr_results and ocr_results.get('detected_text'):
-#                 ocr_attrs = ProductAttributeService.extract_attributes_from_ocr(
-#                     ocr_results,
-#                     validated_data.get('model')
-#                 )
-#                 ocr_results['extracted_attributes'] = ocr_attrs
-                
-#                 # Format OCR text
-#                 ocr_text = "\n".join([
-#                     f"{item['text']} (confidence: {item['confidence']:.2f})"
-#                     for item in ocr_results['detected_text']
-#                 ])
-
-#         # Combine all product information
-#         product_text = ProductAttributeService.combine_product_text(
-#             title=validated_data.get('title'),
-#             short_desc=validated_data.get('short_desc'),
-#             long_desc=validated_data.get('long_desc'),
-#             ocr_text=ocr_text
-#         )
-
-#         # Extract attributes
-#         result = ProductAttributeService.extract_attributes(
-#             product_text=product_text,
-#             mandatory_attrs=validated_data['mandatory_attrs'],
-#             model=validated_data.get('model'),
-#             extract_additional=validated_data.get('extract_additional', True)
-#         )
-        
-#         # Add OCR results if available
-#         if ocr_results:
-#             result['ocr_results'] = ocr_results
-
-#         response_serializer = ProductAttributeResultSerializer(data=result)
-#         if response_serializer.is_valid():
-#             return Response(response_serializer.data, status=status.HTTP_200_OK)
-        
-#         return Response(result, status=status.HTTP_200_OK)
-
-
 from .models import Product
 
 class ExtractProductAttributesView(APIView):
@@ -377,44 +312,118 @@ class ExtractProductAttributesView(APIView):
         return Response(result, status=status.HTTP_200_OK)
 
 
+from .models import Product
+
 # class BatchExtractProductAttributesView(APIView):
 #     """
-#     API endpoint to extract product attributes for multiple products in batch.
-#     Now supports image URLs for OCR-based text extraction.
+#     API endpoint to extract product attributes for multiple products in batch by item_id.
+#     Fetches all product details from database automatically.
 #     """
 
 #     def post(self, request):
 #         serializer = BatchProductRequestSerializer(data=request.data)
 #         if not serializer.is_valid():
-#             return Response(
-#                 {"error": serializer.errors},
-#                 status=status.HTTP_400_BAD_REQUEST
-#             )
+#             return Response({"error": serializer.errors}, status=status.HTTP_400_BAD_REQUEST)
 
 #         validated_data = serializer.validated_data
+#         item_ids = validated_data.get("item_ids", [])
+#         model = validated_data.get("model")
+#         extract_additional = validated_data.get("extract_additional", True)
+#         process_image = validated_data.get("process_image", True)
+#         mandatory_attrs = validated_data["mandatory_attrs"]
+
+#         # Fetch all products in one query
+#         products = Product.objects.filter(item_id__in=item_ids)
+#         found_ids = set(products.values_list("item_id", flat=True))
+#         missing_ids = [pid for pid in item_ids if pid not in found_ids]
+
+#         results = []
+#         successful = 0
+#         failed = 0
+
+#         for product in products:
+#             try:
+#                 title = product.product_name
+#                 short_desc = product.product_short_description
+#                 long_desc = product.product_long_description
+#                 image_url = product.image_path
+
+#                 ocr_results = None
+#                 ocr_text = None
+
+#                 if process_image and image_url:
+#                     ocr_service = OCRService()
+#                     ocr_results = ocr_service.process_image(image_url)
+
+#                     if ocr_results and ocr_results.get("detected_text"):
+#                         ocr_attrs = ProductAttributeService.extract_attributes_from_ocr(
+#                             ocr_results, model
+#                         )
+#                         ocr_results["extracted_attributes"] = ocr_attrs
+#                         ocr_text = "\n".join([
+#                             f"{item['text']} (confidence: {item['confidence']:.2f})"
+#                             for item in ocr_results["detected_text"]
+#                         ])
+
+#                 product_text = ProductAttributeService.combine_product_text(
+#                     title=title,
+#                     short_desc=short_desc,
+#                     long_desc=long_desc,
+#                     ocr_text=ocr_text
+#                 )
 
-#         # Extract attributes for all products in batch
-#         result = ProductAttributeService.extract_attributes_batch(
-#             products=validated_data['products'],
-#             mandatory_attrs=validated_data['mandatory_attrs'],
-#             model=validated_data.get('model'),
-#             extract_additional=validated_data.get('extract_additional', True),
-#             process_image=validated_data.get('process_image', True)
-#         )
+#                 extracted = ProductAttributeService.extract_attributes(
+#                     product_text=product_text,
+#                     mandatory_attrs=mandatory_attrs,
+#                     model=model,
+#                     extract_additional=extract_additional
+#                 )
 
-#         response_serializer = BatchProductResponseSerializer(data=result)
+#                 result = {
+#                     "product_id": product.item_id,
+#                     "mandatory": extracted.get("mandatory", {}),
+#                     "additional": extracted.get("additional", {}),
+#                 }
+
+#                 if ocr_results:
+#                     result["ocr_results"] = ocr_results
+
+#                 results.append(result)
+#                 successful += 1
+
+#             except Exception as e:
+#                 failed += 1
+#                 results.append({
+#                     "product_id": product.item_id,
+#                     "error": str(e)
+#                 })
+
+#         # Add missing item_ids as failed entries
+#         for mid in missing_ids:
+#             failed += 1
+#             results.append({
+#                 "product_id": mid,
+#                 "error": "Product not found in database"
+#             })
+
+#         batch_result = {
+#             "results": results,
+#             "total_products": len(item_ids),
+#             "successful": successful,
+#             "failed": failed
+#         }
+
+#         response_serializer = BatchProductResponseSerializer(data=batch_result)
 #         if response_serializer.is_valid():
 #             return Response(response_serializer.data, status=status.HTTP_200_OK)
-        
-#         return Response(result, status=status.HTTP_200_OK)
 
+#         return Response(batch_result, status=status.HTTP_200_OK)
 
-from .models import Product
 
 class BatchExtractProductAttributesView(APIView):
     """
-    API endpoint to extract product attributes for multiple products in batch by item_id.
-    Fetches all product details from database automatically.
+    API endpoint to extract product attributes for multiple products in batch.
+    Uses item-specific mandatory_attrs.
     """
 
     def post(self, request):
@@ -423,22 +432,42 @@ class BatchExtractProductAttributesView(APIView):
             return Response({"error": serializer.errors}, status=status.HTTP_400_BAD_REQUEST)
 
         validated_data = serializer.validated_data
-        item_ids = validated_data.get("item_ids", [])
+        
+        # Get batch-level settings
+        product_list = validated_data.get("products", []) # New: list of {item_id, mandatory_attrs}
         model = validated_data.get("model")
         extract_additional = validated_data.get("extract_additional", True)
         process_image = validated_data.get("process_image", True)
-        mandatory_attrs = validated_data["mandatory_attrs"]
-
+        
+        # Extract all item_ids to query the database efficiently
+        item_ids = [p['item_id'] for p in product_list] 
+        
         # Fetch all products in one query
-        products = Product.objects.filter(item_id__in=item_ids)
-        found_ids = set(products.values_list("item_id", flat=True))
-        missing_ids = [pid for pid in item_ids if pid not in found_ids]
-
+        products_queryset = Product.objects.filter(item_id__in=item_ids)
+        
+        # Create a dictionary for easy lookup: item_id -> Product object
+        product_map = {product.item_id: product for product in products_queryset}
+        found_ids = set(product_map.keys())
+        
         results = []
         successful = 0
         failed = 0
 
-        for product in products:
+        for product_entry in product_list:
+            item_id = product_entry['item_id']
+            # Get item-specific mandatory attributes
+            mandatory_attrs = product_entry['mandatory_attrs'] 
+
+            if item_id not in found_ids:
+                failed += 1
+                results.append({
+                    "product_id": item_id,
+                    "error": "Product not found in database"
+                })
+                continue # Skip to the next product
+
+            product = product_map[item_id]
+            
             try:
                 title = product.product_name
                 short_desc = product.product_short_description
@@ -448,11 +477,14 @@ class BatchExtractProductAttributesView(APIView):
                 ocr_results = None
                 ocr_text = None
 
+                # Image Processing Logic (same as before)
                 if process_image and image_url:
                     ocr_service = OCRService()
                     ocr_results = ocr_service.process_image(image_url)
 
                     if ocr_results and ocr_results.get("detected_text"):
+                        # Ensure the services are designed to handle 'mandatory_attrs'
+                        # for attribute extraction from OCR text
                         ocr_attrs = ProductAttributeService.extract_attributes_from_ocr(
                             ocr_results, model
                         )
@@ -469,9 +501,10 @@ class BatchExtractProductAttributesView(APIView):
                     ocr_text=ocr_text
                 )
 
+                # Attribute Extraction Logic - NOW USING ITEM-SPECIFIC mandatory_attrs
                 extracted = ProductAttributeService.extract_attributes(
                     product_text=product_text,
-                    mandatory_attrs=mandatory_attrs,
+                    mandatory_attrs=mandatory_attrs, # <--- Changed: now item-specific
                     model=model,
                     extract_additional=extract_additional
                 )
@@ -491,21 +524,17 @@ class BatchExtractProductAttributesView(APIView):
             except Exception as e:
                 failed += 1
                 results.append({
-                    "product_id": product.item_id,
+                    "product_id": item_id,
                     "error": str(e)
                 })
 
-        # Add missing item_ids as failed entries
-        for mid in missing_ids:
-            failed += 1
-            results.append({
-                "product_id": mid,
-                "error": "Product not found in database"
-            })
+        # No need for a separate missing_ids loop since we handle it when iterating over product_list
+        # The list comprehension `item_ids = [p['item_id'] for p in product_list]` and the check 
+        # `if item_id not in found_ids:` now correctly handle missing products from the input list.
 
         batch_result = {
             "results": results,
-            "total_products": len(item_ids),
+            "total_products": len(product_list),
             "successful": successful,
             "failed": failed
         }
@@ -536,8 +565,6 @@ class ProductListView(APIView):
 
 
 
-
-
 import pandas as pd
 from rest_framework.parsers import MultiPartParser, FormParser
 from rest_framework.views import APIView
@@ -547,62 +574,6 @@ from .models import Product
 from .serializers import ProductSerializer
 
 
-# class ProductUploadExcelView(APIView):
-#     """
-#     POST API to upload an Excel file and add data to Product model
-#     """
-#     parser_classes = (MultiPartParser, FormParser)
-
-#     def post(self, request, *args, **kwargs):
-#         file_obj = request.FILES.get('file')
-#         if not file_obj:
-#             return Response({'error': 'No file provided'}, status=status.HTTP_400_BAD_REQUEST)
-
-#         try:
-#             # Read the Excel file
-#             df = pd.read_excel(file_obj)
-
-#             # Normalize column names
-#             df.columns = [c.strip().lower().replace(' ', '_') for c in df.columns]
-
-#             # Expected columns
-#             expected_cols = {
-#                 'item_id',
-#                 'product_name',
-#                 'product_long_description',
-#                 'product_short_description',
-#                 'product_type',
-#                 'image_path'
-#             }
-
-#             if not expected_cols.issubset(df.columns):
-#                 return Response({
-#                     'error': 'Missing required columns',
-#                     'required_columns': list(expected_cols)
-#                 }, status=status.HTTP_400_BAD_REQUEST)
-
-#             # Loop through rows and create Product entries
-#             created_count = 0
-#             for _, row in df.iterrows():
-#                 Product.objects.create(
-#                     item_id=row.get('item_id', ''),
-#                     product_name=row.get('product_name', ''),
-#                     product_long_description=row.get('product_long_description', ''),
-#                     product_short_description=row.get('product_short_description', ''),
-#                     product_type=row.get('product_type', ''),
-#                     image_path=row.get('image_path', ''),
-#                 )
-#                 created_count += 1
-
-#             return Response({
-#                 'message': f'Successfully uploaded {created_count} products.'
-#             }, status=status.HTTP_201_CREATED)
-
-#         except Exception as e:
-#             return Response({'error': str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
-
-
-
 class ProductUploadExcelView(APIView):
     """
     POST API to upload an Excel file and add data to Product model (skip duplicates)
@@ -665,11 +636,6 @@ class ProductUploadExcelView(APIView):
 
 
 
-
-
-
-
-
 import pandas as pd
 from rest_framework.views import APIView
 from rest_framework.response import Response
@@ -678,6 +644,7 @@ from rest_framework.parsers import MultiPartParser, FormParser
 from .models import ProductType, ProductAttribute, AttributePossibleValue
 
 
+
 class ProductAttributesUploadView(APIView):
     """
     POST API to upload an Excel file and add mandatory/additional attributes