Browse Source

Merge branch 'master' of https://git.luminad.com/harshit.pathak/content_quality_tool

VISHAL BHANUSHALI 3 tháng trước cách đây
mục cha
commit
9169fe79b2

+ 904 - 2
attr_extraction/services.py

@@ -948,6 +948,862 @@
 
 
 
+# # ==================== services.py ====================
+# import requests
+# import json
+# from typing import Dict, List, Optional, Tuple
+# from django.conf import settings
+# from concurrent.futures import ThreadPoolExecutor, as_completed
+# from sentence_transformers import SentenceTransformer, util
+# import numpy as np
+# from .ocr_service import OCRService
+
+
+# # Initialize embedding model for normalization
+# model_embedder = SentenceTransformer("all-MiniLM-L6-v2")
+
+
+# class ProductAttributeService:
+#     """Service class for extracting product attributes using Groq LLM."""
+
+#     @staticmethod
+#     def combine_product_text(
+#         title: Optional[str] = None,
+#         short_desc: Optional[str] = None,
+#         long_desc: Optional[str] = None,
+#         ocr_text: Optional[str] = None
+#     ) -> Tuple[str, Dict[str, str]]:
+#         """
+#         Combine product metadata into a single text block.
+#         Returns: (combined_text, source_map) where source_map tracks which text came from where
+#         """
+#         parts = []
+#         source_map = {}
+        
+#         if title:
+#             title_str = str(title).strip()
+#             parts.append(f"Title: {title_str}")
+#             source_map['title'] = title_str
+#         if short_desc:
+#             short_str = str(short_desc).strip()
+#             parts.append(f"Description: {short_str}")
+#             source_map['short_desc'] = short_str
+#         if long_desc:
+#             long_str = str(long_desc).strip()
+#             parts.append(f"Details: {long_str}")
+#             source_map['long_desc'] = long_str
+#         if ocr_text:
+#             parts.append(f"OCR Text: {ocr_text}")
+#             source_map['ocr_text'] = ocr_text
+        
+#         combined = "\n".join(parts).strip()
+        
+#         if not combined:
+#             return "No product information available", {}
+        
+#         return combined, source_map
+
+#     @staticmethod
+#     def find_value_source(value: str, source_map: Dict[str, str]) -> str:
+#         """
+#         Find which source(s) contain the given value.
+#         Returns the source name(s) where the value appears.
+#         """
+#         value_lower = value.lower()
+#         # Split value into tokens for better matching
+#         value_tokens = set(value_lower.replace("-", " ").split())
+        
+#         sources_found = []
+#         source_scores = {}
+        
+#         for source_name, source_text in source_map.items():
+#             source_lower = source_text.lower()
+            
+#             # Check for exact phrase match first
+#             if value_lower in source_lower:
+#                 source_scores[source_name] = 1.0
+#                 continue
+            
+#             # Check for token matches
+#             token_matches = sum(1 for token in value_tokens if token in source_lower)
+#             if token_matches > 0:
+#                 source_scores[source_name] = token_matches / len(value_tokens)
+        
+#         # Return source with highest score, or all sources if multiple have same score
+#         if source_scores:
+#             max_score = max(source_scores.values())
+#             sources_found = [s for s, score in source_scores.items() if score == max_score]
+            
+#             # Prioritize: title > short_desc > long_desc > ocr_text
+#             priority = ['title', 'short_desc', 'long_desc', 'ocr_text']
+#             for p in priority:
+#                 if p in sources_found:
+#                     return p
+            
+#             return sources_found[0] if sources_found else "Not found"
+        
+#         return "Not found"
+
+#     @staticmethod
+#     def extract_attributes_from_ocr(ocr_results: Dict, model: str = None) -> Dict:
+#         """Extract structured attributes from OCR text using LLM."""
+#         if model is None:
+#             model = settings.SUPPORTED_MODELS[0]
+        
+#         detected_text = ocr_results.get('detected_text', [])
+#         if not detected_text:
+#             return {}
+        
+#         # Format OCR text for prompt
+#         ocr_text = "\n".join([f"Text: {item['text']}, Confidence: {item['confidence']:.2f}" 
+#                               for item in detected_text])
+        
+#         prompt = f"""
+# You are an AI model that extracts structured attributes from OCR text detected on product images.
+# Given the OCR detections below, infer the possible product attributes and return them as a clean JSON object.
+
+# OCR Text:
+# {ocr_text}
+
+# Extract relevant attributes like:
+# - brand
+# - model_number
+# - size (waist_size, length, etc.)
+# - collection
+# - any other relevant product information
+
+# Return a JSON object with only the attributes you can confidently identify.
+# If an attribute is not present, do not include it in the response.
+# """
+        
+#         payload = {
+#             "model": model,
+#             "messages": [
+#                 {
+#                     "role": "system",
+#                     "content": "You are a helpful AI that extracts structured data from OCR output. Return only valid JSON."
+#                 },
+#                 {"role": "user", "content": prompt}
+#             ],
+#             "temperature": 0.2,
+#             "max_tokens": 500
+#         }
+        
+#         headers = {
+#             "Authorization": f"Bearer {settings.GROQ_API_KEY}",
+#             "Content-Type": "application/json",
+#         }
+        
+#         try:
+#             response = requests.post(
+#                 settings.GROQ_API_URL,
+#                 headers=headers,
+#                 json=payload,
+#                 timeout=30
+#             )
+#             response.raise_for_status()
+#             result_text = response.json()["choices"][0]["message"]["content"].strip()
+            
+#             # Clean and parse JSON
+#             result_text = ProductAttributeService._clean_json_response(result_text)
+#             parsed = json.loads(result_text)
+            
+#             return parsed
+#         except Exception as e:
+#             return {"error": f"Failed to extract attributes from OCR: {str(e)}"}
+
+#     @staticmethod
+#     def calculate_attribute_relationships(
+#         mandatory_attrs: Dict[str, List[str]],
+#         product_text: str
+#     ) -> Dict[str, float]:
+#         """
+#         Calculate semantic relationships between attribute values across different attributes.
+#         Returns a matrix of cross-attribute value similarities.
+#         """
+#         pt_emb = model_embedder.encode(product_text, convert_to_tensor=True)
+
+#         # Calculate similarities between all attribute values and product text
+#         attr_scores = {}
+#         for attr, values in mandatory_attrs.items():
+#             attr_scores[attr] = {}
+#             for val in values:
+#                 contexts = [val, f"for {val}", f"use in {val}", f"suitable for {val}"]
+#                 ctx_embs = [model_embedder.encode(c, convert_to_tensor=True) for c in contexts]
+#                 sem_sim = max(float(util.cos_sim(pt_emb, ce).item()) for ce in ctx_embs)
+#                 attr_scores[attr][val] = sem_sim
+
+#         # Calculate cross-attribute value relationships
+#         relationships = {}
+#         attr_list = list(mandatory_attrs.keys())
+
+#         for i, attr1 in enumerate(attr_list):
+#             for attr2 in attr_list[i+1:]:
+#                 # Calculate pairwise similarities between values of different attributes
+#                 for val1 in mandatory_attrs[attr1]:
+#                     for val2 in mandatory_attrs[attr2]:
+#                         emb1 = model_embedder.encode(val1, convert_to_tensor=True)
+#                         emb2 = model_embedder.encode(val2, convert_to_tensor=True)
+#                         sim = float(util.cos_sim(emb1, emb2).item())
+
+#                         # Store bidirectional relationships
+#                         key1 = f"{attr1}:{val1}->{attr2}:{val2}"
+#                         key2 = f"{attr2}:{val2}->{attr1}:{val1}"
+#                         relationships[key1] = sim
+#                         relationships[key2] = sim
+
+#         return relationships
+
+#     @staticmethod
+#     def calculate_value_clusters(
+#         values: List[str],
+#         scores: List[Tuple[str, float]],
+#         cluster_threshold: float = 0.4
+#     ) -> List[List[str]]:
+#         """
+#         Group values into semantic clusters based on their similarity to each other.
+#         Returns clusters of related values.
+#         """
+#         if len(values) <= 1:
+#             return [[val] for val, _ in scores]
+
+#         # Get embeddings for all values
+#         embeddings = [model_embedder.encode(val, convert_to_tensor=True) for val in values]
+
+#         # Calculate pairwise similarities
+#         similarity_matrix = np.zeros((len(values), len(values)))
+#         for i in range(len(values)):
+#             for j in range(i+1, len(values)):
+#                 sim = float(util.cos_sim(embeddings[i], embeddings[j]).item())
+#                 similarity_matrix[i][j] = sim
+#                 similarity_matrix[j][i] = sim
+
+#         # Simple clustering: group values with high similarity
+#         clusters = []
+#         visited = set()
+
+#         for i, (val, score) in enumerate(scores):
+#             if i in visited:
+#                 continue
+
+#             cluster = [val]
+#             visited.add(i)
+
+#             # Find similar values
+#             for j in range(len(values)):
+#                 if j not in visited and similarity_matrix[i][j] >= cluster_threshold:
+#                     cluster.append(values[j])
+#                     visited.add(j)
+
+#             clusters.append(cluster)
+
+#         return clusters
+
+#     @staticmethod
+#     def get_dynamic_threshold(
+#         attr: str,
+#         val: str,
+#         base_score: float,
+#         extracted_attrs: Dict[str, List[Dict[str, str]]],
+#         relationships: Dict[str, float],
+#         mandatory_attrs: Dict[str, List[str]],
+#         base_threshold: float = 0.65,
+#         boost_factor: float = 0.15
+#     ) -> float:
+#         """
+#         Calculate dynamic threshold based on relationships with already-extracted attributes.
+#         """
+#         threshold = base_threshold
+
+#         # Check relationships with already extracted attributes
+#         max_relationship = 0.0
+#         for other_attr, other_values_list in extracted_attrs.items():
+#             if other_attr == attr:
+#                 continue
+
+#             for other_val_dict in other_values_list:
+#                 other_val = other_val_dict['value']
+#                 key = f"{attr}:{val}->{other_attr}:{other_val}"
+#                 if key in relationships:
+#                     max_relationship = max(max_relationship, relationships[key])
+
+#         # If strong relationship exists, lower threshold
+#         if max_relationship > 0.6:
+#             threshold = base_threshold - (boost_factor * max_relationship)
+
+#         return max(0.3, threshold)
+
+#     @staticmethod
+#     def get_adaptive_margin(
+#         scores: List[Tuple[str, float]],
+#         base_margin: float = 0.15,
+#         max_margin: float = 0.22
+#     ) -> float:
+#         """
+#         Calculate adaptive margin based on score distribution.
+#         """
+#         if len(scores) < 2:
+#             return base_margin
+
+#         score_values = [s for _, s in scores]
+#         best_score = score_values[0]
+
+#         # If best score is very low, use adaptive margin but be more conservative
+#         if best_score < 0.5:
+#             # Calculate score spread in top 3-4 scores only (more selective)
+#             top_scores = score_values[:min(4, len(score_values))]
+#             score_range = max(top_scores) - min(top_scores)
+
+#             # Very controlled margin increase
+#             if score_range < 0.30:
+#                 # Much more conservative scaling
+#                 score_factor = (0.5 - best_score) * 0.35
+#                 adaptive = base_margin + score_factor + (0.30 - score_range) * 0.2
+#                 return min(adaptive, max_margin)
+
+#         return base_margin
+
+#     @staticmethod
+#     def _lexical_evidence(product_text: str, label: str) -> float:
+#         """Calculate lexical overlap between product text and label."""
+#         pt = product_text.lower()
+#         tokens = [t for t in label.lower().replace("-", " ").split() if t]
+#         if not tokens:
+#             return 0.0
+#         hits = sum(1 for t in tokens if t in pt)
+#         return hits / len(tokens)
+
+#     @staticmethod
+#     def normalize_against_product_text(
+#         product_text: str,
+#         mandatory_attrs: Dict[str, List[str]],
+#         source_map: Dict[str, str],
+#         threshold_abs: float = 0.65,
+#         margin: float = 0.15,
+#         allow_multiple: bool = False,
+#         sem_weight: float = 0.8,
+#         lex_weight: float = 0.2,
+#         extracted_attrs: Optional[Dict[str, List[Dict[str, str]]]] = None,
+#         relationships: Optional[Dict[str, float]] = None,
+#         use_dynamic_thresholds: bool = True,
+#         use_adaptive_margin: bool = True,
+#         use_semantic_clustering: bool = True
+#     ) -> dict:
+#         """
+#         Score each allowed value against the product_text with dynamic thresholds.
+#         Returns dict with values in array format: [{"value": "...", "source": "..."}]
+#         """
+#         if extracted_attrs is None:
+#             extracted_attrs = {}
+#         if relationships is None:
+#             relationships = {}
+
+#         pt_emb = model_embedder.encode(product_text, convert_to_tensor=True)
+#         extracted = {}
+
+#         for attr, allowed_values in mandatory_attrs.items():
+#             scores: List[Tuple[str, float]] = []
+
+#             for val in allowed_values:
+#                 contexts = [val, f"for {val}", f"use in {val}", f"suitable for {val}", f"{val} room"]
+#                 ctx_embs = [model_embedder.encode(c, convert_to_tensor=True) for c in contexts]
+#                 sem_sim = max(float(util.cos_sim(pt_emb, ce).item()) for ce in ctx_embs)
+
+#                 lex_score = ProductAttributeService._lexical_evidence(product_text, val)
+#                 final_score = sem_weight * sem_sim + lex_weight * lex_score
+#                 scores.append((val, final_score))
+
+#             scores.sort(key=lambda x: x[1], reverse=True)
+#             best_val, best_score = scores[0]
+
+#             # Calculate adaptive margin if enabled
+#             effective_margin = margin
+#             if allow_multiple and use_adaptive_margin:
+#                 effective_margin = ProductAttributeService.get_adaptive_margin(scores, margin)
+
+#             if not allow_multiple:
+#                 source = ProductAttributeService.find_value_source(best_val, source_map)
+#                 extracted[attr] = [{"value": best_val, "source": source}]
+#             else:
+#                 candidates = [best_val]
+#                 use_base_threshold = best_score >= threshold_abs
+
+#                 # Get semantic clusters if enabled
+#                 clusters = []
+#                 if use_semantic_clustering:
+#                     clusters = ProductAttributeService.calculate_value_clusters(
+#                         allowed_values, scores, cluster_threshold=0.4
+#                     )
+#                     best_cluster = next((c for c in clusters if best_val in c), [best_val])
+
+#                 for val, sc in scores[1:]:
+#                     # Calculate dynamic threshold for this value
+#                     if use_dynamic_thresholds and extracted_attrs:
+#                         dynamic_thresh = ProductAttributeService.get_dynamic_threshold(
+#                             attr, val, sc, extracted_attrs, relationships,
+#                             mandatory_attrs, threshold_abs
+#                         )
+#                     else:
+#                         dynamic_thresh = threshold_abs
+
+#                     within_margin = (best_score - sc) <= effective_margin
+#                     above_threshold = sc >= dynamic_thresh
+
+#                     # Check if in same semantic cluster as best value
+#                     in_cluster = False
+#                     if use_semantic_clustering and clusters:
+#                         in_cluster = any(best_val in c and val in c for c in clusters)
+
+#                     if use_base_threshold:
+#                         # Best score is good, require threshold OR (cluster + margin)
+#                         if above_threshold and within_margin:
+#                             candidates.append(val)
+#                         elif in_cluster and within_margin:
+#                             candidates.append(val)
+#                     else:
+#                         # Best score is low, use margin OR cluster logic
+#                         if within_margin:
+#                             candidates.append(val)
+#                         elif in_cluster and (best_score - sc) <= effective_margin * 2.0:
+#                             # Extended margin for cluster members
+#                             candidates.append(val)
+
+#                 # Map each candidate to its source and create array format
+#                 extracted[attr] = []
+#                 for candidate in candidates:
+#                     source = ProductAttributeService.find_value_source(candidate, source_map)
+#                     extracted[attr].append({"value": candidate, "source": source})
+
+#         return extracted
+
+#     @staticmethod
+#     def extract_attributes(
+#         product_text: str,
+#         mandatory_attrs: Dict[str, List[str]],
+#         source_map: Dict[str, str] = None,
+#         model: str = None,
+#         extract_additional: bool = True,
+#         multiple: Optional[List[str]] = None,
+#         threshold_abs: float = 0.65,
+#         margin: float = 0.15,
+#         use_dynamic_thresholds: bool = True,
+#         use_adaptive_margin: bool = True,
+#         use_semantic_clustering: bool = True
+#     ) -> dict:
+#         """
+#         Use Groq LLM to extract attributes from any product type with enhanced multi-value selection.
+#         Now returns values in array format: [{"value": "...", "source": "..."}]
+#         """
+        
+#         if model is None:
+#             model = settings.SUPPORTED_MODELS[0]
+
+#         if multiple is None:
+#             multiple = []
+
+#         if source_map is None:
+#             source_map = {}
+
+#         # Check if product text is empty or minimal
+#         if not product_text or product_text == "No product information available":
+#             return ProductAttributeService._create_error_response(
+#                 "No product information provided",
+#                 mandatory_attrs,
+#                 extract_additional
+#             )
+
+#         # Create structured prompt for mandatory attributes
+#         mandatory_attr_list = []
+#         for attr_name, allowed_values in mandatory_attrs.items():
+#             mandatory_attr_list.append(f"{attr_name}: {', '.join(allowed_values)}")
+#         mandatory_attr_text = "\n".join(mandatory_attr_list)
+
+#         additional_instruction = ""
+#         if extract_additional:
+#             additional_instruction = """
+# 2. Extract ADDITIONAL attributes: Identify any other relevant attributes from the product text 
+#    that are NOT in the mandatory list. Only include attributes where you can find actual values
+#    in the product text. Do NOT include attributes with "Not Specified" or empty values.
+   
+#    Examples of attributes to look for (only if present): Brand, Material, Size, Color, Dimensions,
+#    Weight, Features, Style, Theme, Pattern, Finish, Care Instructions, etc."""
+
+#         output_format = {
+#             "mandatory": {attr: "value or list of values" for attr in mandatory_attrs.keys()},
+#         }
+
+#         if extract_additional:
+#             output_format["additional"] = {
+#                 "example_attribute_1": "actual value found",
+#                 "example_attribute_2": "actual value found"
+#             }
+#             output_format["additional"]["_note"] = "Only include attributes with actual values found in text"
+
+#         prompt = f"""
+# You are an intelligent product attribute extractor that works with ANY product type.
+
+# TASK:
+# 1. Extract MANDATORY attributes: For each mandatory attribute, select the most appropriate value(s)
+#    from the provided list. Choose the value(s) that best match the product description.
+# {additional_instruction}
+
+# Product Text:
+# {product_text}
+
+# Mandatory Attribute Lists (MUST select from these allowed values):
+# {mandatory_attr_text}
+
+# CRITICAL INSTRUCTIONS:
+# - Return ONLY valid JSON, nothing else
+# - No explanations, no markdown, no text before or after the JSON
+# - For mandatory attributes, choose the value(s) from the provided list that best match
+# - If a mandatory attribute cannot be determined from the product text, use "Not Specified"
+# - Prefer exact matches from the allowed values list over generic synonyms
+# - If multiple values are plausible, you MAY return more than one
+# {f"- For additional attributes: ONLY include attributes where you found actual values in the product text. DO NOT include attributes with 'Not Specified', 'None', 'N/A', or empty values. If you cannot find a value for an attribute, simply don't include that attribute." if extract_additional else ""}
+# - Be precise and only extract information that is explicitly stated or clearly implied
+
+# Required Output Format:
+# {json.dumps(output_format, indent=2)}
+#         """
+
+#         payload = {
+#             "model": model,
+#             "messages": [
+#                 {
+#                     "role": "system",
+#                     "content": f"You are a precise attribute extraction model. Return ONLY valid JSON with {'mandatory and additional' if extract_additional else 'mandatory'} sections. No explanations, no markdown, no other text."
+#                 },
+#                 {"role": "user", "content": prompt}
+#             ],
+#             "temperature": 0.0,
+#             "max_tokens": 1500
+#         }
+
+#         headers = {
+#             "Authorization": f"Bearer {settings.GROQ_API_KEY}",
+#             "Content-Type": "application/json",
+#         }
+
+#         try:
+#             response = requests.post(
+#                 settings.GROQ_API_URL,
+#                 headers=headers,
+#                 json=payload,
+#                 timeout=30
+#             )
+#             response.raise_for_status()
+#             result_text = response.json()["choices"][0]["message"]["content"].strip()
+
+#             # Clean the response
+#             result_text = ProductAttributeService._clean_json_response(result_text)
+
+#             # Parse JSON
+#             parsed = json.loads(result_text)
+
+#             # Validate and restructure with source tracking
+#             parsed = ProductAttributeService._validate_response_structure(
+#                 parsed, mandatory_attrs, extract_additional, source_map
+#             )
+
+#             # Clean up and add source tracking to additional attributes in array format
+#             if extract_additional and "additional" in parsed:
+#                 cleaned_additional = {}
+#                 for k, v in parsed["additional"].items():
+#                     if v and v not in ["Not Specified", "None", "N/A", "", "not specified", "none", "n/a"]:
+#                         if not (isinstance(v, str) and v.lower() in ["not specified", "none", "n/a", ""]):
+#                             # Convert to array format if not already
+#                             if isinstance(v, list):
+#                                 cleaned_additional[k] = []
+#                                 for item in v:
+#                                     if isinstance(item, dict) and "value" in item:
+#                                         if "source" not in item:
+#                                             item["source"] = ProductAttributeService.find_value_source(
+#                                                 item["value"], source_map
+#                                             )
+#                                         cleaned_additional[k].append(item)
+#                                     else:
+#                                         source = ProductAttributeService.find_value_source(str(item), source_map)
+#                                         cleaned_additional[k].append({"value": str(item), "source": source})
+#                             else:
+#                                 source = ProductAttributeService.find_value_source(str(v), source_map)
+#                                 cleaned_additional[k] = [{"value": str(v), "source": source}]
+#                 parsed["additional"] = cleaned_additional
+
+#             # Calculate attribute relationships if using dynamic thresholds
+#             relationships = {}
+#             if use_dynamic_thresholds:
+#                 relationships = ProductAttributeService.calculate_attribute_relationships(
+#                     mandatory_attrs, product_text
+#                 )
+
+#             # Process attributes in order, allowing earlier ones to influence later ones
+#             extracted_so_far = {}
+#             for attr in mandatory_attrs.keys():
+#                 allow_multiple = attr in multiple
+
+#                 result = ProductAttributeService.normalize_against_product_text(
+#                     product_text=product_text,
+#                     mandatory_attrs={attr: mandatory_attrs[attr]},
+#                     source_map=source_map,
+#                     threshold_abs=threshold_abs,
+#                     margin=margin,
+#                     allow_multiple=allow_multiple,
+#                     extracted_attrs=extracted_so_far,
+#                     relationships=relationships,
+#                     use_dynamic_thresholds=use_dynamic_thresholds,
+#                     use_adaptive_margin=use_adaptive_margin,
+#                     use_semantic_clustering=use_semantic_clustering
+#                 )
+
+#                 # Result is already in array format from normalize_against_product_text
+#                 parsed["mandatory"][attr] = result[attr]
+#                 extracted_so_far[attr] = result[attr]
+
+#             return parsed
+
+#         except requests.exceptions.RequestException as e:
+#             return ProductAttributeService._create_error_response(
+#                 str(e), mandatory_attrs, extract_additional
+#             )
+#         except json.JSONDecodeError as e:
+#             return ProductAttributeService._create_error_response(
+#                 f"Invalid JSON: {str(e)}", mandatory_attrs, extract_additional, result_text
+#             )
+#         except Exception as e:
+#             return ProductAttributeService._create_error_response(
+#                 str(e), mandatory_attrs, extract_additional
+#             )
+
+#     @staticmethod
+#     def extract_attributes_batch(
+#         products: List[Dict],
+#         mandatory_attrs: Dict[str, List[str]],
+#         model: str = None,
+#         extract_additional: bool = True,
+#         process_image: bool = True,
+#         max_workers: int = 5,
+#         multiple: Optional[List[str]] = None,
+#         threshold_abs: float = 0.65,
+#         margin: float = 0.15,
+#         use_dynamic_thresholds: bool = True,
+#         use_adaptive_margin: bool = True,
+#         use_semantic_clustering: bool = True
+#     ) -> Dict:
+#         """Extract attributes for multiple products in parallel with enhanced multi-value selection and source tracking."""
+#         results = []
+#         successful = 0
+#         failed = 0
+        
+#         ocr_service = OCRService()
+
+#         if multiple is None:
+#             multiple = []
+
+#         def process_product(product_data):
+#             """Process a single product."""
+#             product_id = product_data.get('product_id', f"product_{len(results)}")
+            
+#             try:
+#                 # Process image if URL is provided
+#                 ocr_results = None
+#                 ocr_text = None
+                
+#                 if process_image and product_data.get('image_url'):
+#                     ocr_results = ocr_service.process_image(product_data['image_url'])
+                    
+#                     # Extract attributes from OCR
+#                     if ocr_results and ocr_results.get('detected_text'):
+#                         ocr_attrs = ProductAttributeService.extract_attributes_from_ocr(
+#                             ocr_results, model
+#                         )
+#                         ocr_results['extracted_attributes'] = ocr_attrs
+                        
+#                         # Format OCR text for combining with product text
+#                         ocr_text = "\n".join([
+#                             f"{item['text']} (confidence: {item['confidence']:.2f})"
+#                             for item in ocr_results['detected_text']
+#                         ])
+                
+#                 # Combine all product information with source tracking
+#                 product_text, source_map = ProductAttributeService.combine_product_text(
+#                     title=product_data.get('title'),
+#                     short_desc=product_data.get('short_desc'),
+#                     long_desc=product_data.get('long_desc'),
+#                     ocr_text=ocr_text
+#                 )
+                
+#                 # Extract attributes from combined text with enhanced features
+#                 result = ProductAttributeService.extract_attributes(
+#                     product_text=product_text,
+#                     mandatory_attrs=mandatory_attrs,
+#                     source_map=source_map,
+#                     model=model,
+#                     extract_additional=extract_additional,
+#                     multiple=multiple,
+#                     threshold_abs=threshold_abs,
+#                     margin=margin,
+#                     use_dynamic_thresholds=use_dynamic_thresholds,
+#                     use_adaptive_margin=use_adaptive_margin,
+#                     use_semantic_clustering=use_semantic_clustering
+#                 )
+                
+#                 result['product_id'] = product_id
+                
+#                 # Add OCR results if available
+#                 if ocr_results:
+#                     result['ocr_results'] = ocr_results
+                
+#                 # Check if extraction was successful
+#                 if 'error' not in result:
+#                     return result, True
+#                 else:
+#                     return result, False
+                    
+#             except Exception as e:
+#                 return {
+#                     'product_id': product_id,
+#                     'mandatory': {attr: [{"value": "Not Specified", "source": "error"}] for attr in mandatory_attrs.keys()},
+#                     'additional': {} if extract_additional else None,
+#                     'error': f"Processing error: {str(e)}"
+#                 }, False
+
+#         # Process products in parallel
+#         with ThreadPoolExecutor(max_workers=max_workers) as executor:
+#             future_to_product = {
+#                 executor.submit(process_product, product): product 
+#                 for product in products
+#             }
+            
+#             for future in as_completed(future_to_product):
+#                 try:
+#                     result, success = future.result()
+#                     results.append(result)
+#                     if success:
+#                         successful += 1
+#                     else:
+#                         failed += 1
+#                 except Exception as e:
+#                     failed += 1
+#                     results.append({
+#                         'product_id': 'unknown',
+#                         'mandatory': {attr: [{"value": "Not Specified", "source": "error"}] for attr in mandatory_attrs.keys()},
+#                         'additional': {} if extract_additional else None,
+#                         'error': f"Unexpected error: {str(e)}"
+#                     })
+
+#         return {
+#             'results': results,
+#             'total_products': len(products),
+#             'successful': successful,
+#             'failed': failed
+#         }
+
+#     @staticmethod
+#     def _clean_json_response(text: str) -> str:
+#         """Clean LLM response to extract valid JSON."""
+#         start_idx = text.find('{')
+#         end_idx = text.rfind('}')
+
+#         if start_idx != -1 and end_idx != -1:
+#             text = text[start_idx:end_idx + 1]
+
+#         if "```json" in text:
+#             text = text.split("```json")[1].split("```")[0].strip()
+#         elif "```" in text:
+#             text = text.split("```")[1].split("```")[0].strip()
+#             if text.startswith("json"):
+#                 text = text[4:].strip()
+
+#         return text
+
+#     @staticmethod
+#     def _validate_response_structure(
+#         parsed: dict,
+#         mandatory_attrs: Dict[str, List[str]],
+#         extract_additional: bool,
+#         source_map: Dict[str, str] = None
+#     ) -> dict:
+#         """Validate and fix the response structure, ensuring array format with source tracking."""
+#         if source_map is None:
+#             source_map = {}
+        
+#         expected_sections = ["mandatory"]
+#         if extract_additional:
+#             expected_sections.append("additional")
+
+#         if not all(section in parsed for section in expected_sections):
+#             if isinstance(parsed, dict):
+#                 mandatory_keys = set(mandatory_attrs.keys())
+#                 mandatory = {k: v for k, v in parsed.items() if k in mandatory_keys}
+#                 additional = {k: v for k, v in parsed.items() if k not in mandatory_keys}
+
+#                 result = {"mandatory": mandatory}
+#                 if extract_additional:
+#                     result["additional"] = additional
+#                 parsed = result
+#             else:
+#                 return ProductAttributeService._create_error_response(
+#                     "Invalid response structure",
+#                     mandatory_attrs,
+#                     extract_additional,
+#                     str(parsed)
+#                 )
+
+#         # Convert mandatory attributes to array format with source tracking
+#         if "mandatory" in parsed:
+#             converted_mandatory = {}
+#             for attr, value in parsed["mandatory"].items():
+#                 if isinstance(value, list):
+#                     # Already in array format, ensure each item has source
+#                     converted_mandatory[attr] = []
+#                     for item in value:
+#                         if isinstance(item, dict) and "value" in item:
+#                             # Already has proper structure
+#                             if "source" not in item:
+#                                 item["source"] = ProductAttributeService.find_value_source(
+#                                     item["value"], source_map
+#                                 )
+#                             converted_mandatory[attr].append(item)
+#                         else:
+#                             # Convert string to proper format
+#                             source = ProductAttributeService.find_value_source(str(item), source_map)
+#                             converted_mandatory[attr].append({"value": str(item), "source": source})
+#                 else:
+#                     # Single value - convert to array format
+#                     source = ProductAttributeService.find_value_source(str(value), source_map)
+#                     converted_mandatory[attr] = [{"value": str(value), "source": source}]
+            
+#             parsed["mandatory"] = converted_mandatory
+
+#         return parsed
+
+#     @staticmethod
+#     def _create_error_response(
+#         error: str,
+#         mandatory_attrs: Dict[str, List[str]],
+#         extract_additional: bool,
+#         raw_output: Optional[str] = None
+#     ) -> dict:
+#         """Create a standardized error response in array format."""
+#         response = {
+#             "mandatory": {attr: [{"value": "Not Specified", "source": "error"}] for attr in mandatory_attrs.keys()},
+#             "error": error
+#         }
+#         if extract_additional:
+#             response["additional"] = {}
+#         if raw_output:
+#             response["raw_output"] = raw_output
+#         return response
+
+
+
+
+
+
+
+
+
 # ==================== services.py ====================
 import requests
 import json
@@ -1044,6 +1900,33 @@ class ProductAttributeService:
         
         return "Not found"
 
+    @staticmethod
+    def format_visual_attributes(visual_attributes: Dict) -> Dict:
+        """
+        Convert visual attributes to array format with source tracking.
+        Source is always 'image' for visual attributes.
+        """
+        formatted = {}
+        
+        for key, value in visual_attributes.items():
+            if isinstance(value, list):
+                # Already a list (like color_palette)
+                formatted[key] = [{"value": str(item), "source": "image"} for item in value]
+            elif isinstance(value, dict):
+                # Nested dictionary - format recursively
+                nested_formatted = {}
+                for nested_key, nested_value in value.items():
+                    if isinstance(nested_value, list):
+                        nested_formatted[nested_key] = [{"value": str(item), "source": "image"} for item in nested_value]
+                    else:
+                        nested_formatted[nested_key] = [{"value": str(nested_value), "source": "image"}]
+                formatted[key] = nested_formatted
+            else:
+                # Single value
+                formatted[key] = [{"value": str(value), "source": "image"}]
+        
+        return formatted
+
     @staticmethod
     def extract_attributes_from_ocr(ocr_results: Dict, model: str = None) -> Dict:
         """Extract structured attributes from OCR text using LLM."""
@@ -1108,7 +1991,26 @@ If an attribute is not present, do not include it in the response.
             result_text = ProductAttributeService._clean_json_response(result_text)
             parsed = json.loads(result_text)
             
-            return parsed
+            # Convert to array format with source tracking
+            formatted_attributes = {}
+            for key, value in parsed.items():
+                if key == "error":
+                    continue
+                
+                # Handle nested dictionaries (like size)
+                if isinstance(value, dict):
+                    nested_formatted = {}
+                    for nested_key, nested_value in value.items():
+                        nested_formatted[nested_key] = [{"value": str(nested_value), "source": "image"}]
+                    formatted_attributes[key] = nested_formatted
+                elif isinstance(value, list):
+                    # Already a list, convert each item
+                    formatted_attributes[key] = [{"value": str(item), "source": "image"} for item in value]
+                else:
+                    # Single value
+                    formatted_attributes[key] = [{"value": str(value), "source": "image"}]
+            
+            return formatted_attributes
         except Exception as e:
             return {"error": f"Failed to extract attributes from OCR: {str(e)}"}
 
@@ -1650,7 +2552,7 @@ Required Output Format:
                 
                 result['product_id'] = product_id
                 
-                # Add OCR results if available
+                # Add OCR results if available (already in correct format)
                 if ocr_results:
                     result['ocr_results'] = ocr_results
                 

+ 272 - 18
attr_extraction/views.py

@@ -132,6 +132,174 @@ class ExtractProductAttributesView(APIView):
         return Response(result, status=status.HTTP_200_OK)
 
 
+# class BatchExtractProductAttributesView(APIView):
+#     """
+#     API endpoint to extract product attributes for multiple products in batch.
+#     Uses item-specific mandatory_attrs with source tracking.
+#     Returns attributes in array format: [{"value": "...", "source": "..."}]
+#     Includes OCR and Visual Processing results.
+#     """
+
+#     def post(self, request):
+#         serializer = BatchProductRequestSerializer(data=request.data)
+#         if not serializer.is_valid():
+#             return Response({"error": serializer.errors}, status=status.HTTP_400_BAD_REQUEST)
+
+#         validated_data = serializer.validated_data
+        
+#         # DEBUG: Print what we received
+#         print("\n" + "="*80)
+#         print("BATCH REQUEST - RECEIVED DATA")
+#         print("="*80)
+#         print(f"Raw request data keys: {request.data.keys()}")
+#         print(f"Multiple field in request: {request.data.get('multiple')}")
+#         print(f"Validated multiple field: {validated_data.get('multiple')}")
+#         print("="*80 + "\n")
+        
+#         # Get batch-level settings
+#         product_list = validated_data.get("products", [])
+#         model = validated_data.get("model")
+#         extract_additional = validated_data.get("extract_additional", True)
+#         process_image = validated_data.get("process_image", True)
+#         multiple = validated_data.get("multiple", [])
+#         threshold_abs = validated_data.get("threshold_abs", 0.65)
+#         margin = validated_data.get("margin", 0.15)
+#         use_dynamic_thresholds = validated_data.get("use_dynamic_thresholds", True)
+#         use_adaptive_margin = validated_data.get("use_adaptive_margin", True)
+#         use_semantic_clustering = validated_data.get("use_semantic_clustering", True)
+        
+#         # DEBUG: Print extracted settings
+#         print(f"Extracted multiple parameter: {multiple}")
+#         print(f"Type: {type(multiple)}")
+        
+#         # Extract all item_ids to query the database efficiently
+#         item_ids = [p['item_id'] for p in product_list] 
+        
+#         # Fetch all products in one query
+#         products_queryset = Product.objects.filter(item_id__in=item_ids)
+        
+#         # Create a dictionary for easy lookup: item_id -> Product object
+#         product_map = {product.item_id: product for product in products_queryset}
+#         found_ids = set(product_map.keys())
+        
+#         results = []
+#         successful = 0
+#         failed = 0
+
+#         for product_entry in product_list:
+#             item_id = product_entry['item_id']
+#             # Get item-specific mandatory attributes
+#             mandatory_attrs = product_entry['mandatory_attrs'] 
+
+#             if item_id not in found_ids:
+#                 failed += 1
+#                 results.append({
+#                     "product_id": item_id,
+#                     "error": "Product not found in database"
+#                 })
+#                 continue
+
+#             product = product_map[item_id]
+            
+#             try: 
+#                 title = product.product_name
+#                 short_desc = product.product_short_description
+#                 long_desc = product.product_long_description
+#                 image_url = product.image_path
+#                 # image_url = "https://images.unsplash.com/photo-1595777457583-95e059d581b8"
+#                 ocr_results = None
+#                 ocr_text = None
+#                 visual_results = None
+
+#                 # Image Processing Logic
+#                 if process_image and image_url:
+#                     # OCR Processing
+#                     ocr_service = OCRService()
+#                     ocr_results = ocr_service.process_image(image_url)
+#                     print(f"OCR results for {item_id}: {ocr_results}")
+                    
+#                     if ocr_results and ocr_results.get("detected_text"):
+#                         ocr_attrs = ProductAttributeService.extract_attributes_from_ocr(
+#                             ocr_results, model
+#                         )
+#                         ocr_results["extracted_attributes"] = ocr_attrs
+#                         ocr_text = "\n".join([
+#                             f"{item['text']} (confidence: {item['confidence']:.2f})"
+#                             for item in ocr_results["detected_text"]
+#                         ])
+                    
+#                     # Visual Processing
+#                     visual_service = VisualProcessingService()
+#                     product_type_hint = product.product_type if hasattr(product, 'product_type') else None
+#                     visual_results = visual_service.process_image(image_url, product_type_hint)
+#                     print(f"Visual results for {item_id}: {visual_results.get('visual_attributes', {})}")
+
+#                 # Combine product text with source tracking
+#                 product_text, source_map = ProductAttributeService.combine_product_text(
+#                     title=title,
+#                     short_desc=short_desc,
+#                     long_desc=long_desc,
+#                     ocr_text=ocr_text
+#                 )
+
+#                 # DEBUG: Print before extraction
+#                 print(f"\n>>> Extracting for product {item_id}")
+#                 print(f"    Passing multiple: {multiple}")
+
+#                 # Attribute Extraction with source tracking (returns array format)
+#                 extracted = ProductAttributeService.extract_attributes(
+#                     product_text=product_text,
+#                     mandatory_attrs=mandatory_attrs,
+#                     source_map=source_map,
+#                     model=model,
+#                     extract_additional=extract_additional,
+#                     multiple=multiple,
+#                     threshold_abs=threshold_abs,
+#                     margin=margin,
+#                     use_dynamic_thresholds=use_dynamic_thresholds,
+#                     use_adaptive_margin=use_adaptive_margin,
+#                     use_semantic_clustering=use_semantic_clustering
+#                 )
+
+#                 result = {
+#                     "product_id": product.item_id,
+#                     "mandatory": extracted.get("mandatory", {}),
+#                     "additional": extracted.get("additional", {}),
+#                 }
+
+#                 # Attach OCR results if available
+#                 if ocr_results:
+#                     result["ocr_results"] = ocr_results
+                
+#                 # Attach Visual Processing results if available
+#                 if visual_results:
+#                     result["visual_results"] = visual_results
+
+#                 results.append(result)
+#                 successful += 1
+
+#             except Exception as e:
+#                 failed += 1
+#                 results.append({
+#                     "product_id": item_id,
+#                     "error": str(e)
+#                 })
+
+#         batch_result = {
+#             "results": results,
+#             "total_products": len(product_list),
+#             "successful": successful,
+#             "failed": failed
+#         }
+
+#         response_serializer = BatchProductResponseSerializer(data=batch_result)
+#         if response_serializer.is_valid():
+#             return Response(response_serializer.data, status=status.HTTP_200_OK)
+
+#         return Response(batch_result, status=status.HTTP_200_OK)
+
+
+
 class BatchExtractProductAttributesView(APIView):
     """
     API endpoint to extract product attributes for multiple products in batch.
@@ -233,6 +401,12 @@ class BatchExtractProductAttributesView(APIView):
                     product_type_hint = product.product_type if hasattr(product, 'product_type') else None
                     visual_results = visual_service.process_image(image_url, product_type_hint)
                     print(f"Visual results for {item_id}: {visual_results.get('visual_attributes', {})}")
+                    
+                    # Format visual attributes to array format with source tracking
+                    if visual_results and visual_results.get('visual_attributes'):
+                        visual_results['visual_attributes'] = ProductAttributeService.format_visual_attributes(
+                            visual_results['visual_attributes']
+                        )
 
                 # Combine product text with source tracking
                 product_text, source_map = ProductAttributeService.combine_product_text(
@@ -301,7 +475,6 @@ class BatchExtractProductAttributesView(APIView):
 
 
 
-
 # class ExtractProductAttributesView(APIView):
 #     """
 #     API endpoint to extract product attributes for a single product by item_id.
@@ -554,9 +727,79 @@ class ProductListView(APIView):
         return Response(serializer.data, status=status.HTTP_200_OK)
 
 
+# class ProductUploadExcelView(APIView):
+#     """
+#     POST API to upload an Excel file and add data to Product model (skip duplicates)
+#     """
+#     parser_classes = (MultiPartParser, FormParser)
+
+#     def post(self, request, *args, **kwargs):
+#         file_obj = request.FILES.get('file')
+#         if not file_obj:
+#             return Response({'error': 'No file provided'}, status=status.HTTP_400_BAD_REQUEST)
+
+#         try:
+#             df = pd.read_excel(file_obj)
+#             df.columns = [c.strip().lower().replace(' ', '_') for c in df.columns]
+
+#             expected_cols = {
+#                 'item_id',
+#                 'product_name',
+#                 'product_long_description',
+#                 'product_short_description',
+#                 'product_type',
+#                 'image_path'
+#             }
+
+#             if not expected_cols.issubset(df.columns):
+#                 return Response({
+#                     'error': 'Missing required columns',
+#                     'required_columns': list(expected_cols)
+#                 }, status=status.HTTP_400_BAD_REQUEST)
+
+#             created_count = 0
+#             skipped_count = 0
+
+#             for _, row in df.iterrows():
+#                 item_id = row.get('item_id', '')
+
+#                 # Check if this item already exists
+#                 if Product.objects.filter(item_id=item_id).exists():
+#                     skipped_count += 1
+#                     continue
+
+#                 Product.objects.create(
+#                     item_id=item_id,
+#                     product_name=row.get('product_name', ''),
+#                     product_long_description=row.get('product_long_description', ''),
+#                     product_short_description=row.get('product_short_description', ''),
+#                     product_type=row.get('product_type', ''),
+#                     image_path=row.get('image_path', ''),
+#                 )
+#                 created_count += 1
+
+#             return Response({
+#                 'message': f'Successfully uploaded {created_count} products.',
+#                 'skipped': f'Skipped {skipped_count} duplicates.'
+#             }, status=status.HTTP_201_CREATED)
+
+#         except Exception as e:
+#             return Response({'error': str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
+
+
+from rest_framework.views import APIView
+from rest_framework.response import Response
+from rest_framework import status
+from rest_framework.parsers import MultiPartParser, FormParser
+import pandas as pd
+from .models import Product
+
+
 class ProductUploadExcelView(APIView):
     """
-    POST API to upload an Excel file and add data to Product model (skip duplicates)
+    POST API to upload an Excel file and add/update data in Product model.
+    - Creates new records if they don't exist.
+    - Updates existing ones (e.g., when image_path or other fields change).
     """
     parser_classes = (MultiPartParser, FormParser)
 
@@ -566,6 +809,7 @@ class ProductUploadExcelView(APIView):
             return Response({'error': 'No file provided'}, status=status.HTTP_400_BAD_REQUEST)
 
         try:
+            # Read Excel into DataFrame
             df = pd.read_excel(file_obj)
             df.columns = [c.strip().lower().replace(' ', '_') for c in df.columns]
 
@@ -578,6 +822,7 @@ class ProductUploadExcelView(APIView):
                 'image_path'
             }
 
+            # Check required columns
             if not expected_cols.issubset(df.columns):
                 return Response({
                     'error': 'Missing required columns',
@@ -585,35 +830,44 @@ class ProductUploadExcelView(APIView):
                 }, status=status.HTTP_400_BAD_REQUEST)
 
             created_count = 0
-            skipped_count = 0
+            updated_count = 0
 
+            # Loop through rows and update or create
             for _, row in df.iterrows():
-                item_id = row.get('item_id', '')
-
-                # Check if this item already exists
-                if Product.objects.filter(item_id=item_id).exists():
-                    skipped_count += 1
-                    continue
+                item_id = str(row.get('item_id', '')).strip()
+                if not item_id:
+                    continue  # Skip rows without an item_id
+
+                defaults = {
+                    'product_name': row.get('product_name', ''),
+                    'product_long_description': row.get('product_long_description', ''),
+                    'product_short_description': row.get('product_short_description', ''),
+                    'product_type': row.get('product_type', ''),
+                    'image_path': row.get('image_path', ''),
+                }
 
-                Product.objects.create(
+                obj, created = Product.objects.update_or_create(
                     item_id=item_id,
-                    product_name=row.get('product_name', ''),
-                    product_long_description=row.get('product_long_description', ''),
-                    product_short_description=row.get('product_short_description', ''),
-                    product_type=row.get('product_type', ''),
-                    image_path=row.get('image_path', ''),
+                    defaults=defaults
                 )
-                created_count += 1
+
+                if created:
+                    created_count += 1
+                else:
+                    updated_count += 1
 
             return Response({
-                'message': f'Successfully uploaded {created_count} products.',
-                'skipped': f'Skipped {skipped_count} duplicates.'
+                'message': f'Upload successful.',
+                'created': f'{created_count} new records added.',
+                'updated': f'{updated_count} existing records updated.'
             }, status=status.HTTP_201_CREATED)
 
         except Exception as e:
             return Response({'error': str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
 
 
+
+
 class ProductAttributesUploadView(APIView):
     """
     POST API to upload an Excel file and add mandatory/additional attributes

BIN
db.sqlite3


BIN
media/products/3217373735.jpg


BIN
media/products/55inchtv.jpg


BIN
media/products/applewatch.jpg


BIN
media/products/camera.webp


BIN
media/products/casiogshock.avif


BIN
media/products/chair.avif


BIN
media/products/dewalt.webp


BIN
media/products/dyson.jpg


BIN
media/products/image_QJg7FgP.jpg


BIN
media/products/intra.jpg


BIN
media/products/jacket-2821961_960_720.jpg


BIN
media/products/led.avif


BIN
media/products/levi_test_ocr2_RHKlXGJ.jpg


BIN
media/products/mixer.jpg


BIN
media/products/socket.jpg


BIN
media/products/sonyheadphone.jpg


BIN
media/products/sonyheadphone_oLfN2fy.jpg


BIN
media/products/staple.jpg