Explorar el Código

updated api to take only item_id

Harshit Pathak hace 3 meses
padre
commit
220a6d2b1e
Se han modificado 2 ficheros con 274 adiciones y 56 borrados
  1. 43 8
      attr_extraction/serializers.py
  2. 231 48
      attr_extraction/views.py

+ 43 - 8
attr_extraction/serializers.py

@@ -98,10 +98,11 @@ class ProductInputSerializer(serializers.Serializer):
 
 class SingleProductRequestSerializer(serializers.Serializer):
     """Serializer for single product extraction request."""
-    title = serializers.CharField(required=False, allow_blank=True, allow_null=True)
-    short_desc = serializers.CharField(required=False, allow_blank=True, allow_null=True)
-    long_desc = serializers.CharField(required=False, allow_blank=True, allow_null=True)
-    image_url = serializers.URLField(required=False, allow_blank=True, allow_null=True)
+    # title = serializers.CharField(required=False, allow_blank=True, allow_null=True)
+    # short_desc = serializers.CharField(required=False, allow_blank=True, allow_null=True)
+    # long_desc = serializers.CharField(required=False, allow_blank=True, allow_null=True)
+    # image_url = serializers.URLField(required=False, allow_blank=True, allow_null=True)
+    item_id = serializers.CharField(required=True)
     mandatory_attrs = serializers.DictField(
         child=serializers.ListField(child=serializers.CharField()),
         required=True
@@ -119,10 +120,43 @@ class SingleProductRequestSerializer(serializers.Serializer):
         return value
 
 
+# class BatchProductRequestSerializer(serializers.Serializer):
+#     """Serializer for batch product extraction request."""
+#     products = serializers.ListField(
+#         child=ProductInputSerializer(),
+#         required=True,
+#         min_length=1
+#     )
+#     mandatory_attrs = serializers.DictField(
+#         child=serializers.ListField(child=serializers.CharField()),
+#         required=True
+#     )
+#     model = serializers.CharField(required=False, default="llama-3.1-8b-instant")
+#     extract_additional = serializers.BooleanField(required=False, default=True)
+#     process_image = serializers.BooleanField(required=False, default=True)
+
+#     def validate_model(self, value):
+#         from django.conf import settings
+#         if value not in settings.SUPPORTED_MODELS:
+#             raise serializers.ValidationError(
+#                 f"Model must be one of {settings.SUPPORTED_MODELS}"
+#             )
+#         return value
+
+#     def validate_products(self, value):
+#         from django.conf import settings
+#         max_size = getattr(settings, 'MAX_BATCH_SIZE', 100)
+#         if len(value) > max_size:
+#             raise serializers.ValidationError(
+#                 f"Batch size cannot exceed {max_size} products"
+#             )
+#         return value
+
+
 class BatchProductRequestSerializer(serializers.Serializer):
-    """Serializer for batch product extraction request."""
-    products = serializers.ListField(
-        child=ProductInputSerializer(),
+    """Serializer for batch product extraction request (by item_id)."""
+    item_ids = serializers.ListField(
+        child=serializers.CharField(),
         required=True,
         min_length=1
     )
@@ -142,7 +176,7 @@ class BatchProductRequestSerializer(serializers.Serializer):
             )
         return value
 
-    def validate_products(self, value):
+    def validate_item_ids(self, value):
         from django.conf import settings
         max_size = getattr(settings, 'MAX_BATCH_SIZE', 100)
         if len(value) > max_size:
@@ -152,6 +186,7 @@ class BatchProductRequestSerializer(serializers.Serializer):
         return value
 
 
+
 class OCRResultSerializer(serializers.Serializer):
     """Serializer for OCR results."""
     detected_text = serializers.ListField(child=serializers.DictField())

+ 231 - 48
attr_extraction/views.py

@@ -235,103 +235,286 @@ from .services import ProductAttributeService
 from .ocr_service import OCRService
 
 
+# class ExtractProductAttributesView(APIView):
+#     """
+#     API endpoint to extract product attributes for a single product.
+#     Now supports image URL for OCR-based text extraction.
+#     """
+
+#     def post(self, request):
+#         serializer = SingleProductRequestSerializer(data=request.data)
+#         if not serializer.is_valid():
+#             return Response(
+#                 {"error": serializer.errors},
+#                 status=status.HTTP_400_BAD_REQUEST
+#             )
+
+#         validated_data = serializer.validated_data
+        
+#         # Process image if URL provided
+#         ocr_results = None
+#         ocr_text = None
+        
+#         if validated_data.get('process_image', True) and validated_data.get('image_url'):
+#             ocr_service = OCRService()
+#             ocr_results = ocr_service.process_image(validated_data['image_url'])
+            
+#             # Extract attributes from OCR
+#             if ocr_results and ocr_results.get('detected_text'):
+#                 ocr_attrs = ProductAttributeService.extract_attributes_from_ocr(
+#                     ocr_results,
+#                     validated_data.get('model')
+#                 )
+#                 ocr_results['extracted_attributes'] = ocr_attrs
+                
+#                 # Format OCR text
+#                 ocr_text = "\n".join([
+#                     f"{item['text']} (confidence: {item['confidence']:.2f})"
+#                     for item in ocr_results['detected_text']
+#                 ])
+
+#         # Combine all product information
+#         product_text = ProductAttributeService.combine_product_text(
+#             title=validated_data.get('title'),
+#             short_desc=validated_data.get('short_desc'),
+#             long_desc=validated_data.get('long_desc'),
+#             ocr_text=ocr_text
+#         )
+
+#         # Extract attributes
+#         result = ProductAttributeService.extract_attributes(
+#             product_text=product_text,
+#             mandatory_attrs=validated_data['mandatory_attrs'],
+#             model=validated_data.get('model'),
+#             extract_additional=validated_data.get('extract_additional', True)
+#         )
+        
+#         # Add OCR results if available
+#         if ocr_results:
+#             result['ocr_results'] = ocr_results
+
+#         response_serializer = ProductAttributeResultSerializer(data=result)
+#         if response_serializer.is_valid():
+#             return Response(response_serializer.data, status=status.HTTP_200_OK)
+        
+#         return Response(result, status=status.HTTP_200_OK)
+
+
+from .models import Product
+
 class ExtractProductAttributesView(APIView):
     """
-    API endpoint to extract product attributes for a single product.
-    Now supports image URL for OCR-based text extraction.
+    API endpoint to extract product attributes for a single product by item_id.
+    Fetches product details from database.
     """
 
     def post(self, request):
         serializer = SingleProductRequestSerializer(data=request.data)
         if not serializer.is_valid():
+            return Response({"error": serializer.errors}, status=status.HTTP_400_BAD_REQUEST)
+
+        validated_data = serializer.validated_data
+        item_id = validated_data.get("item_id")
+
+        # Fetch product from DB
+        try:
+            product = Product.objects.get(item_id=item_id)
+        except Product.DoesNotExist:
             return Response(
-                {"error": serializer.errors},
-                status=status.HTTP_400_BAD_REQUEST
+                {"error": f"Product with item_id '{item_id}' not found."},
+                status=status.HTTP_404_NOT_FOUND
             )
 
-        validated_data = serializer.validated_data
-        
-        # Process image if URL provided
+        # Extract product details
+        title = product.product_name
+        short_desc = product.product_short_description
+        long_desc = product.product_long_description
+        image_url = product.image_path
+
+        # Process image for OCR if required
         ocr_results = None
         ocr_text = None
-        
-        if validated_data.get('process_image', True) and validated_data.get('image_url'):
+
+        if validated_data.get("process_image", True) and image_url:
             ocr_service = OCRService()
-            ocr_results = ocr_service.process_image(validated_data['image_url'])
-            
-            # Extract attributes from OCR
-            if ocr_results and ocr_results.get('detected_text'):
+            ocr_results = ocr_service.process_image(image_url)
+
+            if ocr_results and ocr_results.get("detected_text"):
                 ocr_attrs = ProductAttributeService.extract_attributes_from_ocr(
-                    ocr_results,
-                    validated_data.get('model')
+                    ocr_results, validated_data.get("model")
                 )
-                ocr_results['extracted_attributes'] = ocr_attrs
-                
-                # Format OCR text
+                ocr_results["extracted_attributes"] = ocr_attrs
+
                 ocr_text = "\n".join([
                     f"{item['text']} (confidence: {item['confidence']:.2f})"
-                    for item in ocr_results['detected_text']
+                    for item in ocr_results["detected_text"]
                 ])
 
-        # Combine all product information
+        # Combine all product text
         product_text = ProductAttributeService.combine_product_text(
-            title=validated_data.get('title'),
-            short_desc=validated_data.get('short_desc'),
-            long_desc=validated_data.get('long_desc'),
+            title=title,
+            short_desc=short_desc,
+            long_desc=long_desc,
             ocr_text=ocr_text
         )
 
         # Extract attributes
         result = ProductAttributeService.extract_attributes(
             product_text=product_text,
-            mandatory_attrs=validated_data['mandatory_attrs'],
-            model=validated_data.get('model'),
-            extract_additional=validated_data.get('extract_additional', True)
+            mandatory_attrs=validated_data["mandatory_attrs"],
+            model=validated_data.get("model"),
+            extract_additional=validated_data.get("extract_additional", True)
         )
-        
-        # Add OCR results if available
+
+        # Attach OCR results if available
         if ocr_results:
-            result['ocr_results'] = ocr_results
+            result["ocr_results"] = ocr_results
 
         response_serializer = ProductAttributeResultSerializer(data=result)
         if response_serializer.is_valid():
             return Response(response_serializer.data, status=status.HTTP_200_OK)
-        
+
         return Response(result, status=status.HTTP_200_OK)
 
 
+# class BatchExtractProductAttributesView(APIView):
+#     """
+#     API endpoint to extract product attributes for multiple products in batch.
+#     Now supports image URLs for OCR-based text extraction.
+#     """
+
+#     def post(self, request):
+#         serializer = BatchProductRequestSerializer(data=request.data)
+#         if not serializer.is_valid():
+#             return Response(
+#                 {"error": serializer.errors},
+#                 status=status.HTTP_400_BAD_REQUEST
+#             )
+
+#         validated_data = serializer.validated_data
+
+#         # Extract attributes for all products in batch
+#         result = ProductAttributeService.extract_attributes_batch(
+#             products=validated_data['products'],
+#             mandatory_attrs=validated_data['mandatory_attrs'],
+#             model=validated_data.get('model'),
+#             extract_additional=validated_data.get('extract_additional', True),
+#             process_image=validated_data.get('process_image', True)
+#         )
+
+#         response_serializer = BatchProductResponseSerializer(data=result)
+#         if response_serializer.is_valid():
+#             return Response(response_serializer.data, status=status.HTTP_200_OK)
+        
+#         return Response(result, status=status.HTTP_200_OK)
+
+
+from .models import Product
+
 class BatchExtractProductAttributesView(APIView):
     """
-    API endpoint to extract product attributes for multiple products in batch.
-    Now supports image URLs for OCR-based text extraction.
+    API endpoint to extract product attributes for multiple products in batch by item_id.
+    Fetches all product details from database automatically.
     """
 
     def post(self, request):
         serializer = BatchProductRequestSerializer(data=request.data)
         if not serializer.is_valid():
-            return Response(
-                {"error": serializer.errors},
-                status=status.HTTP_400_BAD_REQUEST
-            )
+            return Response({"error": serializer.errors}, status=status.HTTP_400_BAD_REQUEST)
 
         validated_data = serializer.validated_data
+        item_ids = validated_data.get("item_ids", [])
+        model = validated_data.get("model")
+        extract_additional = validated_data.get("extract_additional", True)
+        process_image = validated_data.get("process_image", True)
+        mandatory_attrs = validated_data["mandatory_attrs"]
+
+        # Fetch all products in one query
+        products = Product.objects.filter(item_id__in=item_ids)
+        found_ids = set(products.values_list("item_id", flat=True))
+        missing_ids = [pid for pid in item_ids if pid not in found_ids]
+
+        results = []
+        successful = 0
+        failed = 0
+
+        for product in products:
+            try:
+                title = product.product_name
+                short_desc = product.product_short_description
+                long_desc = product.product_long_description
+                image_url = product.image_path
+
+                ocr_results = None
+                ocr_text = None
+
+                if process_image and image_url:
+                    ocr_service = OCRService()
+                    ocr_results = ocr_service.process_image(image_url)
+
+                    if ocr_results and ocr_results.get("detected_text"):
+                        ocr_attrs = ProductAttributeService.extract_attributes_from_ocr(
+                            ocr_results, model
+                        )
+                        ocr_results["extracted_attributes"] = ocr_attrs
+                        ocr_text = "\n".join([
+                            f"{item['text']} (confidence: {item['confidence']:.2f})"
+                            for item in ocr_results["detected_text"]
+                        ])
+
+                product_text = ProductAttributeService.combine_product_text(
+                    title=title,
+                    short_desc=short_desc,
+                    long_desc=long_desc,
+                    ocr_text=ocr_text
+                )
 
-        # Extract attributes for all products in batch
-        result = ProductAttributeService.extract_attributes_batch(
-            products=validated_data['products'],
-            mandatory_attrs=validated_data['mandatory_attrs'],
-            model=validated_data.get('model'),
-            extract_additional=validated_data.get('extract_additional', True),
-            process_image=validated_data.get('process_image', True)
-        )
+                extracted = ProductAttributeService.extract_attributes(
+                    product_text=product_text,
+                    mandatory_attrs=mandatory_attrs,
+                    model=model,
+                    extract_additional=extract_additional
+                )
 
-        response_serializer = BatchProductResponseSerializer(data=result)
+                result = {
+                    "product_id": product.item_id,
+                    "mandatory": extracted.get("mandatory", {}),
+                    "additional": extracted.get("additional", {}),
+                }
+
+                if ocr_results:
+                    result["ocr_results"] = ocr_results
+
+                results.append(result)
+                successful += 1
+
+            except Exception as e:
+                failed += 1
+                results.append({
+                    "product_id": product.item_id,
+                    "error": str(e)
+                })
+
+        # Add missing item_ids as failed entries
+        for mid in missing_ids:
+            failed += 1
+            results.append({
+                "product_id": mid,
+                "error": "Product not found in database"
+            })
+
+        batch_result = {
+            "results": results,
+            "total_products": len(item_ids),
+            "successful": successful,
+            "failed": failed
+        }
+
+        response_serializer = BatchProductResponseSerializer(data=batch_result)
         if response_serializer.is_valid():
             return Response(response_serializer.data, status=status.HTTP_200_OK)
-        
-        return Response(result, status=status.HTTP_200_OK)
-
 
+        return Response(batch_result, status=status.HTTP_200_OK)