||
- from rest_framework.views import APIView
- from rest_framework.response import Response
- from rest_framework import status
- from rest_framework.parsers import MultiPartParser, FormParser
- from django.db import transaction
- import pandas as pd
- from .models import Product, ProductType, ProductAttribute, AttributePossibleValue
- from .serializers import (
- SingleProductRequestSerializer,
- BatchProductRequestSerializer,
- ProductAttributeResultSerializer,
- BatchProductResponseSerializer,
- ProductSerializer,
- ProductTypeSerializer,
- ProductAttributeSerializer,
- AttributePossibleValueSerializer
- )
- from .services import ProductAttributeService
- from .ocr_service import OCRService
- # Sample test images (publicly available)
- SAMPLE_IMAGES = {
- "tshirt": "https://images.unsplash.com/photo-1521572163474-6864f9cf17ab",
- "dress": "https://images.unsplash.com/photo-1595777457583-95e059d581b8",
- "jeans": "https://images.unsplash.com/photo-1542272604-787c3835535d"
- }
- # ==================== Updated views.py ====================
- from rest_framework.views import APIView
- from rest_framework.response import Response
- from rest_framework import status
- from .models import Product
- from .services import ProductAttributeService
- from .ocr_service import OCRService
- from .visual_processing_service import VisualProcessingService
- class ExtractProductAttributesView(APIView):
- """
- API endpoint to extract product attributes for a single product by item_id.
- Fetches product details from database with source tracking.
- Returns attributes in array format: [{"value": "...", "source": "..."}]
- Includes OCR and Visual Processing results.
- """
- def post(self, request):
- serializer = SingleProductRequestSerializer(data=request.data)
- if not serializer.is_valid():
- return Response({"error": serializer.errors}, status=status.HTTP_400_BAD_REQUEST)
- validated_data = serializer.validated_data
- item_id = validated_data.get("item_id")
- # Fetch product from DB
- try:
- product = Product.objects.get(item_id=item_id)
- except Product.DoesNotExist:
- return Response(
- {"error": f"Product with item_id '{item_id}' not found."},
- status=status.HTTP_404_NOT_FOUND
- )
- # Extract product details
- title = product.product_name
- short_desc = product.product_short_description
- long_desc = product.product_long_description
- image_url = product.image_path
- # Process image for OCR if required
- ocr_results = None
- ocr_text = None
- visual_results = None
- if validated_data.get("process_image", True) and image_url:
- # OCR Processing
- ocr_service = OCRService()
- ocr_results = ocr_service.process_image(image_url)
- if ocr_results and ocr_results.get("detected_text"):
- ocr_attrs = ProductAttributeService.extract_attributes_from_ocr(
- ocr_results, validated_data.get("model")
- )
- ocr_results["extracted_attributes"] = ocr_attrs
- ocr_text = "\n".join([
- f"{item['text']} (confidence: {item['confidence']:.2f})"
- for item in ocr_results["detected_text"]
- ])
-
- # Visual Processing
- visual_service = VisualProcessingService()
- product_type_hint = product.product_type if hasattr(product, 'product_type') else None
- visual_results = visual_service.process_image(image_url, product_type_hint)
- # Combine all product text with source tracking
- product_text, source_map = ProductAttributeService.combine_product_text(
- title=title,
- short_desc=short_desc,
- long_desc=long_desc,
- ocr_text=ocr_text
- )
- # Extract attributes with enhanced features and source tracking
- result = ProductAttributeService.extract_attributes(
- product_text=product_text,
- mandatory_attrs=validated_data["mandatory_attrs"],
- source_map=source_map,
- model=validated_data.get("model"),
- extract_additional=validated_data.get("extract_additional", True),
- multiple=validated_data.get("multiple", []),
- threshold_abs=validated_data.get("threshold_abs", 0.65),
- margin=validated_data.get("margin", 0.15),
- use_dynamic_thresholds=validated_data.get("use_dynamic_thresholds", True),
- use_adaptive_margin=validated_data.get("use_adaptive_margin", True),
- use_semantic_clustering=validated_data.get("use_semantic_clustering", True)
- )
- # Attach OCR results if available
- if ocr_results:
- result["ocr_results"] = ocr_results
-
- # Attach Visual Processing results if available
- if visual_results:
- result["visual_results"] = visual_results
- response_serializer = ProductAttributeResultSerializer(data=result)
- if response_serializer.is_valid():
- return Response(response_serializer.data, status=status.HTTP_200_OK)
- return Response(result, status=status.HTTP_200_OK)
- # class BatchExtractProductAttributesView(APIView):
- # """
- # API endpoint to extract product attributes for multiple products in batch.
- # Uses item-specific mandatory_attrs with source tracking.
- # Returns attributes in array format: [{"value": "...", "source": "..."}]
- # Includes OCR and Visual Processing results.
- # """
- # def post(self, request):
- # serializer = BatchProductRequestSerializer(data=request.data)
- # if not serializer.is_valid():
- # return Response({"error": serializer.errors}, status=status.HTTP_400_BAD_REQUEST)
- # validated_data = serializer.validated_data
-
- # # DEBUG: Print what we received
- # print("\n" + "="*80)
- # print("BATCH REQUEST - RECEIVED DATA")
- # print("="*80)
- # print(f"Raw request data keys: {request.data.keys()}")
- # print(f"Multiple field in request: {request.data.get('multiple')}")
- # print(f"Validated multiple field: {validated_data.get('multiple')}")
- # print("="*80 + "\n")
-
- # # Get batch-level settings
- # product_list = validated_data.get("products", [])
- # model = validated_data.get("model")
- # extract_additional = validated_data.get("extract_additional", True)
- # process_image = validated_data.get("process_image", True)
- # multiple = validated_data.get("multiple", [])
- # threshold_abs = validated_data.get("threshold_abs", 0.65)
- # margin = validated_data.get("margin", 0.15)
- # use_dynamic_thresholds = validated_data.get("use_dynamic_thresholds", True)
- # use_adaptive_margin = validated_data.get("use_adaptive_margin", True)
- # use_semantic_clustering = validated_data.get("use_semantic_clustering", True)
-
- # # DEBUG: Print extracted settings
- # print(f"Extracted multiple parameter: {multiple}")
- # print(f"Type: {type(multiple)}")
-
- # # Extract all item_ids to query the database efficiently
- # item_ids = [p['item_id'] for p in product_list]
-
- # # Fetch all products in one query
- # products_queryset = Product.objects.filter(item_id__in=item_ids)
-
- # # Create a dictionary for easy lookup: item_id -> Product object
- # product_map = {product.item_id: product for product in products_queryset}
- # found_ids = set(product_map.keys())
-
- # results = []
- # successful = 0
- # failed = 0
- # for product_entry in product_list:
- # item_id = product_entry['item_id']
- # # Get item-specific mandatory attributes
- # mandatory_attrs = product_entry['mandatory_attrs']
- # if item_id not in found_ids:
- # failed += 1
- # results.append({
- # "product_id": item_id,
- # "error": "Product not found in database"
- # })
- # continue
- # product = product_map[item_id]
-
- # try:
- # title = product.product_name
- # short_desc = product.product_short_description
- # long_desc = product.product_long_description
- # image_url = product.image_path
- # # image_url = "https://images.unsplash.com/photo-1595777457583-95e059d581b8"
- # ocr_results = None
- # ocr_text = None
- # visual_results = None
- # # Image Processing Logic
- # if process_image and image_url:
- # # OCR Processing
- # ocr_service = OCRService()
- # ocr_results = ocr_service.process_image(image_url)
- # print(f"OCR results for {item_id}: {ocr_results}")
-
- # if ocr_results and ocr_results.get("detected_text"):
- # ocr_attrs = ProductAttributeService.extract_attributes_from_ocr(
- # ocr_results, model
- # )
- # ocr_results["extracted_attributes"] = ocr_attrs
- # ocr_text = "\n".join([
- # f"{item['text']} (confidence: {item['confidence']:.2f})"
- # for item in ocr_results["detected_text"]
- # ])
-
- # # Visual Processing
- # visual_service = VisualProcessingService()
- # product_type_hint = product.product_type if hasattr(product, 'product_type') else None
- # visual_results = visual_service.process_image(image_url, product_type_hint)
- # print(f"Visual results for {item_id}: {visual_results.get('visual_attributes', {})}")
- # # Combine product text with source tracking
- # product_text, source_map = ProductAttributeService.combine_product_text(
- # title=title,
- # short_desc=short_desc,
- # long_desc=long_desc,
- # ocr_text=ocr_text
- # )
- # # DEBUG: Print before extraction
- # print(f"\n>>> Extracting for product {item_id}")
- # print(f" Passing multiple: {multiple}")
- # # Attribute Extraction with source tracking (returns array format)
- # extracted = ProductAttributeService.extract_attributes(
- # product_text=product_text,
- # mandatory_attrs=mandatory_attrs,
- # source_map=source_map,
- # model=model,
- # extract_additional=extract_additional,
- # multiple=multiple,
- # threshold_abs=threshold_abs,
- # margin=margin,
- # use_dynamic_thresholds=use_dynamic_thresholds,
- # use_adaptive_margin=use_adaptive_margin,
- # use_semantic_clustering=use_semantic_clustering
- # )
- # result = {
- # "product_id": product.item_id,
- # "mandatory": extracted.get("mandatory", {}),
- # "additional": extracted.get("additional", {}),
- # }
- # # Attach OCR results if available
- # if ocr_results:
- # result["ocr_results"] = ocr_results
-
- # # Attach Visual Processing results if available
- # if visual_results:
- # result["visual_results"] = visual_results
- # results.append(result)
- # successful += 1
- # except Exception as e:
- # failed += 1
- # results.append({
- # "product_id": item_id,
- # "error": str(e)
- # })
- # batch_result = {
- # "results": results,
- # "total_products": len(product_list),
- # "successful": successful,
- # "failed": failed
- # }
- # response_serializer = BatchProductResponseSerializer(data=batch_result)
- # if response_serializer.is_valid():
- # return Response(response_serializer.data, status=status.HTTP_200_OK)
- # return Response(batch_result, status=status.HTTP_200_OK)
- class BatchExtractProductAttributesView(APIView):
- """
- API endpoint to extract product attributes for multiple products in batch.
- Uses item-specific mandatory_attrs with source tracking.
- Returns attributes in array format: [{"value": "...", "source": "..."}]
- Includes OCR and Visual Processing results.
- """
- def post(self, request):
- serializer = BatchProductRequestSerializer(data=request.data)
- if not serializer.is_valid():
- return Response({"error": serializer.errors}, status=status.HTTP_400_BAD_REQUEST)
- validated_data = serializer.validated_data
-
- # DEBUG: Print what we received
- print("\n" + "="*80)
- print("BATCH REQUEST - RECEIVED DATA")
- print("="*80)
- print(f"Raw request data keys: {request.data.keys()}")
- print(f"Multiple field in request: {request.data.get('multiple')}")
- print(f"Validated multiple field: {validated_data.get('multiple')}")
- print("="*80 + "\n")
-
- # Get batch-level settings
- product_list = validated_data.get("products", [])
- model = validated_data.get("model")
- extract_additional = validated_data.get("extract_additional", True)
- process_image = validated_data.get("process_image", True)
- multiple = validated_data.get("multiple", [])
- threshold_abs = validated_data.get("threshold_abs", 0.65)
- margin = validated_data.get("margin", 0.15)
- use_dynamic_thresholds = validated_data.get("use_dynamic_thresholds", True)
- use_adaptive_margin = validated_data.get("use_adaptive_margin", True)
- use_semantic_clustering = validated_data.get("use_semantic_clustering", True)
-
- # DEBUG: Print extracted settings
- print(f"Extracted multiple parameter: {multiple}")
- print(f"Type: {type(multiple)}")
-
- # Extract all item_ids to query the database efficiently
- item_ids = [p['item_id'] for p in product_list]
-
- # Fetch all products in one query
- products_queryset = Product.objects.filter(item_id__in=item_ids)
-
- # Create a dictionary for easy lookup: item_id -> Product object
- product_map = {product.item_id: product for product in products_queryset}
- found_ids = set(product_map.keys())
-
- results = []
- successful = 0
- failed = 0
- for product_entry in product_list:
- item_id = product_entry['item_id']
- # Get item-specific mandatory attributes
- mandatory_attrs = product_entry['mandatory_attrs']
- if item_id not in found_ids:
- failed += 1
- results.append({
- "product_id": item_id,
- "error": "Product not found in database"
- })
- continue
- product = product_map[item_id]
-
- try:
- title = product.product_name
- short_desc = product.product_short_description
- long_desc = product.product_long_description
- image_url = product.image_path
- # image_url = "https://images.unsplash.com/photo-1595777457583-95e059d581b8"
- ocr_results = None
- ocr_text = None
- visual_results = None
- # Image Processing Logic
- if process_image and image_url:
- # OCR Processing
- ocr_service = OCRService()
- ocr_results = ocr_service.process_image(image_url)
- print(f"OCR results for {item_id}: {ocr_results}")
-
- if ocr_results and ocr_results.get("detected_text"):
- ocr_attrs = ProductAttributeService.extract_attributes_from_ocr(
- ocr_results, model
- )
- ocr_results["extracted_attributes"] = ocr_attrs
- ocr_text = "\n".join([
- f"{item['text']} (confidence: {item['confidence']:.2f})"
- for item in ocr_results["detected_text"]
- ])
-
- # Visual Processing
- visual_service = VisualProcessingService()
- product_type_hint = product.product_type if hasattr(product, 'product_type') else None
- visual_results = visual_service.process_image(image_url, product_type_hint)
- print(f"Visual results for {item_id}: {visual_results.get('visual_attributes', {})}")
-
- # Format visual attributes to array format with source tracking
- if visual_results and visual_results.get('visual_attributes'):
- visual_results['visual_attributes'] = ProductAttributeService.format_visual_attributes(
- visual_results['visual_attributes']
- )
- # Combine product text with source tracking
- product_text, source_map = ProductAttributeService.combine_product_text(
- title=title,
- short_desc=short_desc,
- long_desc=long_desc,
- ocr_text=ocr_text
- )
- # DEBUG: Print before extraction
- print(f"\n>>> Extracting for product {item_id}")
- print(f" Passing multiple: {multiple}")
- # Attribute Extraction with source tracking (returns array format)
- extracted = ProductAttributeService.extract_attributes(
- product_text=product_text,
- mandatory_attrs=mandatory_attrs,
- source_map=source_map,
- model=model,
- extract_additional=extract_additional,
- multiple=multiple,
- threshold_abs=threshold_abs,
- margin=margin,
- use_dynamic_thresholds=use_dynamic_thresholds,
- use_adaptive_margin=use_adaptive_margin,
- use_semantic_clustering=use_semantic_clustering
- )
- result = {
- "product_id": product.item_id,
- "mandatory": extracted.get("mandatory", {}),
- "additional": extracted.get("additional", {}),
- }
- # Attach OCR results if available
- if ocr_results:
- result["ocr_results"] = ocr_results
-
- # Attach Visual Processing results if available
- if visual_results:
- result["visual_results"] = visual_results
- results.append(result)
- successful += 1
- except Exception as e:
- failed += 1
- results.append({
- "product_id": item_id,
- "error": str(e)
- })
- batch_result = {
- "results": results,
- "total_products": len(product_list),
- "successful": successful,
- "failed": failed
- }
- response_serializer = BatchProductResponseSerializer(data=batch_result)
- if response_serializer.is_valid():
- return Response(response_serializer.data, status=status.HTTP_200_OK)
- return Response(batch_result, status=status.HTTP_200_OK)
- # class ExtractProductAttributesView(APIView):
- # """
- # API endpoint to extract product attributes for a single product by item_id.
- # Fetches product details from database with source tracking.
- # Returns attributes in array format: [{"value": "...", "source": "..."}]
- # """
- # def post(self, request):
- # serializer = SingleProductRequestSerializer(data=request.data)
- # if not serializer.is_valid():
- # return Response({"error": serializer.errors}, status=status.HTTP_400_BAD_REQUEST)
- # validated_data = serializer.validated_data
- # item_id = validated_data.get("item_id")
- # # Fetch product from DB
- # try:
- # product = Product.objects.get(item_id=item_id)
- # except Product.DoesNotExist:
- # return Response(
- # {"error": f"Product with item_id '{item_id}' not found."},
- # status=status.HTTP_404_NOT_FOUND
- # )
- # # Extract product details
- # title = product.product_name
- # short_desc = product.product_short_description
- # long_desc = product.product_long_description
- # image_url = product.image_path
- # # Process image for OCR if required
- # ocr_results = None
- # ocr_text = None
- # if validated_data.get("process_image", True) and image_url:
- # ocr_service = OCRService()
- # ocr_results = ocr_service.process_image(image_url)
- # if ocr_results and ocr_results.get("detected_text"):
- # ocr_attrs = ProductAttributeService.extract_attributes_from_ocr(
- # ocr_results, validated_data.get("model")
- # )
- # ocr_results["extracted_attributes"] = ocr_attrs
- # ocr_text = "\n".join([
- # f"{item['text']} (confidence: {item['confidence']:.2f})"
- # for item in ocr_results["detected_text"]
- # ])
- # # Combine all product text with source tracking
- # product_text, source_map = ProductAttributeService.combine_product_text(
- # title=title,
- # short_desc=short_desc,
- # long_desc=long_desc,
- # ocr_text=ocr_text
- # )
- # # Extract attributes with enhanced features and source tracking
- # result = ProductAttributeService.extract_attributes(
- # product_text=product_text,
- # mandatory_attrs=validated_data["mandatory_attrs"],
- # source_map=source_map,
- # model=validated_data.get("model"),
- # extract_additional=validated_data.get("extract_additional", True),
- # multiple=validated_data.get("multiple", []),
- # threshold_abs=validated_data.get("threshold_abs", 0.65),
- # margin=validated_data.get("margin", 0.15),
- # use_dynamic_thresholds=validated_data.get("use_dynamic_thresholds", True),
- # use_adaptive_margin=validated_data.get("use_adaptive_margin", True),
- # use_semantic_clustering=validated_data.get("use_semantic_clustering", True)
- # )
- # # Attach OCR results if available
- # if ocr_results:
- # result["ocr_results"] = ocr_results
- # response_serializer = ProductAttributeResultSerializer(data=result)
- # if response_serializer.is_valid():
- # return Response(response_serializer.data, status=status.HTTP_200_OK)
- # return Response(result, status=status.HTTP_200_OK)
- # class BatchExtractProductAttributesView(APIView):
- # """
- # API endpoint to extract product attributes for multiple products in batch.
- # Uses item-specific mandatory_attrs with source tracking.
- # Returns attributes in array format: [{"value": "...", "source": "..."}]
- # """
- # def post(self, request):
- # serializer = BatchProductRequestSerializer(data=request.data)
- # if not serializer.is_valid():
- # return Response({"error": serializer.errors}, status=status.HTTP_400_BAD_REQUEST)
- # validated_data = serializer.validated_data
-
- # # DEBUG: Print what we received
- # print("\n" + "="*80)
- # print("BATCH REQUEST - RECEIVED DATA")
- # print("="*80)
- # print(f"Raw request data keys: {request.data.keys()}")
- # print(f"Multiple field in request: {request.data.get('multiple')}")
- # print(f"Validated multiple field: {validated_data.get('multiple')}")
- # print("="*80 + "\n")
-
- # # Get batch-level settings
- # product_list = validated_data.get("products", [])
- # model = validated_data.get("model")
- # extract_additional = validated_data.get("extract_additional", True)
- # process_image = validated_data.get("process_image", True)
- # multiple = validated_data.get("multiple", [])
- # threshold_abs = validated_data.get("threshold_abs", 0.65)
- # margin = validated_data.get("margin", 0.15)
- # use_dynamic_thresholds = validated_data.get("use_dynamic_thresholds", True)
- # use_adaptive_margin = validated_data.get("use_adaptive_margin", True)
- # use_semantic_clustering = validated_data.get("use_semantic_clustering", True)
-
- # # DEBUG: Print extracted settings
- # print(f"Extracted multiple parameter: {multiple}")
- # print(f"Type: {type(multiple)}")
-
- # # Extract all item_ids to query the database efficiently
- # item_ids = [p['item_id'] for p in product_list]
-
- # # Fetch all products in one query
- # products_queryset = Product.objects.filter(item_id__in=item_ids)
-
- # # Create a dictionary for easy lookup: item_id -> Product object
- # product_map = {product.item_id: product for product in products_queryset}
- # found_ids = set(product_map.keys())
-
- # results = []
- # successful = 0
- # failed = 0
- # for product_entry in product_list:
- # item_id = product_entry['item_id']
- # # Get item-specific mandatory attributes
- # mandatory_attrs = product_entry['mandatory_attrs']
- # if item_id not in found_ids:
- # failed += 1
- # results.append({
- # "product_id": item_id,
- # "error": "Product not found in database"
- # })
- # continue
- # product = product_map[item_id]
-
- # try:
- # title = product.product_name
- # short_desc = product.product_short_description
- # long_desc = product.product_long_description
- # # image_url = product.image_path
- # image_url = "http://localhost:8000/media/products/levi_test_ocr2.jpg"
- # ocr_results = None
- # ocr_text = None
- # # Image Processing Logic
- # if process_image and image_url:
- # ocr_service = OCRService()
- # ocr_results = ocr_service.process_image(image_url)
- # print(f"ocr results are: {ocr_results}")
- # if ocr_results and ocr_results.get("detected_text"):
- # ocr_attrs = ProductAttributeService.extract_attributes_from_ocr(
- # ocr_results, model
- # )
- # ocr_results["extracted_attributes"] = ocr_attrs
- # ocr_text = "\n".join([
- # f"{item['text']} (confidence: {item['confidence']:.2f})"
- # for item in ocr_results["detected_text"]
- # ])
- # # Combine product text with source tracking
- # product_text, source_map = ProductAttributeService.combine_product_text(
- # title=title,
- # short_desc=short_desc,
- # long_desc=long_desc,
- # ocr_text=ocr_text
- # )
- # # DEBUG: Print before extraction
- # print(f"\n>>> Extracting for product {item_id}")
- # print(f" Passing multiple: {multiple}")
- # # Attribute Extraction with source tracking (returns array format)
- # extracted = ProductAttributeService.extract_attributes(
- # product_text=product_text,
- # mandatory_attrs=mandatory_attrs,
- # source_map=source_map,
- # model=model,
- # extract_additional=extract_additional,
- # multiple=multiple, # Make sure this is passed!
- # threshold_abs=threshold_abs,
- # margin=margin,
- # use_dynamic_thresholds=use_dynamic_thresholds,
- # use_adaptive_margin=use_adaptive_margin,
- # use_semantic_clustering=use_semantic_clustering
- # )
- # result = {
- # "product_id": product.item_id,
- # "mandatory": extracted.get("mandatory", {}),
- # "additional": extracted.get("additional", {}),
- # }
- # if ocr_results:
- # result["ocr_results"] = ocr_results
- # results.append(result)
- # successful += 1
- # except Exception as e:
- # failed += 1
- # results.append({
- # "product_id": item_id,
- # "error": str(e)
- # })
- # batch_result = {
- # "results": results,
- # "total_products": len(product_list),
- # "successful": successful,
- # "failed": failed
- # }
- # response_serializer = BatchProductResponseSerializer(data=batch_result)
- # if response_serializer.is_valid():
- # return Response(response_serializer.data, status=status.HTTP_200_OK)
- # return Response(batch_result, status=status.HTTP_200_OK)
- class ProductListView(APIView):
- """
- GET API to list all products with details
- """
- def get(self, request):
- products = Product.objects.all()
- serializer = ProductSerializer(products, many=True)
- return Response(serializer.data, status=status.HTTP_200_OK)
- # class ProductUploadExcelView(APIView):
- # """
- # POST API to upload an Excel file and add data to Product model (skip duplicates)
- # """
- # parser_classes = (MultiPartParser, FormParser)
- # def post(self, request, *args, **kwargs):
- # file_obj = request.FILES.get('file')
- # if not file_obj:
- # return Response({'error': 'No file provided'}, status=status.HTTP_400_BAD_REQUEST)
- # try:
- # df = pd.read_excel(file_obj)
- # df.columns = [c.strip().lower().replace(' ', '_') for c in df.columns]
- # expected_cols = {
- # 'item_id',
- # 'product_name',
- # 'product_long_description',
- # 'product_short_description',
- # 'product_type',
- # 'image_path'
- # }
- # if not expected_cols.issubset(df.columns):
- # return Response({
- # 'error': 'Missing required columns',
- # 'required_columns': list(expected_cols)
- # }, status=status.HTTP_400_BAD_REQUEST)
- # created_count = 0
- # skipped_count = 0
- # for _, row in df.iterrows():
- # item_id = row.get('item_id', '')
- # # Check if this item already exists
- # if Product.objects.filter(item_id=item_id).exists():
- # skipped_count += 1
- # continue
- # Product.objects.create(
- # item_id=item_id,
- # product_name=row.get('product_name', ''),
- # product_long_description=row.get('product_long_description', ''),
- # product_short_description=row.get('product_short_description', ''),
- # product_type=row.get('product_type', ''),
- # image_path=row.get('image_path', ''),
- # )
- # created_count += 1
- # return Response({
- # 'message': f'Successfully uploaded {created_count} products.',
- # 'skipped': f'Skipped {skipped_count} duplicates.'
- # }, status=status.HTTP_201_CREATED)
- # except Exception as e:
- # return Response({'error': str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
- from rest_framework.views import APIView
- from rest_framework.response import Response
- from rest_framework import status
- from rest_framework.parsers import MultiPartParser, FormParser
- import pandas as pd
- from .models import Product
- class ProductUploadExcelView(APIView):
- """
- POST API to upload an Excel file and add/update data in Product model.
- - Creates new records if they don't exist.
- - Updates existing ones (e.g., when image_path or other fields change).
- """
- parser_classes = (MultiPartParser, FormParser)
- def post(self, request, *args, **kwargs):
- file_obj = request.FILES.get('file')
- if not file_obj:
- return Response({'error': 'No file provided'}, status=status.HTTP_400_BAD_REQUEST)
- try:
- # Read Excel into DataFrame
- df = pd.read_excel(file_obj)
- df.columns = [c.strip().lower().replace(' ', '_') for c in df.columns]
- expected_cols = {
- 'item_id',
- 'product_name',
- 'product_long_description',
- 'product_short_description',
- 'product_type',
- 'image_path'
- }
- # Check required columns
- if not expected_cols.issubset(df.columns):
- return Response({
- 'error': 'Missing required columns',
- 'required_columns': list(expected_cols)
- }, status=status.HTTP_400_BAD_REQUEST)
- created_count = 0
- updated_count = 0
- # Loop through rows and update or create
- for _, row in df.iterrows():
- item_id = str(row.get('item_id', '')).strip()
- if not item_id:
- continue # Skip rows without an item_id
- defaults = {
- 'product_name': row.get('product_name', ''),
- 'product_long_description': row.get('product_long_description', ''),
- 'product_short_description': row.get('product_short_description', ''),
- 'product_type': row.get('product_type', ''),
- 'image_path': row.get('image_path', ''),
- }
- obj, created = Product.objects.update_or_create(
- item_id=item_id,
- defaults=defaults
- )
- if created:
- created_count += 1
- else:
- updated_count += 1
- return Response({
- 'message': f'Upload successful.',
- 'created': f'{created_count} new records added.',
- 'updated': f'{updated_count} existing records updated.'
- }, status=status.HTTP_201_CREATED)
- except Exception as e:
- return Response({'error': str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
- class ProductAttributesUploadView(APIView):
- """
- POST API to upload an Excel file and add mandatory/additional attributes
- for product types with possible values.
- """
- parser_classes = (MultiPartParser, FormParser)
- def post(self, request):
- file_obj = request.FILES.get('file')
- if not file_obj:
- return Response({"error": "No file provided."}, status=status.HTTP_400_BAD_REQUEST)
- try:
- df = pd.read_excel(file_obj)
- required_columns = {'product_type', 'attribute_name', 'is_mandatory', 'possible_values'}
- if not required_columns.issubset(df.columns):
- return Response({
- "error": f"Missing required columns. Found: {list(df.columns)}"
- }, status=status.HTTP_400_BAD_REQUEST)
- for _, row in df.iterrows():
- product_type_name = str(row['product_type']).strip()
- attr_name = str(row['attribute_name']).strip()
- is_mandatory = str(row['is_mandatory']).strip().lower() in ['yes', 'true', '1']
- possible_values = str(row.get('possible_values', '')).strip()
- # Get or create product type
- product_type, _ = ProductType.objects.get_or_create(name=product_type_name)
- # Get or create attribute
- attribute, _ = ProductAttribute.objects.get_or_create(
- product_type=product_type,
- name=attr_name,
- defaults={'is_mandatory': is_mandatory}
- )
- attribute.is_mandatory = is_mandatory
- attribute.save()
- # Handle possible values
- AttributePossibleValue.objects.filter(attribute=attribute).delete()
- if possible_values:
- for val in [v.strip() for v in possible_values.split(',') if v.strip()]:
- AttributePossibleValue.objects.create(attribute=attribute, value=val)
- return Response({"message": "Attributes uploaded successfully."}, status=status.HTTP_201_CREATED)
- except Exception as e:
- return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
- class ProductTypeAttributesView(APIView):
- """
- API to view, create, update, and delete product type attributes and their possible values.
- Also supports dynamic product type creation.
- """
- def get(self, request):
- """
- Retrieve all product types with their attributes and possible values.
- """
- product_types = ProductType.objects.all()
- serializer = ProductTypeSerializer(product_types, many=True)
-
- # Transform the serialized data into the requested format
- result = []
- for pt in serializer.data:
- for attr in pt['attributes']:
- result.append({
- 'product_type': pt['name'],
- 'attribute_name': attr['name'],
- 'is_mandatory': 'Yes' if attr['is_mandatory'] else 'No',
- 'possible_values': ', '.join([pv['value'] for pv in attr['possible_values']])
- })
-
- return Response(result, status=status.HTTP_200_OK)
- def post(self, request):
- """
- Create a new product type or attribute with possible values.
- Expected payload example:
- {
- "product_type": "Hardware Screws",
- "attribute_name": "Material",
- "is_mandatory": "Yes",
- "possible_values": "Steel, Zinc Plated, Stainless Steel"
- }
- """
- try:
- product_type_name = request.data.get('product_type')
- attribute_name = request.data.get('attribute_name', '')
- is_mandatory = request.data.get('is_mandatory', '').lower() in ['yes', 'true', '1']
- possible_values = request.data.get('possible_values', '')
- if not product_type_name:
- return Response({
- "error": "product_type is required"
- }, status=status.HTTP_400_BAD_REQUEST)
- with transaction.atomic():
- # Get or create product type
- product_type, created = ProductType.objects.get_or_create(name=product_type_name)
- if created and not attribute_name:
- return Response({
- "message": f"Product type '{product_type_name}' created successfully",
- "data": {"product_type": product_type_name}
- }, status=status.HTTP_201_CREATED)
- if attribute_name:
- # Create attribute
- attribute, attr_created = ProductAttribute.objects.get_or_create(
- product_type=product_type,
- name=attribute_name,
- defaults={'is_mandatory': is_mandatory}
- )
-
- if not attr_created:
- return Response({
- "error": f"Attribute '{attribute_name}' already exists for product type '{product_type_name}'"
- }, status=status.HTTP_400_BAD_REQUEST)
- # Handle possible values
- if possible_values:
- for val in [v.strip() for v in possible_values.split(',') if v.strip()]:
- AttributePossibleValue.objects.create(attribute=attribute, value=val)
- return Response({
- "message": "Attribute created successfully",
- "data": {
- "product_type": product_type_name,
- "attribute_name": attribute_name,
- "is_mandatory": "Yes" if is_mandatory else "No",
- "possible_values": possible_values
- }
- }, status=status.HTTP_201_CREATED)
- return Response({
- "message": f"Product type '{product_type_name}' already exists",
- "data": {"product_type": product_type_name}
- }, status=status.HTTP_200_OK)
- except Exception as e:
- return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
- def put(self, request):
- """
- Update an existing product type attribute and its possible values.
- Expected payload example:
- {
- "product_type": "Hardware Screws",
- "attribute_name": "Material",
- "is_mandatory": "Yes",
- "possible_values": "Steel, Zinc Plated, Stainless Steel, Brass"
- }
- """
- try:
- product_type_name = request.data.get('product_type')
- attribute_name = request.data.get('attribute_name')
- is_mandatory = request.data.get('is_mandatory', '').lower() in ['yes', 'true', '1']
- possible_values = request.data.get('possible_values', '')
- if not all([product_type_name, attribute_name]):
- return Response({
- "error": "product_type and attribute_name are required"
- }, status=status.HTTP_400_BAD_REQUEST)
- with transaction.atomic():
- try:
- product_type = ProductType.objects.get(name=product_type_name)
- attribute = ProductAttribute.objects.get(
- product_type=product_type,
- name=attribute_name
- )
- except ProductType.DoesNotExist:
- return Response({
- "error": f"Product type '{product_type_name}' not found"
- }, status=status.HTTP_404_NOT_FOUND)
- except ProductAttribute.DoesNotExist:
- return Response({
- "error": f"Attribute '{attribute_name}' not found for product type '{product_type_name}'"
- }, status=status.HTTP_404_NOT_FOUND)
- # Update attribute
- attribute.is_mandatory = is_mandatory
- attribute.save()
- # Update possible values
- AttributePossibleValue.objects.filter(attribute=attribute).delete()
- if possible_values:
- for val in [v.strip() for v in possible_values.split(',') if v.strip()]:
- AttributePossibleValue.objects.create(attribute=attribute, value=val)
- return Response({
- "message": "Attribute updated successfully",
- "data": {
- "product_type": product_type_name,
- "attribute_name": attribute_name,
- "is_mandatory": "Yes" if is_mandatory else "No",
- "possible_values": possible_values
- }
- }, status=status.HTTP_200_OK)
- except Exception as e:
- return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
- def delete(self, request):
- """
- Delete a product type or a specific attribute.
- Expected payload example:
- {
- "product_type": "Hardware Screws",
- "attribute_name": "Material"
- }
- """
- try:
- product_type_name = request.data.get('product_type')
- attribute_name = request.data.get('attribute_name', '')
- if not product_type_name:
- return Response({
- "error": "product_type is required"
- }, status=status.HTTP_400_BAD_REQUEST)
- with transaction.atomic():
- try:
- product_type = ProductType.objects.get(name=product_type_name)
- except ProductType.DoesNotExist:
- return Response({
- "error": f"Product type '{product_type_name}' not found"
- }, status=status.HTTP_404_NOT_FOUND)
- if attribute_name:
- # Delete specific attribute
- try:
- attribute = ProductAttribute.objects.get(
- product_type=product_type,
- name=attribute_name
- )
- attribute.delete()
- return Response({
- "message": f"Attribute '{attribute_name}' deleted successfully from product type '{product_type_name}'"
- }, status=status.HTTP_200_OK)
- except ProductAttribute.DoesNotExist:
- return Response({
- "error": f"Attribute '{attribute_name}' not found for product type '{product_type_name}'"
- }, status=status.HTTP_404_NOT_FOUND)
- else:
- # Delete entire product type
- product_type.delete()
- return Response({
- "message": f"Product type '{product_type_name}' and all its attributes deleted successfully"
- }, status=status.HTTP_200_OK)
- except Exception as e:
- return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
- class ProductTypeListView(APIView):
- """
- GET API to list all product types (only names).
- """
- def get(self, request):
- product_types = ProductType.objects.values_list('name', flat=True)
- return Response({"product_types": list(product_types)}, status=status.HTTP_200_OK)
|