views.py 36 KB


  1. from rest_framework.views import APIView
  2. from rest_framework.response import Response
  3. from rest_framework import status
  4. from rest_framework.parsers import MultiPartParser, FormParser
  5. from django.db import transaction
  6. import pandas as pd
  7. from .models import Product, ProductType, ProductAttribute, AttributePossibleValue
  8. from .serializers import (
  9. SingleProductRequestSerializer,
  10. BatchProductRequestSerializer,
  11. ProductAttributeResultSerializer,
  12. BatchProductResponseSerializer,
  13. ProductSerializer,
  14. ProductTypeSerializer,
  15. ProductAttributeSerializer,
  16. AttributePossibleValueSerializer
  17. )
  18. from .services import ProductAttributeService
  19. from .ocr_service import OCRService
  20. # Sample test images (publicly available)
  21. SAMPLE_IMAGES = {
  22. "tshirt": "https://images.unsplash.com/photo-1521572163474-6864f9cf17ab",
  23. "dress": "https://images.unsplash.com/photo-1595777457583-95e059d581b8",
  24. "jeans": "https://images.unsplash.com/photo-1542272604-787c3835535d"
  25. }
  26. # ==================== Updated views.py ====================
  27. from rest_framework.views import APIView
  28. from rest_framework.response import Response
  29. from rest_framework import status
  30. from .models import Product
  31. from .services import ProductAttributeService
  32. from .ocr_service import OCRService
  33. from .visual_processing_service import VisualProcessingService
  34. class ExtractProductAttributesView(APIView):
  35. """
  36. API endpoint to extract product attributes for a single product by item_id.
  37. Fetches product details from database with source tracking.
  38. Returns attributes in array format: [{"value": "...", "source": "..."}]
  39. Includes OCR and Visual Processing results.
  40. """
  41. def post(self, request):
  42. serializer = SingleProductRequestSerializer(data=request.data)
  43. if not serializer.is_valid():
  44. return Response({"error": serializer.errors}, status=status.HTTP_400_BAD_REQUEST)
  45. validated_data = serializer.validated_data
  46. item_id = validated_data.get("item_id")
  47. # Fetch product from DB
  48. try:
  49. product = Product.objects.get(item_id=item_id)
  50. except Product.DoesNotExist:
  51. return Response(
  52. {"error": f"Product with item_id '{item_id}' not found."},
  53. status=status.HTTP_404_NOT_FOUND
  54. )
  55. # Extract product details
  56. title = product.product_name
  57. short_desc = product.product_short_description
  58. long_desc = product.product_long_description
  59. image_url = product.image_path
  60. # Process image for OCR if required
  61. ocr_results = None
  62. ocr_text = None
  63. visual_results = None
  64. if validated_data.get("process_image", True) and image_url:
  65. # OCR Processing
  66. ocr_service = OCRService()
  67. ocr_results = ocr_service.process_image(image_url)
  68. if ocr_results and ocr_results.get("detected_text"):
  69. ocr_attrs = ProductAttributeService.extract_attributes_from_ocr(
  70. ocr_results, validated_data.get("model")
  71. )
  72. ocr_results["extracted_attributes"] = ocr_attrs
  73. ocr_text = "\n".join([
  74. f"{item['text']} (confidence: {item['confidence']:.2f})"
  75. for item in ocr_results["detected_text"]
  76. ])
  77. # Visual Processing
  78. visual_service = VisualProcessingService()
  79. product_type_hint = product.product_type if hasattr(product, 'product_type') else None
  80. visual_results = visual_service.process_image(image_url, product_type_hint)
  81. # Combine all product text with source tracking
  82. product_text, source_map = ProductAttributeService.combine_product_text(
  83. title=title,
  84. short_desc=short_desc,
  85. long_desc=long_desc,
  86. ocr_text=ocr_text
  87. )
  88. # Extract attributes with enhanced features and source tracking
  89. result = ProductAttributeService.extract_attributes(
  90. product_text=product_text,
  91. mandatory_attrs=validated_data["mandatory_attrs"],
  92. source_map=source_map,
  93. model=validated_data.get("model"),
  94. extract_additional=validated_data.get("extract_additional", True),
  95. multiple=validated_data.get("multiple", []),
  96. threshold_abs=validated_data.get("threshold_abs", 0.65),
  97. margin=validated_data.get("margin", 0.15),
  98. use_dynamic_thresholds=validated_data.get("use_dynamic_thresholds", True),
  99. use_adaptive_margin=validated_data.get("use_adaptive_margin", True),
  100. use_semantic_clustering=validated_data.get("use_semantic_clustering", True)
  101. )
  102. # Attach OCR results if available
  103. if ocr_results:
  104. result["ocr_results"] = ocr_results
  105. # Attach Visual Processing results if available
  106. if visual_results:
  107. result["visual_results"] = visual_results
  108. response_serializer = ProductAttributeResultSerializer(data=result)
  109. if response_serializer.is_valid():
  110. return Response(response_serializer.data, status=status.HTTP_200_OK)
  111. return Response(result, status=status.HTTP_200_OK)
  112. class BatchExtractProductAttributesView(APIView):
  113. """
  114. API endpoint to extract product attributes for multiple products in batch.
  115. Uses item-specific mandatory_attrs with source tracking.
  116. Returns attributes in array format: [{"value": "...", "source": "..."}]
  117. Includes OCR and Visual Processing results.
  118. """
  119. def post(self, request):
  120. serializer = BatchProductRequestSerializer(data=request.data)
  121. if not serializer.is_valid():
  122. return Response({"error": serializer.errors}, status=status.HTTP_400_BAD_REQUEST)
  123. validated_data = serializer.validated_data
  124. # DEBUG: Print what we received
  125. print("\n" + "="*80)
  126. print("BATCH REQUEST - RECEIVED DATA")
  127. print("="*80)
  128. print(f"Raw request data keys: {request.data.keys()}")
  129. print(f"Multiple field in request: {request.data.get('multiple')}")
  130. print(f"Validated multiple field: {validated_data.get('multiple')}")
  131. print("="*80 + "\n")
  132. # Get batch-level settings
  133. product_list = validated_data.get("products", [])
  134. model = validated_data.get("model")
  135. extract_additional = validated_data.get("extract_additional", True)
  136. process_image = validated_data.get("process_image", True)
  137. multiple = validated_data.get("multiple", [])
  138. threshold_abs = validated_data.get("threshold_abs", 0.65)
  139. margin = validated_data.get("margin", 0.15)
  140. use_dynamic_thresholds = validated_data.get("use_dynamic_thresholds", True)
  141. use_adaptive_margin = validated_data.get("use_adaptive_margin", True)
  142. use_semantic_clustering = validated_data.get("use_semantic_clustering", True)
  143. # DEBUG: Print extracted settings
  144. print(f"Extracted multiple parameter: {multiple}")
  145. print(f"Type: {type(multiple)}")
  146. # Extract all item_ids to query the database efficiently
  147. item_ids = [p['item_id'] for p in product_list]
  148. # Fetch all products in one query
  149. products_queryset = Product.objects.filter(item_id__in=item_ids)
  150. # Create a dictionary for easy lookup: item_id -> Product object
  151. product_map = {product.item_id: product for product in products_queryset}
  152. found_ids = set(product_map.keys())
  153. results = []
  154. successful = 0
  155. failed = 0
  156. for product_entry in product_list:
  157. item_id = product_entry['item_id']
  158. # Get item-specific mandatory attributes
  159. mandatory_attrs = product_entry['mandatory_attrs']
  160. if item_id not in found_ids:
  161. failed += 1
  162. results.append({
  163. "product_id": item_id,
  164. "error": "Product not found in database"
  165. })
  166. continue
  167. product = product_map[item_id]
  168. try:
  169. title = product.product_name
  170. short_desc = product.product_short_description
  171. long_desc = product.product_long_description
  172. image_url = product.image_path
  173. # image_url = "https://images.unsplash.com/photo-1595777457583-95e059d581b8"
  174. ocr_results = None
  175. ocr_text = None
  176. visual_results = None
  177. # Image Processing Logic
  178. if process_image and image_url:
  179. # OCR Processing
  180. ocr_service = OCRService()
  181. ocr_results = ocr_service.process_image(image_url)
  182. print(f"OCR results for {item_id}: {ocr_results}")
  183. if ocr_results and ocr_results.get("detected_text"):
  184. ocr_attrs = ProductAttributeService.extract_attributes_from_ocr(
  185. ocr_results, model
  186. )
  187. ocr_results["extracted_attributes"] = ocr_attrs
  188. ocr_text = "\n".join([
  189. f"{item['text']} (confidence: {item['confidence']:.2f})"
  190. for item in ocr_results["detected_text"]
  191. ])
  192. # Visual Processing
  193. visual_service = VisualProcessingService()
  194. product_type_hint = product.product_type if hasattr(product, 'product_type') else None
  195. visual_results = visual_service.process_image(image_url, product_type_hint)
  196. print(f"Visual results for {item_id}: {visual_results.get('visual_attributes', {})}")
  197. # Combine product text with source tracking
  198. product_text, source_map = ProductAttributeService.combine_product_text(
  199. title=title,
  200. short_desc=short_desc,
  201. long_desc=long_desc,
  202. ocr_text=ocr_text
  203. )
  204. # DEBUG: Print before extraction
  205. print(f"\n>>> Extracting for product {item_id}")
  206. print(f" Passing multiple: {multiple}")
  207. # Attribute Extraction with source tracking (returns array format)
  208. extracted = ProductAttributeService.extract_attributes(
  209. product_text=product_text,
  210. mandatory_attrs=mandatory_attrs,
  211. source_map=source_map,
  212. model=model,
  213. extract_additional=extract_additional,
  214. multiple=multiple,
  215. threshold_abs=threshold_abs,
  216. margin=margin,
  217. use_dynamic_thresholds=use_dynamic_thresholds,
  218. use_adaptive_margin=use_adaptive_margin,
  219. use_semantic_clustering=use_semantic_clustering
  220. )
  221. result = {
  222. "product_id": product.item_id,
  223. "mandatory": extracted.get("mandatory", {}),
  224. "additional": extracted.get("additional", {}),
  225. }
  226. # Attach OCR results if available
  227. if ocr_results:
  228. result["ocr_results"] = ocr_results
  229. # Attach Visual Processing results if available
  230. if visual_results:
  231. result["visual_results"] = visual_results
  232. results.append(result)
  233. successful += 1
  234. except Exception as e:
  235. failed += 1
  236. results.append({
  237. "product_id": item_id,
  238. "error": str(e)
  239. })
  240. batch_result = {
  241. "results": results,
  242. "total_products": len(product_list),
  243. "successful": successful,
  244. "failed": failed
  245. }
  246. response_serializer = BatchProductResponseSerializer(data=batch_result)
  247. if response_serializer.is_valid():
  248. return Response(response_serializer.data, status=status.HTTP_200_OK)
  249. return Response(batch_result, status=status.HTTP_200_OK)
  250. # class ExtractProductAttributesView(APIView):
  251. # """
  252. # API endpoint to extract product attributes for a single product by item_id.
  253. # Fetches product details from database with source tracking.
  254. # Returns attributes in array format: [{"value": "...", "source": "..."}]
  255. # """
  256. # def post(self, request):
  257. # serializer = SingleProductRequestSerializer(data=request.data)
  258. # if not serializer.is_valid():
  259. # return Response({"error": serializer.errors}, status=status.HTTP_400_BAD_REQUEST)
  260. # validated_data = serializer.validated_data
  261. # item_id = validated_data.get("item_id")
  262. # # Fetch product from DB
  263. # try:
  264. # product = Product.objects.get(item_id=item_id)
  265. # except Product.DoesNotExist:
  266. # return Response(
  267. # {"error": f"Product with item_id '{item_id}' not found."},
  268. # status=status.HTTP_404_NOT_FOUND
  269. # )
  270. # # Extract product details
  271. # title = product.product_name
  272. # short_desc = product.product_short_description
  273. # long_desc = product.product_long_description
  274. # image_url = product.image_path
  275. # # Process image for OCR if required
  276. # ocr_results = None
  277. # ocr_text = None
  278. # if validated_data.get("process_image", True) and image_url:
  279. # ocr_service = OCRService()
  280. # ocr_results = ocr_service.process_image(image_url)
  281. # if ocr_results and ocr_results.get("detected_text"):
  282. # ocr_attrs = ProductAttributeService.extract_attributes_from_ocr(
  283. # ocr_results, validated_data.get("model")
  284. # )
  285. # ocr_results["extracted_attributes"] = ocr_attrs
  286. # ocr_text = "\n".join([
  287. # f"{item['text']} (confidence: {item['confidence']:.2f})"
  288. # for item in ocr_results["detected_text"]
  289. # ])
  290. # # Combine all product text with source tracking
  291. # product_text, source_map = ProductAttributeService.combine_product_text(
  292. # title=title,
  293. # short_desc=short_desc,
  294. # long_desc=long_desc,
  295. # ocr_text=ocr_text
  296. # )
  297. # # Extract attributes with enhanced features and source tracking
  298. # result = ProductAttributeService.extract_attributes(
  299. # product_text=product_text,
  300. # mandatory_attrs=validated_data["mandatory_attrs"],
  301. # source_map=source_map,
  302. # model=validated_data.get("model"),
  303. # extract_additional=validated_data.get("extract_additional", True),
  304. # multiple=validated_data.get("multiple", []),
  305. # threshold_abs=validated_data.get("threshold_abs", 0.65),
  306. # margin=validated_data.get("margin", 0.15),
  307. # use_dynamic_thresholds=validated_data.get("use_dynamic_thresholds", True),
  308. # use_adaptive_margin=validated_data.get("use_adaptive_margin", True),
  309. # use_semantic_clustering=validated_data.get("use_semantic_clustering", True)
  310. # )
  311. # # Attach OCR results if available
  312. # if ocr_results:
  313. # result["ocr_results"] = ocr_results
  314. # response_serializer = ProductAttributeResultSerializer(data=result)
  315. # if response_serializer.is_valid():
  316. # return Response(response_serializer.data, status=status.HTTP_200_OK)
  317. # return Response(result, status=status.HTTP_200_OK)
  318. # class BatchExtractProductAttributesView(APIView):
  319. # """
  320. # API endpoint to extract product attributes for multiple products in batch.
  321. # Uses item-specific mandatory_attrs with source tracking.
  322. # Returns attributes in array format: [{"value": "...", "source": "..."}]
  323. # """
  324. # def post(self, request):
  325. # serializer = BatchProductRequestSerializer(data=request.data)
  326. # if not serializer.is_valid():
  327. # return Response({"error": serializer.errors}, status=status.HTTP_400_BAD_REQUEST)
  328. # validated_data = serializer.validated_data
  329. # # DEBUG: Print what we received
  330. # print("\n" + "="*80)
  331. # print("BATCH REQUEST - RECEIVED DATA")
  332. # print("="*80)
  333. # print(f"Raw request data keys: {request.data.keys()}")
  334. # print(f"Multiple field in request: {request.data.get('multiple')}")
  335. # print(f"Validated multiple field: {validated_data.get('multiple')}")
  336. # print("="*80 + "\n")
  337. # # Get batch-level settings
  338. # product_list = validated_data.get("products", [])
  339. # model = validated_data.get("model")
  340. # extract_additional = validated_data.get("extract_additional", True)
  341. # process_image = validated_data.get("process_image", True)
  342. # multiple = validated_data.get("multiple", [])
  343. # threshold_abs = validated_data.get("threshold_abs", 0.65)
  344. # margin = validated_data.get("margin", 0.15)
  345. # use_dynamic_thresholds = validated_data.get("use_dynamic_thresholds", True)
  346. # use_adaptive_margin = validated_data.get("use_adaptive_margin", True)
  347. # use_semantic_clustering = validated_data.get("use_semantic_clustering", True)
  348. # # DEBUG: Print extracted settings
  349. # print(f"Extracted multiple parameter: {multiple}")
  350. # print(f"Type: {type(multiple)}")
  351. # # Extract all item_ids to query the database efficiently
  352. # item_ids = [p['item_id'] for p in product_list]
  353. # # Fetch all products in one query
  354. # products_queryset = Product.objects.filter(item_id__in=item_ids)
  355. # # Create a dictionary for easy lookup: item_id -> Product object
  356. # product_map = {product.item_id: product for product in products_queryset}
  357. # found_ids = set(product_map.keys())
  358. # results = []
  359. # successful = 0
  360. # failed = 0
  361. # for product_entry in product_list:
  362. # item_id = product_entry['item_id']
  363. # # Get item-specific mandatory attributes
  364. # mandatory_attrs = product_entry['mandatory_attrs']
  365. # if item_id not in found_ids:
  366. # failed += 1
  367. # results.append({
  368. # "product_id": item_id,
  369. # "error": "Product not found in database"
  370. # })
  371. # continue
  372. # product = product_map[item_id]
  373. # try:
  374. # title = product.product_name
  375. # short_desc = product.product_short_description
  376. # long_desc = product.product_long_description
  377. # # image_url = product.image_path
  378. # image_url = "http://localhost:8000/media/products/levi_test_ocr2.jpg"
  379. # ocr_results = None
  380. # ocr_text = None
  381. # # Image Processing Logic
  382. # if process_image and image_url:
  383. # ocr_service = OCRService()
  384. # ocr_results = ocr_service.process_image(image_url)
  385. # print(f"ocr results are: {ocr_results}")
  386. # if ocr_results and ocr_results.get("detected_text"):
  387. # ocr_attrs = ProductAttributeService.extract_attributes_from_ocr(
  388. # ocr_results, model
  389. # )
  390. # ocr_results["extracted_attributes"] = ocr_attrs
  391. # ocr_text = "\n".join([
  392. # f"{item['text']} (confidence: {item['confidence']:.2f})"
  393. # for item in ocr_results["detected_text"]
  394. # ])
  395. # # Combine product text with source tracking
  396. # product_text, source_map = ProductAttributeService.combine_product_text(
  397. # title=title,
  398. # short_desc=short_desc,
  399. # long_desc=long_desc,
  400. # ocr_text=ocr_text
  401. # )
  402. # # DEBUG: Print before extraction
  403. # print(f"\n>>> Extracting for product {item_id}")
  404. # print(f" Passing multiple: {multiple}")
  405. # # Attribute Extraction with source tracking (returns array format)
  406. # extracted = ProductAttributeService.extract_attributes(
  407. # product_text=product_text,
  408. # mandatory_attrs=mandatory_attrs,
  409. # source_map=source_map,
  410. # model=model,
  411. # extract_additional=extract_additional,
  412. # multiple=multiple, # Make sure this is passed!
  413. # threshold_abs=threshold_abs,
  414. # margin=margin,
  415. # use_dynamic_thresholds=use_dynamic_thresholds,
  416. # use_adaptive_margin=use_adaptive_margin,
  417. # use_semantic_clustering=use_semantic_clustering
  418. # )
  419. # result = {
  420. # "product_id": product.item_id,
  421. # "mandatory": extracted.get("mandatory", {}),
  422. # "additional": extracted.get("additional", {}),
  423. # }
  424. # if ocr_results:
  425. # result["ocr_results"] = ocr_results
  426. # results.append(result)
  427. # successful += 1
  428. # except Exception as e:
  429. # failed += 1
  430. # results.append({
  431. # "product_id": item_id,
  432. # "error": str(e)
  433. # })
  434. # batch_result = {
  435. # "results": results,
  436. # "total_products": len(product_list),
  437. # "successful": successful,
  438. # "failed": failed
  439. # }
  440. # response_serializer = BatchProductResponseSerializer(data=batch_result)
  441. # if response_serializer.is_valid():
  442. # return Response(response_serializer.data, status=status.HTTP_200_OK)
  443. # return Response(batch_result, status=status.HTTP_200_OK)
  444. class ProductListView(APIView):
  445. """
  446. GET API to list all products with details
  447. """
  448. def get(self, request):
  449. products = Product.objects.all()
  450. serializer = ProductSerializer(products, many=True)
  451. return Response(serializer.data, status=status.HTTP_200_OK)
  452. class ProductUploadExcelView(APIView):
  453. """
  454. POST API to upload an Excel file and add data to Product model (skip duplicates)
  455. """
  456. parser_classes = (MultiPartParser, FormParser)
  457. def post(self, request, *args, **kwargs):
  458. file_obj = request.FILES.get('file')
  459. if not file_obj:
  460. return Response({'error': 'No file provided'}, status=status.HTTP_400_BAD_REQUEST)
  461. try:
  462. df = pd.read_excel(file_obj)
  463. df.columns = [c.strip().lower().replace(' ', '_') for c in df.columns]
  464. expected_cols = {
  465. 'item_id',
  466. 'product_name',
  467. 'product_long_description',
  468. 'product_short_description',
  469. 'product_type',
  470. 'image_path'
  471. }
  472. if not expected_cols.issubset(df.columns):
  473. return Response({
  474. 'error': 'Missing required columns',
  475. 'required_columns': list(expected_cols)
  476. }, status=status.HTTP_400_BAD_REQUEST)
  477. created_count = 0
  478. skipped_count = 0
  479. for _, row in df.iterrows():
  480. item_id = row.get('item_id', '')
  481. # Check if this item already exists
  482. if Product.objects.filter(item_id=item_id).exists():
  483. skipped_count += 1
  484. continue
  485. Product.objects.create(
  486. item_id=item_id,
  487. product_name=row.get('product_name', ''),
  488. product_long_description=row.get('product_long_description', ''),
  489. product_short_description=row.get('product_short_description', ''),
  490. product_type=row.get('product_type', ''),
  491. image_path=row.get('image_path', ''),
  492. )
  493. created_count += 1
  494. return Response({
  495. 'message': f'Successfully uploaded {created_count} products.',
  496. 'skipped': f'Skipped {skipped_count} duplicates.'
  497. }, status=status.HTTP_201_CREATED)
  498. except Exception as e:
  499. return Response({'error': str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
  500. class ProductAttributesUploadView(APIView):
  501. """
  502. POST API to upload an Excel file and add mandatory/additional attributes
  503. for product types with possible values.
  504. """
  505. parser_classes = (MultiPartParser, FormParser)
  506. def post(self, request):
  507. file_obj = request.FILES.get('file')
  508. if not file_obj:
  509. return Response({"error": "No file provided."}, status=status.HTTP_400_BAD_REQUEST)
  510. try:
  511. df = pd.read_excel(file_obj)
  512. required_columns = {'product_type', 'attribute_name', 'is_mandatory', 'possible_values'}
  513. if not required_columns.issubset(df.columns):
  514. return Response({
  515. "error": f"Missing required columns. Found: {list(df.columns)}"
  516. }, status=status.HTTP_400_BAD_REQUEST)
  517. for _, row in df.iterrows():
  518. product_type_name = str(row['product_type']).strip()
  519. attr_name = str(row['attribute_name']).strip()
  520. is_mandatory = str(row['is_mandatory']).strip().lower() in ['yes', 'true', '1']
  521. possible_values = str(row.get('possible_values', '')).strip()
  522. # Get or create product type
  523. product_type, _ = ProductType.objects.get_or_create(name=product_type_name)
  524. # Get or create attribute
  525. attribute, _ = ProductAttribute.objects.get_or_create(
  526. product_type=product_type,
  527. name=attr_name,
  528. defaults={'is_mandatory': is_mandatory}
  529. )
  530. attribute.is_mandatory = is_mandatory
  531. attribute.save()
  532. # Handle possible values
  533. AttributePossibleValue.objects.filter(attribute=attribute).delete()
  534. if possible_values:
  535. for val in [v.strip() for v in possible_values.split(',') if v.strip()]:
  536. AttributePossibleValue.objects.create(attribute=attribute, value=val)
  537. return Response({"message": "Attributes uploaded successfully."}, status=status.HTTP_201_CREATED)
  538. except Exception as e:
  539. return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
  540. class ProductTypeAttributesView(APIView):
  541. """
  542. API to view, create, update, and delete product type attributes and their possible values.
  543. Also supports dynamic product type creation.
  544. """
  545. def get(self, request):
  546. """
  547. Retrieve all product types with their attributes and possible values.
  548. """
  549. product_types = ProductType.objects.all()
  550. serializer = ProductTypeSerializer(product_types, many=True)
  551. # Transform the serialized data into the requested format
  552. result = []
  553. for pt in serializer.data:
  554. for attr in pt['attributes']:
  555. result.append({
  556. 'product_type': pt['name'],
  557. 'attribute_name': attr['name'],
  558. 'is_mandatory': 'Yes' if attr['is_mandatory'] else 'No',
  559. 'possible_values': ', '.join([pv['value'] for pv in attr['possible_values']])
  560. })
  561. return Response(result, status=status.HTTP_200_OK)
  562. def post(self, request):
  563. """
  564. Create a new product type or attribute with possible values.
  565. Expected payload example:
  566. {
  567. "product_type": "Hardware Screws",
  568. "attribute_name": "Material",
  569. "is_mandatory": "Yes",
  570. "possible_values": "Steel, Zinc Plated, Stainless Steel"
  571. }
  572. """
  573. try:
  574. product_type_name = request.data.get('product_type')
  575. attribute_name = request.data.get('attribute_name', '')
  576. is_mandatory = request.data.get('is_mandatory', '').lower() in ['yes', 'true', '1']
  577. possible_values = request.data.get('possible_values', '')
  578. if not product_type_name:
  579. return Response({
  580. "error": "product_type is required"
  581. }, status=status.HTTP_400_BAD_REQUEST)
  582. with transaction.atomic():
  583. # Get or create product type
  584. product_type, created = ProductType.objects.get_or_create(name=product_type_name)
  585. if created and not attribute_name:
  586. return Response({
  587. "message": f"Product type '{product_type_name}' created successfully",
  588. "data": {"product_type": product_type_name}
  589. }, status=status.HTTP_201_CREATED)
  590. if attribute_name:
  591. # Create attribute
  592. attribute, attr_created = ProductAttribute.objects.get_or_create(
  593. product_type=product_type,
  594. name=attribute_name,
  595. defaults={'is_mandatory': is_mandatory}
  596. )
  597. if not attr_created:
  598. return Response({
  599. "error": f"Attribute '{attribute_name}' already exists for product type '{product_type_name}'"
  600. }, status=status.HTTP_400_BAD_REQUEST)
  601. # Handle possible values
  602. if possible_values:
  603. for val in [v.strip() for v in possible_values.split(',') if v.strip()]:
  604. AttributePossibleValue.objects.create(attribute=attribute, value=val)
  605. return Response({
  606. "message": "Attribute created successfully",
  607. "data": {
  608. "product_type": product_type_name,
  609. "attribute_name": attribute_name,
  610. "is_mandatory": "Yes" if is_mandatory else "No",
  611. "possible_values": possible_values
  612. }
  613. }, status=status.HTTP_201_CREATED)
  614. return Response({
  615. "message": f"Product type '{product_type_name}' already exists",
  616. "data": {"product_type": product_type_name}
  617. }, status=status.HTTP_200_OK)
  618. except Exception as e:
  619. return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
  620. def put(self, request):
  621. """
  622. Update an existing product type attribute and its possible values.
  623. Expected payload example:
  624. {
  625. "product_type": "Hardware Screws",
  626. "attribute_name": "Material",
  627. "is_mandatory": "Yes",
  628. "possible_values": "Steel, Zinc Plated, Stainless Steel, Brass"
  629. }
  630. """
  631. try:
  632. product_type_name = request.data.get('product_type')
  633. attribute_name = request.data.get('attribute_name')
  634. is_mandatory = request.data.get('is_mandatory', '').lower() in ['yes', 'true', '1']
  635. possible_values = request.data.get('possible_values', '')
  636. if not all([product_type_name, attribute_name]):
  637. return Response({
  638. "error": "product_type and attribute_name are required"
  639. }, status=status.HTTP_400_BAD_REQUEST)
  640. with transaction.atomic():
  641. try:
  642. product_type = ProductType.objects.get(name=product_type_name)
  643. attribute = ProductAttribute.objects.get(
  644. product_type=product_type,
  645. name=attribute_name
  646. )
  647. except ProductType.DoesNotExist:
  648. return Response({
  649. "error": f"Product type '{product_type_name}' not found"
  650. }, status=status.HTTP_404_NOT_FOUND)
  651. except ProductAttribute.DoesNotExist:
  652. return Response({
  653. "error": f"Attribute '{attribute_name}' not found for product type '{product_type_name}'"
  654. }, status=status.HTTP_404_NOT_FOUND)
  655. # Update attribute
  656. attribute.is_mandatory = is_mandatory
  657. attribute.save()
  658. # Update possible values
  659. AttributePossibleValue.objects.filter(attribute=attribute).delete()
  660. if possible_values:
  661. for val in [v.strip() for v in possible_values.split(',') if v.strip()]:
  662. AttributePossibleValue.objects.create(attribute=attribute, value=val)
  663. return Response({
  664. "message": "Attribute updated successfully",
  665. "data": {
  666. "product_type": product_type_name,
  667. "attribute_name": attribute_name,
  668. "is_mandatory": "Yes" if is_mandatory else "No",
  669. "possible_values": possible_values
  670. }
  671. }, status=status.HTTP_200_OK)
  672. except Exception as e:
  673. return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
  674. def delete(self, request):
  675. """
  676. Delete a product type or a specific attribute.
  677. Expected payload example:
  678. {
  679. "product_type": "Hardware Screws",
  680. "attribute_name": "Material"
  681. }
  682. """
  683. try:
  684. product_type_name = request.data.get('product_type')
  685. attribute_name = request.data.get('attribute_name', '')
  686. if not product_type_name:
  687. return Response({
  688. "error": "product_type is required"
  689. }, status=status.HTTP_400_BAD_REQUEST)
  690. with transaction.atomic():
  691. try:
  692. product_type = ProductType.objects.get(name=product_type_name)
  693. except ProductType.DoesNotExist:
  694. return Response({
  695. "error": f"Product type '{product_type_name}' not found"
  696. }, status=status.HTTP_404_NOT_FOUND)
  697. if attribute_name:
  698. # Delete specific attribute
  699. try:
  700. attribute = ProductAttribute.objects.get(
  701. product_type=product_type,
  702. name=attribute_name
  703. )
  704. attribute.delete()
  705. return Response({
  706. "message": f"Attribute '{attribute_name}' deleted successfully from product type '{product_type_name}'"
  707. }, status=status.HTTP_200_OK)
  708. except ProductAttribute.DoesNotExist:
  709. return Response({
  710. "error": f"Attribute '{attribute_name}' not found for product type '{product_type_name}'"
  711. }, status=status.HTTP_404_NOT_FOUND)
  712. else:
  713. # Delete entire product type
  714. product_type.delete()
  715. return Response({
  716. "message": f"Product type '{product_type_name}' and all its attributes deleted successfully"
  717. }, status=status.HTTP_200_OK)
  718. except Exception as e:
  719. return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
  720. class ProductTypeListView(APIView):
  721. """
  722. GET API to list all product types (only names).
  723. """
  724. def get(self, request):
  725. product_types = ProductType.objects.values_list('name', flat=True)
  726. return Response({"product_types": list(product_types)}, status=status.HTTP_200_OK)