views.py 46 KB


  1. from rest_framework.views import APIView
  2. from rest_framework.response import Response
  3. from rest_framework import status
  4. from rest_framework.parsers import MultiPartParser, FormParser
  5. from django.db import transaction
  6. import pandas as pd
  7. from .models import Product, ProductType, ProductAttribute, AttributePossibleValue
  8. from .serializers import (
  9. SingleProductRequestSerializer,
  10. BatchProductRequestSerializer,
  11. ProductAttributeResultSerializer,
  12. BatchProductResponseSerializer,
  13. ProductSerializer,
  14. ProductTypeSerializer,
  15. ProductAttributeSerializer,
  16. AttributePossibleValueSerializer
  17. )
  18. from .services import ProductAttributeService
  19. from .ocr_service import OCRService
  20. # Sample test images (publicly available)
  21. SAMPLE_IMAGES = {
  22. "tshirt": "https://images.unsplash.com/photo-1521572163474-6864f9cf17ab",
  23. "dress": "https://images.unsplash.com/photo-1595777457583-95e059d581b8",
  24. "jeans": "https://images.unsplash.com/photo-1542272604-787c3835535d"
  25. }
  26. # ==================== Updated views.py ====================
  27. from rest_framework.views import APIView
  28. from rest_framework.response import Response
  29. from rest_framework import status
  30. from .models import Product
  31. from .services import ProductAttributeService
  32. from .ocr_service import OCRService
  33. from .visual_processing_service import VisualProcessingService
  34. class ExtractProductAttributesView(APIView):
  35. """
  36. API endpoint to extract product attributes for a single product by item_id.
  37. Fetches product details from database with source tracking.
  38. Returns attributes in array format: [{"value": "...", "source": "..."}]
  39. Includes OCR and Visual Processing results.
  40. """
  41. def post(self, request):
  42. serializer = SingleProductRequestSerializer(data=request.data)
  43. if not serializer.is_valid():
  44. return Response({"error": serializer.errors}, status=status.HTTP_400_BAD_REQUEST)
  45. validated_data = serializer.validated_data
  46. item_id = validated_data.get("item_id")
  47. # Fetch product from DB
  48. try:
  49. product = Product.objects.get(item_id=item_id)
  50. except Product.DoesNotExist:
  51. return Response(
  52. {"error": f"Product with item_id '{item_id}' not found."},
  53. status=status.HTTP_404_NOT_FOUND
  54. )
  55. # Extract product details
  56. title = product.product_name
  57. short_desc = product.product_short_description
  58. long_desc = product.product_long_description
  59. image_url = product.image_path
  60. # Process image for OCR if required
  61. ocr_results = None
  62. ocr_text = None
  63. visual_results = None
  64. if validated_data.get("process_image", True) and image_url:
  65. # OCR Processing
  66. ocr_service = OCRService()
  67. ocr_results = ocr_service.process_image(image_url)
  68. if ocr_results and ocr_results.get("detected_text"):
  69. ocr_attrs = ProductAttributeService.extract_attributes_from_ocr(
  70. ocr_results, validated_data.get("model")
  71. )
  72. ocr_results["extracted_attributes"] = ocr_attrs
  73. ocr_text = "\n".join([
  74. f"{item['text']} (confidence: {item['confidence']:.2f})"
  75. for item in ocr_results["detected_text"]
  76. ])
  77. # Visual Processing
  78. visual_service = VisualProcessingService()
  79. product_type_hint = product.product_type if hasattr(product, 'product_type') else None
  80. visual_results = visual_service.process_image(image_url, product_type_hint)
  81. # Combine all product text with source tracking
  82. product_text, source_map = ProductAttributeService.combine_product_text(
  83. title=title,
  84. short_desc=short_desc,
  85. long_desc=long_desc,
  86. ocr_text=ocr_text
  87. )
  88. # Extract attributes with enhanced features and source tracking
  89. result = ProductAttributeService.extract_attributes(
  90. product_text=product_text,
  91. mandatory_attrs=validated_data["mandatory_attrs"],
  92. source_map=source_map,
  93. model=validated_data.get("model"),
  94. extract_additional=validated_data.get("extract_additional", True),
  95. multiple=validated_data.get("multiple", []),
  96. threshold_abs=validated_data.get("threshold_abs", 0.65),
  97. margin=validated_data.get("margin", 0.15),
  98. use_dynamic_thresholds=validated_data.get("use_dynamic_thresholds", True),
  99. use_adaptive_margin=validated_data.get("use_adaptive_margin", True),
  100. use_semantic_clustering=validated_data.get("use_semantic_clustering", True)
  101. )
  102. # Attach OCR results if available
  103. if ocr_results:
  104. result["ocr_results"] = ocr_results
  105. # Attach Visual Processing results if available
  106. if visual_results:
  107. result["visual_results"] = visual_results
  108. response_serializer = ProductAttributeResultSerializer(data=result)
  109. if response_serializer.is_valid():
  110. return Response(response_serializer.data, status=status.HTTP_200_OK)
  111. return Response(result, status=status.HTTP_200_OK)
  112. # class BatchExtractProductAttributesView(APIView):
  113. # """
  114. # API endpoint to extract product attributes for multiple products in batch.
  115. # Uses item-specific mandatory_attrs with source tracking.
  116. # Returns attributes in array format: [{"value": "...", "source": "..."}]
  117. # Includes OCR and Visual Processing results.
  118. # """
  119. # def post(self, request):
  120. # serializer = BatchProductRequestSerializer(data=request.data)
  121. # if not serializer.is_valid():
  122. # return Response({"error": serializer.errors}, status=status.HTTP_400_BAD_REQUEST)
  123. # validated_data = serializer.validated_data
  124. # # DEBUG: Print what we received
  125. # print("\n" + "="*80)
  126. # print("BATCH REQUEST - RECEIVED DATA")
  127. # print("="*80)
  128. # print(f"Raw request data keys: {request.data.keys()}")
  129. # print(f"Multiple field in request: {request.data.get('multiple')}")
  130. # print(f"Validated multiple field: {validated_data.get('multiple')}")
  131. # print("="*80 + "\n")
  132. # # Get batch-level settings
  133. # product_list = validated_data.get("products", [])
  134. # model = validated_data.get("model")
  135. # extract_additional = validated_data.get("extract_additional", True)
  136. # process_image = validated_data.get("process_image", True)
  137. # multiple = validated_data.get("multiple", [])
  138. # threshold_abs = validated_data.get("threshold_abs", 0.65)
  139. # margin = validated_data.get("margin", 0.15)
  140. # use_dynamic_thresholds = validated_data.get("use_dynamic_thresholds", True)
  141. # use_adaptive_margin = validated_data.get("use_adaptive_margin", True)
  142. # use_semantic_clustering = validated_data.get("use_semantic_clustering", True)
  143. # # DEBUG: Print extracted settings
  144. # print(f"Extracted multiple parameter: {multiple}")
  145. # print(f"Type: {type(multiple)}")
  146. # # Extract all item_ids to query the database efficiently
  147. # item_ids = [p['item_id'] for p in product_list]
  148. # # Fetch all products in one query
  149. # products_queryset = Product.objects.filter(item_id__in=item_ids)
  150. # # Create a dictionary for easy lookup: item_id -> Product object
  151. # product_map = {product.item_id: product for product in products_queryset}
  152. # found_ids = set(product_map.keys())
  153. # results = []
  154. # successful = 0
  155. # failed = 0
  156. # for product_entry in product_list:
  157. # item_id = product_entry['item_id']
  158. # # Get item-specific mandatory attributes
  159. # mandatory_attrs = product_entry['mandatory_attrs']
  160. # if item_id not in found_ids:
  161. # failed += 1
  162. # results.append({
  163. # "product_id": item_id,
  164. # "error": "Product not found in database"
  165. # })
  166. # continue
  167. # product = product_map[item_id]
  168. # try:
  169. # title = product.product_name
  170. # short_desc = product.product_short_description
  171. # long_desc = product.product_long_description
  172. # image_url = product.image_path
  173. # # image_url = "https://images.unsplash.com/photo-1595777457583-95e059d581b8"
  174. # ocr_results = None
  175. # ocr_text = None
  176. # visual_results = None
  177. # # Image Processing Logic
  178. # if process_image and image_url:
  179. # # OCR Processing
  180. # ocr_service = OCRService()
  181. # ocr_results = ocr_service.process_image(image_url)
  182. # print(f"OCR results for {item_id}: {ocr_results}")
  183. # if ocr_results and ocr_results.get("detected_text"):
  184. # ocr_attrs = ProductAttributeService.extract_attributes_from_ocr(
  185. # ocr_results, model
  186. # )
  187. # ocr_results["extracted_attributes"] = ocr_attrs
  188. # ocr_text = "\n".join([
  189. # f"{item['text']} (confidence: {item['confidence']:.2f})"
  190. # for item in ocr_results["detected_text"]
  191. # ])
  192. # # Visual Processing
  193. # visual_service = VisualProcessingService()
  194. # product_type_hint = product.product_type if hasattr(product, 'product_type') else None
  195. # visual_results = visual_service.process_image(image_url, product_type_hint)
  196. # print(f"Visual results for {item_id}: {visual_results.get('visual_attributes', {})}")
  197. # # Combine product text with source tracking
  198. # product_text, source_map = ProductAttributeService.combine_product_text(
  199. # title=title,
  200. # short_desc=short_desc,
  201. # long_desc=long_desc,
  202. # ocr_text=ocr_text
  203. # )
  204. # # DEBUG: Print before extraction
  205. # print(f"\n>>> Extracting for product {item_id}")
  206. # print(f" Passing multiple: {multiple}")
  207. # # Attribute Extraction with source tracking (returns array format)
  208. # extracted = ProductAttributeService.extract_attributes(
  209. # product_text=product_text,
  210. # mandatory_attrs=mandatory_attrs,
  211. # source_map=source_map,
  212. # model=model,
  213. # extract_additional=extract_additional,
  214. # multiple=multiple,
  215. # threshold_abs=threshold_abs,
  216. # margin=margin,
  217. # use_dynamic_thresholds=use_dynamic_thresholds,
  218. # use_adaptive_margin=use_adaptive_margin,
  219. # use_semantic_clustering=use_semantic_clustering
  220. # )
  221. # result = {
  222. # "product_id": product.item_id,
  223. # "mandatory": extracted.get("mandatory", {}),
  224. # "additional": extracted.get("additional", {}),
  225. # }
  226. # # Attach OCR results if available
  227. # if ocr_results:
  228. # result["ocr_results"] = ocr_results
  229. # # Attach Visual Processing results if available
  230. # if visual_results:
  231. # result["visual_results"] = visual_results
  232. # results.append(result)
  233. # successful += 1
  234. # except Exception as e:
  235. # failed += 1
  236. # results.append({
  237. # "product_id": item_id,
  238. # "error": str(e)
  239. # })
  240. # batch_result = {
  241. # "results": results,
  242. # "total_products": len(product_list),
  243. # "successful": successful,
  244. # "failed": failed
  245. # }
  246. # response_serializer = BatchProductResponseSerializer(data=batch_result)
  247. # if response_serializer.is_valid():
  248. # return Response(response_serializer.data, status=status.HTTP_200_OK)
  249. # return Response(batch_result, status=status.HTTP_200_OK)
  250. class BatchExtractProductAttributesView(APIView):
  251. """
  252. API endpoint to extract product attributes for multiple products in batch.
  253. Uses item-specific mandatory_attrs with source tracking.
  254. Returns attributes in array format: [{"value": "...", "source": "..."}]
  255. Includes OCR and Visual Processing results.
  256. """
  257. def post(self, request):
  258. serializer = BatchProductRequestSerializer(data=request.data)
  259. if not serializer.is_valid():
  260. return Response({"error": serializer.errors}, status=status.HTTP_400_BAD_REQUEST)
  261. validated_data = serializer.validated_data
  262. # DEBUG: Print what we received
  263. print("\n" + "="*80)
  264. print("BATCH REQUEST - RECEIVED DATA")
  265. print("="*80)
  266. print(f"Raw request data keys: {request.data.keys()}")
  267. print(f"Multiple field in request: {request.data.get('multiple')}")
  268. print(f"Validated multiple field: {validated_data.get('multiple')}")
  269. print("="*80 + "\n")
  270. # Get batch-level settings
  271. product_list = validated_data.get("products", [])
  272. model = validated_data.get("model")
  273. extract_additional = validated_data.get("extract_additional", True)
  274. process_image = validated_data.get("process_image", True)
  275. multiple = validated_data.get("multiple", [])
  276. threshold_abs = validated_data.get("threshold_abs", 0.65)
  277. margin = validated_data.get("margin", 0.15)
  278. use_dynamic_thresholds = validated_data.get("use_dynamic_thresholds", True)
  279. use_adaptive_margin = validated_data.get("use_adaptive_margin", True)
  280. use_semantic_clustering = validated_data.get("use_semantic_clustering", True)
  281. # DEBUG: Print extracted settings
  282. print(f"Extracted multiple parameter: {multiple}")
  283. print(f"Type: {type(multiple)}")
  284. # Extract all item_ids to query the database efficiently
  285. item_ids = [p['item_id'] for p in product_list]
  286. # Fetch all products in one query
  287. products_queryset = Product.objects.filter(item_id__in=item_ids)
  288. # Create a dictionary for easy lookup: item_id -> Product object
  289. product_map = {product.item_id: product for product in products_queryset}
  290. found_ids = set(product_map.keys())
  291. results = []
  292. successful = 0
  293. failed = 0
  294. for product_entry in product_list:
  295. item_id = product_entry['item_id']
  296. # Get item-specific mandatory attributes
  297. mandatory_attrs = product_entry['mandatory_attrs']
  298. if item_id not in found_ids:
  299. failed += 1
  300. results.append({
  301. "product_id": item_id,
  302. "error": "Product not found in database"
  303. })
  304. continue
  305. product = product_map[item_id]
  306. try:
  307. title = product.product_name
  308. short_desc = product.product_short_description
  309. long_desc = product.product_long_description
  310. image_url = product.image_path
  311. # image_url = "https://images.unsplash.com/photo-1595777457583-95e059d581b8"
  312. ocr_results = None
  313. ocr_text = None
  314. visual_results = None
  315. # Image Processing Logic
  316. if process_image and image_url:
  317. # OCR Processing
  318. ocr_service = OCRService()
  319. ocr_results = ocr_service.process_image(image_url)
  320. print(f"OCR results for {item_id}: {ocr_results}")
  321. if ocr_results and ocr_results.get("detected_text"):
  322. ocr_attrs = ProductAttributeService.extract_attributes_from_ocr(
  323. ocr_results, model
  324. )
  325. ocr_results["extracted_attributes"] = ocr_attrs
  326. ocr_text = "\n".join([
  327. f"{item['text']} (confidence: {item['confidence']:.2f})"
  328. for item in ocr_results["detected_text"]
  329. ])
  330. # Visual Processing
  331. visual_service = VisualProcessingService()
  332. product_type_hint = product.product_type if hasattr(product, 'product_type') else None
  333. visual_results = visual_service.process_image(image_url, product_type_hint)
  334. print(f"Visual results for {item_id}: {visual_results.get('visual_attributes', {})}")
  335. # Format visual attributes to array format with source tracking
  336. if visual_results and visual_results.get('visual_attributes'):
  337. visual_results['visual_attributes'] = ProductAttributeService.format_visual_attributes(
  338. visual_results['visual_attributes']
  339. )
  340. # Combine product text with source tracking
  341. product_text, source_map = ProductAttributeService.combine_product_text(
  342. title=title,
  343. short_desc=short_desc,
  344. long_desc=long_desc,
  345. ocr_text=ocr_text
  346. )
  347. # DEBUG: Print before extraction
  348. print(f"\n>>> Extracting for product {item_id}")
  349. print(f" Passing multiple: {multiple}")
  350. # Attribute Extraction with source tracking (returns array format)
  351. extracted = ProductAttributeService.extract_attributes(
  352. product_text=product_text,
  353. mandatory_attrs=mandatory_attrs,
  354. source_map=source_map,
  355. model=model,
  356. extract_additional=extract_additional,
  357. multiple=multiple,
  358. threshold_abs=threshold_abs,
  359. margin=margin,
  360. use_dynamic_thresholds=use_dynamic_thresholds,
  361. use_adaptive_margin=use_adaptive_margin,
  362. use_semantic_clustering=use_semantic_clustering
  363. )
  364. result = {
  365. "product_id": product.item_id,
  366. "mandatory": extracted.get("mandatory", {}),
  367. "additional": extracted.get("additional", {}),
  368. }
  369. # Attach OCR results if available
  370. if ocr_results:
  371. result["ocr_results"] = ocr_results
  372. # Attach Visual Processing results if available
  373. if visual_results:
  374. result["visual_results"] = visual_results
  375. results.append(result)
  376. successful += 1
  377. except Exception as e:
  378. failed += 1
  379. results.append({
  380. "product_id": item_id,
  381. "error": str(e)
  382. })
  383. batch_result = {
  384. "results": results,
  385. "total_products": len(product_list),
  386. "successful": successful,
  387. "failed": failed
  388. }
  389. response_serializer = BatchProductResponseSerializer(data=batch_result)
  390. if response_serializer.is_valid():
  391. return Response(response_serializer.data, status=status.HTTP_200_OK)
  392. return Response(batch_result, status=status.HTTP_200_OK)
  393. # class ExtractProductAttributesView(APIView):
  394. # """
  395. # API endpoint to extract product attributes for a single product by item_id.
  396. # Fetches product details from database with source tracking.
  397. # Returns attributes in array format: [{"value": "...", "source": "..."}]
  398. # """
  399. # def post(self, request):
  400. # serializer = SingleProductRequestSerializer(data=request.data)
  401. # if not serializer.is_valid():
  402. # return Response({"error": serializer.errors}, status=status.HTTP_400_BAD_REQUEST)
  403. # validated_data = serializer.validated_data
  404. # item_id = validated_data.get("item_id")
  405. # # Fetch product from DB
  406. # try:
  407. # product = Product.objects.get(item_id=item_id)
  408. # except Product.DoesNotExist:
  409. # return Response(
  410. # {"error": f"Product with item_id '{item_id}' not found."},
  411. # status=status.HTTP_404_NOT_FOUND
  412. # )
  413. # # Extract product details
  414. # title = product.product_name
  415. # short_desc = product.product_short_description
  416. # long_desc = product.product_long_description
  417. # image_url = product.image_path
  418. # # Process image for OCR if required
  419. # ocr_results = None
  420. # ocr_text = None
  421. # if validated_data.get("process_image", True) and image_url:
  422. # ocr_service = OCRService()
  423. # ocr_results = ocr_service.process_image(image_url)
  424. # if ocr_results and ocr_results.get("detected_text"):
  425. # ocr_attrs = ProductAttributeService.extract_attributes_from_ocr(
  426. # ocr_results, validated_data.get("model")
  427. # )
  428. # ocr_results["extracted_attributes"] = ocr_attrs
  429. # ocr_text = "\n".join([
  430. # f"{item['text']} (confidence: {item['confidence']:.2f})"
  431. # for item in ocr_results["detected_text"]
  432. # ])
  433. # # Combine all product text with source tracking
  434. # product_text, source_map = ProductAttributeService.combine_product_text(
  435. # title=title,
  436. # short_desc=short_desc,
  437. # long_desc=long_desc,
  438. # ocr_text=ocr_text
  439. # )
  440. # # Extract attributes with enhanced features and source tracking
  441. # result = ProductAttributeService.extract_attributes(
  442. # product_text=product_text,
  443. # mandatory_attrs=validated_data["mandatory_attrs"],
  444. # source_map=source_map,
  445. # model=validated_data.get("model"),
  446. # extract_additional=validated_data.get("extract_additional", True),
  447. # multiple=validated_data.get("multiple", []),
  448. # threshold_abs=validated_data.get("threshold_abs", 0.65),
  449. # margin=validated_data.get("margin", 0.15),
  450. # use_dynamic_thresholds=validated_data.get("use_dynamic_thresholds", True),
  451. # use_adaptive_margin=validated_data.get("use_adaptive_margin", True),
  452. # use_semantic_clustering=validated_data.get("use_semantic_clustering", True)
  453. # )
  454. # # Attach OCR results if available
  455. # if ocr_results:
  456. # result["ocr_results"] = ocr_results
  457. # response_serializer = ProductAttributeResultSerializer(data=result)
  458. # if response_serializer.is_valid():
  459. # return Response(response_serializer.data, status=status.HTTP_200_OK)
  460. # return Response(result, status=status.HTTP_200_OK)
  461. # class BatchExtractProductAttributesView(APIView):
  462. # """
  463. # API endpoint to extract product attributes for multiple products in batch.
  464. # Uses item-specific mandatory_attrs with source tracking.
  465. # Returns attributes in array format: [{"value": "...", "source": "..."}]
  466. # """
  467. # def post(self, request):
  468. # serializer = BatchProductRequestSerializer(data=request.data)
  469. # if not serializer.is_valid():
  470. # return Response({"error": serializer.errors}, status=status.HTTP_400_BAD_REQUEST)
  471. # validated_data = serializer.validated_data
  472. # # DEBUG: Print what we received
  473. # print("\n" + "="*80)
  474. # print("BATCH REQUEST - RECEIVED DATA")
  475. # print("="*80)
  476. # print(f"Raw request data keys: {request.data.keys()}")
  477. # print(f"Multiple field in request: {request.data.get('multiple')}")
  478. # print(f"Validated multiple field: {validated_data.get('multiple')}")
  479. # print("="*80 + "\n")
  480. # # Get batch-level settings
  481. # product_list = validated_data.get("products", [])
  482. # model = validated_data.get("model")
  483. # extract_additional = validated_data.get("extract_additional", True)
  484. # process_image = validated_data.get("process_image", True)
  485. # multiple = validated_data.get("multiple", [])
  486. # threshold_abs = validated_data.get("threshold_abs", 0.65)
  487. # margin = validated_data.get("margin", 0.15)
  488. # use_dynamic_thresholds = validated_data.get("use_dynamic_thresholds", True)
  489. # use_adaptive_margin = validated_data.get("use_adaptive_margin", True)
  490. # use_semantic_clustering = validated_data.get("use_semantic_clustering", True)
  491. # # DEBUG: Print extracted settings
  492. # print(f"Extracted multiple parameter: {multiple}")
  493. # print(f"Type: {type(multiple)}")
  494. # # Extract all item_ids to query the database efficiently
  495. # item_ids = [p['item_id'] for p in product_list]
  496. # # Fetch all products in one query
  497. # products_queryset = Product.objects.filter(item_id__in=item_ids)
  498. # # Create a dictionary for easy lookup: item_id -> Product object
  499. # product_map = {product.item_id: product for product in products_queryset}
  500. # found_ids = set(product_map.keys())
  501. # results = []
  502. # successful = 0
  503. # failed = 0
  504. # for product_entry in product_list:
  505. # item_id = product_entry['item_id']
  506. # # Get item-specific mandatory attributes
  507. # mandatory_attrs = product_entry['mandatory_attrs']
  508. # if item_id not in found_ids:
  509. # failed += 1
  510. # results.append({
  511. # "product_id": item_id,
  512. # "error": "Product not found in database"
  513. # })
  514. # continue
  515. # product = product_map[item_id]
  516. # try:
  517. # title = product.product_name
  518. # short_desc = product.product_short_description
  519. # long_desc = product.product_long_description
  520. # # image_url = product.image_path
  521. # image_url = "http://localhost:8000/media/products/levi_test_ocr2.jpg"
  522. # ocr_results = None
  523. # ocr_text = None
  524. # # Image Processing Logic
  525. # if process_image and image_url:
  526. # ocr_service = OCRService()
  527. # ocr_results = ocr_service.process_image(image_url)
  528. # print(f"ocr results are: {ocr_results}")
  529. # if ocr_results and ocr_results.get("detected_text"):
  530. # ocr_attrs = ProductAttributeService.extract_attributes_from_ocr(
  531. # ocr_results, model
  532. # )
  533. # ocr_results["extracted_attributes"] = ocr_attrs
  534. # ocr_text = "\n".join([
  535. # f"{item['text']} (confidence: {item['confidence']:.2f})"
  536. # for item in ocr_results["detected_text"]
  537. # ])
  538. # # Combine product text with source tracking
  539. # product_text, source_map = ProductAttributeService.combine_product_text(
  540. # title=title,
  541. # short_desc=short_desc,
  542. # long_desc=long_desc,
  543. # ocr_text=ocr_text
  544. # )
  545. # # DEBUG: Print before extraction
  546. # print(f"\n>>> Extracting for product {item_id}")
  547. # print(f" Passing multiple: {multiple}")
  548. # # Attribute Extraction with source tracking (returns array format)
  549. # extracted = ProductAttributeService.extract_attributes(
  550. # product_text=product_text,
  551. # mandatory_attrs=mandatory_attrs,
  552. # source_map=source_map,
  553. # model=model,
  554. # extract_additional=extract_additional,
  555. # multiple=multiple, # Make sure this is passed!
  556. # threshold_abs=threshold_abs,
  557. # margin=margin,
  558. # use_dynamic_thresholds=use_dynamic_thresholds,
  559. # use_adaptive_margin=use_adaptive_margin,
  560. # use_semantic_clustering=use_semantic_clustering
  561. # )
  562. # result = {
  563. # "product_id": product.item_id,
  564. # "mandatory": extracted.get("mandatory", {}),
  565. # "additional": extracted.get("additional", {}),
  566. # }
  567. # if ocr_results:
  568. # result["ocr_results"] = ocr_results
  569. # results.append(result)
  570. # successful += 1
  571. # except Exception as e:
  572. # failed += 1
  573. # results.append({
  574. # "product_id": item_id,
  575. # "error": str(e)
  576. # })
  577. # batch_result = {
  578. # "results": results,
  579. # "total_products": len(product_list),
  580. # "successful": successful,
  581. # "failed": failed
  582. # }
  583. # response_serializer = BatchProductResponseSerializer(data=batch_result)
  584. # if response_serializer.is_valid():
  585. # return Response(response_serializer.data, status=status.HTTP_200_OK)
  586. # return Response(batch_result, status=status.HTTP_200_OK)
  587. class ProductListView(APIView):
  588. """
  589. GET API to list all products with details
  590. """
  591. def get(self, request):
  592. products = Product.objects.all()
  593. serializer = ProductSerializer(products, many=True)
  594. return Response(serializer.data, status=status.HTTP_200_OK)
  595. # class ProductUploadExcelView(APIView):
  596. # """
  597. # POST API to upload an Excel file and add data to Product model (skip duplicates)
  598. # """
  599. # parser_classes = (MultiPartParser, FormParser)
  600. # def post(self, request, *args, **kwargs):
  601. # file_obj = request.FILES.get('file')
  602. # if not file_obj:
  603. # return Response({'error': 'No file provided'}, status=status.HTTP_400_BAD_REQUEST)
  604. # try:
  605. # df = pd.read_excel(file_obj)
  606. # df.columns = [c.strip().lower().replace(' ', '_') for c in df.columns]
  607. # expected_cols = {
  608. # 'item_id',
  609. # 'product_name',
  610. # 'product_long_description',
  611. # 'product_short_description',
  612. # 'product_type',
  613. # 'image_path'
  614. # }
  615. # if not expected_cols.issubset(df.columns):
  616. # return Response({
  617. # 'error': 'Missing required columns',
  618. # 'required_columns': list(expected_cols)
  619. # }, status=status.HTTP_400_BAD_REQUEST)
  620. # created_count = 0
  621. # skipped_count = 0
  622. # for _, row in df.iterrows():
  623. # item_id = row.get('item_id', '')
  624. # # Check if this item already exists
  625. # if Product.objects.filter(item_id=item_id).exists():
  626. # skipped_count += 1
  627. # continue
  628. # Product.objects.create(
  629. # item_id=item_id,
  630. # product_name=row.get('product_name', ''),
  631. # product_long_description=row.get('product_long_description', ''),
  632. # product_short_description=row.get('product_short_description', ''),
  633. # product_type=row.get('product_type', ''),
  634. # image_path=row.get('image_path', ''),
  635. # )
  636. # created_count += 1
  637. # return Response({
  638. # 'message': f'Successfully uploaded {created_count} products.',
  639. # 'skipped': f'Skipped {skipped_count} duplicates.'
  640. # }, status=status.HTTP_201_CREATED)
  641. # except Exception as e:
  642. # return Response({'error': str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
  643. from rest_framework.views import APIView
  644. from rest_framework.response import Response
  645. from rest_framework import status
  646. from rest_framework.parsers import MultiPartParser, FormParser
  647. import pandas as pd
  648. from .models import Product
  649. class ProductUploadExcelView(APIView):
  650. """
  651. POST API to upload an Excel file and add/update data in Product model.
  652. - Creates new records if they don't exist.
  653. - Updates existing ones (e.g., when image_path or other fields change).
  654. """
  655. parser_classes = (MultiPartParser, FormParser)
  656. def post(self, request, *args, **kwargs):
  657. file_obj = request.FILES.get('file')
  658. if not file_obj:
  659. return Response({'error': 'No file provided'}, status=status.HTTP_400_BAD_REQUEST)
  660. try:
  661. # Read Excel into DataFrame
  662. df = pd.read_excel(file_obj)
  663. df.columns = [c.strip().lower().replace(' ', '_') for c in df.columns]
  664. expected_cols = {
  665. 'item_id',
  666. 'product_name',
  667. 'product_long_description',
  668. 'product_short_description',
  669. 'product_type',
  670. 'image_path'
  671. }
  672. # Check required columns
  673. if not expected_cols.issubset(df.columns):
  674. return Response({
  675. 'error': 'Missing required columns',
  676. 'required_columns': list(expected_cols)
  677. }, status=status.HTTP_400_BAD_REQUEST)
  678. created_count = 0
  679. updated_count = 0
  680. # Loop through rows and update or create
  681. for _, row in df.iterrows():
  682. item_id = str(row.get('item_id', '')).strip()
  683. if not item_id:
  684. continue # Skip rows without an item_id
  685. defaults = {
  686. 'product_name': row.get('product_name', ''),
  687. 'product_long_description': row.get('product_long_description', ''),
  688. 'product_short_description': row.get('product_short_description', ''),
  689. 'product_type': row.get('product_type', ''),
  690. 'image_path': row.get('image_path', ''),
  691. }
  692. obj, created = Product.objects.update_or_create(
  693. item_id=item_id,
  694. defaults=defaults
  695. )
  696. if created:
  697. created_count += 1
  698. else:
  699. updated_count += 1
  700. return Response({
  701. 'message': f'Upload successful.',
  702. 'created': f'{created_count} new records added.',
  703. 'updated': f'{updated_count} existing records updated.'
  704. }, status=status.HTTP_201_CREATED)
  705. except Exception as e:
  706. return Response({'error': str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
  707. class ProductAttributesUploadView(APIView):
  708. """
  709. POST API to upload an Excel file and add mandatory/additional attributes
  710. for product types with possible values.
  711. """
  712. parser_classes = (MultiPartParser, FormParser)
  713. def post(self, request):
  714. file_obj = request.FILES.get('file')
  715. if not file_obj:
  716. return Response({"error": "No file provided."}, status=status.HTTP_400_BAD_REQUEST)
  717. try:
  718. df = pd.read_excel(file_obj)
  719. required_columns = {'product_type', 'attribute_name', 'is_mandatory', 'possible_values'}
  720. if not required_columns.issubset(df.columns):
  721. return Response({
  722. "error": f"Missing required columns. Found: {list(df.columns)}"
  723. }, status=status.HTTP_400_BAD_REQUEST)
  724. for _, row in df.iterrows():
  725. product_type_name = str(row['product_type']).strip()
  726. attr_name = str(row['attribute_name']).strip()
  727. is_mandatory = str(row['is_mandatory']).strip().lower() in ['yes', 'true', '1']
  728. possible_values = str(row.get('possible_values', '')).strip()
  729. # Get or create product type
  730. product_type, _ = ProductType.objects.get_or_create(name=product_type_name)
  731. # Get or create attribute
  732. attribute, _ = ProductAttribute.objects.get_or_create(
  733. product_type=product_type,
  734. name=attr_name,
  735. defaults={'is_mandatory': is_mandatory}
  736. )
  737. attribute.is_mandatory = is_mandatory
  738. attribute.save()
  739. # Handle possible values
  740. AttributePossibleValue.objects.filter(attribute=attribute).delete()
  741. if possible_values:
  742. for val in [v.strip() for v in possible_values.split(',') if v.strip()]:
  743. AttributePossibleValue.objects.create(attribute=attribute, value=val)
  744. return Response({"message": "Attributes uploaded successfully."}, status=status.HTTP_201_CREATED)
  745. except Exception as e:
  746. return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
  747. class ProductTypeAttributesView(APIView):
  748. """
  749. API to view, create, update, and delete product type attributes and their possible values.
  750. Also supports dynamic product type creation.
  751. """
  752. def get(self, request):
  753. """
  754. Retrieve all product types with their attributes and possible values.
  755. """
  756. product_types = ProductType.objects.all()
  757. serializer = ProductTypeSerializer(product_types, many=True)
  758. # Transform the serialized data into the requested format
  759. result = []
  760. for pt in serializer.data:
  761. for attr in pt['attributes']:
  762. result.append({
  763. 'product_type': pt['name'],
  764. 'attribute_name': attr['name'],
  765. 'is_mandatory': 'Yes' if attr['is_mandatory'] else 'No',
  766. 'possible_values': ', '.join([pv['value'] for pv in attr['possible_values']])
  767. })
  768. return Response(result, status=status.HTTP_200_OK)
  769. def post(self, request):
  770. """
  771. Create a new product type or attribute with possible values.
  772. Expected payload example:
  773. {
  774. "product_type": "Hardware Screws",
  775. "attribute_name": "Material",
  776. "is_mandatory": "Yes",
  777. "possible_values": "Steel, Zinc Plated, Stainless Steel"
  778. }
  779. """
  780. try:
  781. product_type_name = request.data.get('product_type')
  782. attribute_name = request.data.get('attribute_name', '')
  783. is_mandatory = request.data.get('is_mandatory', '').lower() in ['yes', 'true', '1']
  784. possible_values = request.data.get('possible_values', '')
  785. if not product_type_name:
  786. return Response({
  787. "error": "product_type is required"
  788. }, status=status.HTTP_400_BAD_REQUEST)
  789. with transaction.atomic():
  790. # Get or create product type
  791. product_type, created = ProductType.objects.get_or_create(name=product_type_name)
  792. if created and not attribute_name:
  793. return Response({
  794. "message": f"Product type '{product_type_name}' created successfully",
  795. "data": {"product_type": product_type_name}
  796. }, status=status.HTTP_201_CREATED)
  797. if attribute_name:
  798. # Create attribute
  799. attribute, attr_created = ProductAttribute.objects.get_or_create(
  800. product_type=product_type,
  801. name=attribute_name,
  802. defaults={'is_mandatory': is_mandatory}
  803. )
  804. if not attr_created:
  805. return Response({
  806. "error": f"Attribute '{attribute_name}' already exists for product type '{product_type_name}'"
  807. }, status=status.HTTP_400_BAD_REQUEST)
  808. # Handle possible values
  809. if possible_values:
  810. for val in [v.strip() for v in possible_values.split(',') if v.strip()]:
  811. AttributePossibleValue.objects.create(attribute=attribute, value=val)
  812. return Response({
  813. "message": "Attribute created successfully",
  814. "data": {
  815. "product_type": product_type_name,
  816. "attribute_name": attribute_name,
  817. "is_mandatory": "Yes" if is_mandatory else "No",
  818. "possible_values": possible_values
  819. }
  820. }, status=status.HTTP_201_CREATED)
  821. return Response({
  822. "message": f"Product type '{product_type_name}' already exists",
  823. "data": {"product_type": product_type_name}
  824. }, status=status.HTTP_200_OK)
  825. except Exception as e:
  826. return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
  827. def put(self, request):
  828. """
  829. Update an existing product type attribute and its possible values.
  830. Expected payload example:
  831. {
  832. "product_type": "Hardware Screws",
  833. "attribute_name": "Material",
  834. "is_mandatory": "Yes",
  835. "possible_values": "Steel, Zinc Plated, Stainless Steel, Brass"
  836. }
  837. """
  838. try:
  839. product_type_name = request.data.get('product_type')
  840. attribute_name = request.data.get('attribute_name')
  841. is_mandatory = request.data.get('is_mandatory', '').lower() in ['yes', 'true', '1']
  842. possible_values = request.data.get('possible_values', '')
  843. if not all([product_type_name, attribute_name]):
  844. return Response({
  845. "error": "product_type and attribute_name are required"
  846. }, status=status.HTTP_400_BAD_REQUEST)
  847. with transaction.atomic():
  848. try:
  849. product_type = ProductType.objects.get(name=product_type_name)
  850. attribute = ProductAttribute.objects.get(
  851. product_type=product_type,
  852. name=attribute_name
  853. )
  854. except ProductType.DoesNotExist:
  855. return Response({
  856. "error": f"Product type '{product_type_name}' not found"
  857. }, status=status.HTTP_404_NOT_FOUND)
  858. except ProductAttribute.DoesNotExist:
  859. return Response({
  860. "error": f"Attribute '{attribute_name}' not found for product type '{product_type_name}'"
  861. }, status=status.HTTP_404_NOT_FOUND)
  862. # Update attribute
  863. attribute.is_mandatory = is_mandatory
  864. attribute.save()
  865. # Update possible values
  866. AttributePossibleValue.objects.filter(attribute=attribute).delete()
  867. if possible_values:
  868. for val in [v.strip() for v in possible_values.split(',') if v.strip()]:
  869. AttributePossibleValue.objects.create(attribute=attribute, value=val)
  870. return Response({
  871. "message": "Attribute updated successfully",
  872. "data": {
  873. "product_type": product_type_name,
  874. "attribute_name": attribute_name,
  875. "is_mandatory": "Yes" if is_mandatory else "No",
  876. "possible_values": possible_values
  877. }
  878. }, status=status.HTTP_200_OK)
  879. except Exception as e:
  880. return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
  881. def delete(self, request):
  882. """
  883. Delete a product type or a specific attribute.
  884. Expected payload example:
  885. {
  886. "product_type": "Hardware Screws",
  887. "attribute_name": "Material"
  888. }
  889. """
  890. try:
  891. product_type_name = request.data.get('product_type')
  892. attribute_name = request.data.get('attribute_name', '')
  893. if not product_type_name:
  894. return Response({
  895. "error": "product_type is required"
  896. }, status=status.HTTP_400_BAD_REQUEST)
  897. with transaction.atomic():
  898. try:
  899. product_type = ProductType.objects.get(name=product_type_name)
  900. except ProductType.DoesNotExist:
  901. return Response({
  902. "error": f"Product type '{product_type_name}' not found"
  903. }, status=status.HTTP_404_NOT_FOUND)
  904. if attribute_name:
  905. # Delete specific attribute
  906. try:
  907. attribute = ProductAttribute.objects.get(
  908. product_type=product_type,
  909. name=attribute_name
  910. )
  911. attribute.delete()
  912. return Response({
  913. "message": f"Attribute '{attribute_name}' deleted successfully from product type '{product_type_name}'"
  914. }, status=status.HTTP_200_OK)
  915. except ProductAttribute.DoesNotExist:
  916. return Response({
  917. "error": f"Attribute '{attribute_name}' not found for product type '{product_type_name}'"
  918. }, status=status.HTTP_404_NOT_FOUND)
  919. else:
  920. # Delete entire product type
  921. product_type.delete()
  922. return Response({
  923. "message": f"Product type '{product_type_name}' and all its attributes deleted successfully"
  924. }, status=status.HTTP_200_OK)
  925. except Exception as e:
  926. return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
  927. class ProductTypeListView(APIView):
  928. """
  929. GET API to list all product types (only names).
  930. """
  931. def get(self, request):
  932. product_types = ProductType.objects.values_list('name', flat=True)
  933. return Response({"product_types": list(product_types)}, status=status.HTTP_200_OK)