services.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882
  1. # ==================== services.py ====================
  2. import requests
  3. import json
  4. from typing import Dict, List, Optional, Tuple
  5. from django.conf import settings
  6. from concurrent.futures import ThreadPoolExecutor, as_completed
  7. from sentence_transformers import SentenceTransformer, util
  8. import numpy as np
  9. from .ocr_service import OCRService
  10. # Initialize embedding model for normalization
  11. model_embedder = SentenceTransformer("all-MiniLM-L6-v2")
  12. class ProductAttributeService:
  13. """Service class for extracting product attributes using Groq LLM."""
  14. @staticmethod
  15. def combine_product_text(
  16. title: Optional[str] = None,
  17. short_desc: Optional[str] = None,
  18. long_desc: Optional[str] = None,
  19. ocr_text: Optional[str] = None
  20. ) -> Tuple[str, Dict[str, str]]:
  21. """
  22. Combine product metadata into a single text block.
  23. Returns: (combined_text, source_map) where source_map tracks which text came from where
  24. """
  25. parts = []
  26. source_map = {}
  27. if title:
  28. title_str = str(title).strip()
  29. parts.append(f"Title: {title_str}")
  30. source_map['title'] = title_str
  31. if short_desc:
  32. short_str = str(short_desc).strip()
  33. parts.append(f"Description: {short_str}")
  34. source_map['short_desc'] = short_str
  35. if long_desc:
  36. long_str = str(long_desc).strip()
  37. parts.append(f"Details: {long_str}")
  38. source_map['long_desc'] = long_str
  39. if ocr_text:
  40. parts.append(f"OCR Text: {ocr_text}")
  41. source_map['ocr_text'] = ocr_text
  42. combined = "\n".join(parts).strip()
  43. if not combined:
  44. return "No product information available", {}
  45. return combined, source_map
  46. @staticmethod
  47. def find_value_source(value: str, source_map: Dict[str, str]) -> str:
  48. """
  49. Find which source(s) contain the given value.
  50. Returns the source name(s) where the value appears.
  51. """
  52. value_lower = value.lower()
  53. # Split value into tokens for better matching
  54. value_tokens = set(value_lower.replace("-", " ").split())
  55. sources_found = []
  56. source_scores = {}
  57. for source_name, source_text in source_map.items():
  58. source_lower = source_text.lower()
  59. # Check for exact phrase match first
  60. if value_lower in source_lower:
  61. source_scores[source_name] = 1.0
  62. continue
  63. # Check for token matches
  64. token_matches = sum(1 for token in value_tokens if token in source_lower)
  65. if token_matches > 0:
  66. source_scores[source_name] = token_matches / len(value_tokens)
  67. # Return source with highest score, or all sources if multiple have same score
  68. if source_scores:
  69. max_score = max(source_scores.values())
  70. sources_found = [s for s, score in source_scores.items() if score == max_score]
  71. # Prioritize: title > short_desc > long_desc > ocr_text
  72. priority = ['title', 'short_desc', 'long_desc', 'ocr_text']
  73. for p in priority:
  74. if p in sources_found:
  75. return p
  76. return sources_found[0] if sources_found else "Not found"
  77. return "Not found"
  78. @staticmethod
  79. def extract_attributes_from_ocr(ocr_results: Dict, model: str = None) -> Dict:
  80. """Extract structured attributes from OCR text using LLM."""
  81. if model is None:
  82. model = settings.SUPPORTED_MODELS[0]
  83. detected_text = ocr_results.get('detected_text', [])
  84. if not detected_text:
  85. return {}
  86. # Format OCR text for prompt
  87. ocr_text = "\n".join([f"Text: {item['text']}, Confidence: {item['confidence']:.2f}"
  88. for item in detected_text])
  89. prompt = f"""
  90. You are an AI model that extracts structured attributes from OCR text detected on product images.
  91. Given the OCR detections below, infer the possible product attributes and return them as a clean JSON object.
  92. OCR Text:
  93. {ocr_text}
  94. Extract relevant attributes like:
  95. - brand
  96. - model_number
  97. - size (waist_size, length, etc.)
  98. - collection
  99. - any other relevant product information
  100. Return a JSON object with only the attributes you can confidently identify.
  101. If an attribute is not present, do not include it in the response.
  102. """
  103. payload = {
  104. "model": model,
  105. "messages": [
  106. {
  107. "role": "system",
  108. "content": "You are a helpful AI that extracts structured data from OCR output. Return only valid JSON."
  109. },
  110. {"role": "user", "content": prompt}
  111. ],
  112. "temperature": 0.2,
  113. "max_tokens": 500
  114. }
  115. headers = {
  116. "Authorization": f"Bearer {settings.GROQ_API_KEY}",
  117. "Content-Type": "application/json",
  118. }
  119. try:
  120. response = requests.post(
  121. settings.GROQ_API_URL,
  122. headers=headers,
  123. json=payload,
  124. timeout=30
  125. )
  126. response.raise_for_status()
  127. result_text = response.json()["choices"][0]["message"]["content"].strip()
  128. # Clean and parse JSON
  129. result_text = ProductAttributeService._clean_json_response(result_text)
  130. parsed = json.loads(result_text)
  131. return parsed
  132. except Exception as e:
  133. return {"error": f"Failed to extract attributes from OCR: {str(e)}"}
  134. @staticmethod
  135. def calculate_attribute_relationships(
  136. mandatory_attrs: Dict[str, List[str]],
  137. product_text: str
  138. ) -> Dict[str, float]:
  139. """
  140. Calculate semantic relationships between attribute values across different attributes.
  141. Returns a matrix of cross-attribute value similarities.
  142. """
  143. pt_emb = model_embedder.encode(product_text, convert_to_tensor=True)
  144. # Calculate similarities between all attribute values and product text
  145. attr_scores = {}
  146. for attr, values in mandatory_attrs.items():
  147. attr_scores[attr] = {}
  148. for val in values:
  149. contexts = [val, f"for {val}", f"use in {val}", f"suitable for {val}"]
  150. ctx_embs = [model_embedder.encode(c, convert_to_tensor=True) for c in contexts]
  151. sem_sim = max(float(util.cos_sim(pt_emb, ce).item()) for ce in ctx_embs)
  152. attr_scores[attr][val] = sem_sim
  153. # Calculate cross-attribute value relationships
  154. relationships = {}
  155. attr_list = list(mandatory_attrs.keys())
  156. for i, attr1 in enumerate(attr_list):
  157. for attr2 in attr_list[i+1:]:
  158. # Calculate pairwise similarities between values of different attributes
  159. for val1 in mandatory_attrs[attr1]:
  160. for val2 in mandatory_attrs[attr2]:
  161. emb1 = model_embedder.encode(val1, convert_to_tensor=True)
  162. emb2 = model_embedder.encode(val2, convert_to_tensor=True)
  163. sim = float(util.cos_sim(emb1, emb2).item())
  164. # Store bidirectional relationships
  165. key1 = f"{attr1}:{val1}->{attr2}:{val2}"
  166. key2 = f"{attr2}:{val2}->{attr1}:{val1}"
  167. relationships[key1] = sim
  168. relationships[key2] = sim
  169. return relationships
  170. @staticmethod
  171. def calculate_value_clusters(
  172. values: List[str],
  173. scores: List[Tuple[str, float]],
  174. cluster_threshold: float = 0.4
  175. ) -> List[List[str]]:
  176. """
  177. Group values into semantic clusters based on their similarity to each other.
  178. Returns clusters of related values.
  179. """
  180. if len(values) <= 1:
  181. return [[val] for val, _ in scores]
  182. # Get embeddings for all values
  183. embeddings = [model_embedder.encode(val, convert_to_tensor=True) for val in values]
  184. # Calculate pairwise similarities
  185. similarity_matrix = np.zeros((len(values), len(values)))
  186. for i in range(len(values)):
  187. for j in range(i+1, len(values)):
  188. sim = float(util.cos_sim(embeddings[i], embeddings[j]).item())
  189. similarity_matrix[i][j] = sim
  190. similarity_matrix[j][i] = sim
  191. # Simple clustering: group values with high similarity
  192. clusters = []
  193. visited = set()
  194. for i, (val, score) in enumerate(scores):
  195. if i in visited:
  196. continue
  197. cluster = [val]
  198. visited.add(i)
  199. # Find similar values
  200. for j in range(len(values)):
  201. if j not in visited and similarity_matrix[i][j] >= cluster_threshold:
  202. cluster.append(values[j])
  203. visited.add(j)
  204. clusters.append(cluster)
  205. return clusters
  206. @staticmethod
  207. def get_dynamic_threshold(
  208. attr: str,
  209. val: str,
  210. base_score: float,
  211. extracted_attrs: Dict[str, List[Dict[str, str]]],
  212. relationships: Dict[str, float],
  213. mandatory_attrs: Dict[str, List[str]],
  214. base_threshold: float = 0.65,
  215. boost_factor: float = 0.15
  216. ) -> float:
  217. """
  218. Calculate dynamic threshold based on relationships with already-extracted attributes.
  219. """
  220. threshold = base_threshold
  221. # Check relationships with already extracted attributes
  222. max_relationship = 0.0
  223. for other_attr, other_values_list in extracted_attrs.items():
  224. if other_attr == attr:
  225. continue
  226. for other_val_dict in other_values_list:
  227. other_val = other_val_dict['value']
  228. key = f"{attr}:{val}->{other_attr}:{other_val}"
  229. if key in relationships:
  230. max_relationship = max(max_relationship, relationships[key])
  231. # If strong relationship exists, lower threshold
  232. if max_relationship > 0.6:
  233. threshold = base_threshold - (boost_factor * max_relationship)
  234. return max(0.3, threshold)
  235. @staticmethod
  236. def get_adaptive_margin(
  237. scores: List[Tuple[str, float]],
  238. base_margin: float = 0.15,
  239. max_margin: float = 0.22
  240. ) -> float:
  241. """
  242. Calculate adaptive margin based on score distribution.
  243. """
  244. if len(scores) < 2:
  245. return base_margin
  246. score_values = [s for _, s in scores]
  247. best_score = score_values[0]
  248. # If best score is very low, use adaptive margin but be more conservative
  249. if best_score < 0.5:
  250. # Calculate score spread in top 3-4 scores only (more selective)
  251. top_scores = score_values[:min(4, len(score_values))]
  252. score_range = max(top_scores) - min(top_scores)
  253. # Very controlled margin increase
  254. if score_range < 0.30:
  255. # Much more conservative scaling
  256. score_factor = (0.5 - best_score) * 0.35
  257. adaptive = base_margin + score_factor + (0.30 - score_range) * 0.2
  258. return min(adaptive, max_margin)
  259. return base_margin
  260. @staticmethod
  261. def _lexical_evidence(product_text: str, label: str) -> float:
  262. """Calculate lexical overlap between product text and label."""
  263. pt = product_text.lower()
  264. tokens = [t for t in label.lower().replace("-", " ").split() if t]
  265. if not tokens:
  266. return 0.0
  267. hits = sum(1 for t in tokens if t in pt)
  268. return hits / len(tokens)
  269. @staticmethod
  270. def normalize_against_product_text(
  271. product_text: str,
  272. mandatory_attrs: Dict[str, List[str]],
  273. source_map: Dict[str, str],
  274. threshold_abs: float = 0.65,
  275. margin: float = 0.15,
  276. allow_multiple: bool = False,
  277. sem_weight: float = 0.8,
  278. lex_weight: float = 0.2,
  279. extracted_attrs: Optional[Dict[str, List[Dict[str, str]]]] = None,
  280. relationships: Optional[Dict[str, float]] = None,
  281. use_dynamic_thresholds: bool = True,
  282. use_adaptive_margin: bool = True,
  283. use_semantic_clustering: bool = True
  284. ) -> dict:
  285. """
  286. Score each allowed value against the product_text with dynamic thresholds.
  287. Returns dict with values in array format: [{"value": "...", "source": "..."}]
  288. """
  289. if extracted_attrs is None:
  290. extracted_attrs = {}
  291. if relationships is None:
  292. relationships = {}
  293. pt_emb = model_embedder.encode(product_text, convert_to_tensor=True)
  294. extracted = {}
  295. for attr, allowed_values in mandatory_attrs.items():
  296. scores: List[Tuple[str, float]] = []
  297. for val in allowed_values:
  298. contexts = [val, f"for {val}", f"use in {val}", f"suitable for {val}", f"{val} room"]
  299. ctx_embs = [model_embedder.encode(c, convert_to_tensor=True) for c in contexts]
  300. sem_sim = max(float(util.cos_sim(pt_emb, ce).item()) for ce in ctx_embs)
  301. lex_score = ProductAttributeService._lexical_evidence(product_text, val)
  302. final_score = sem_weight * sem_sim + lex_weight * lex_score
  303. scores.append((val, final_score))
  304. scores.sort(key=lambda x: x[1], reverse=True)
  305. best_val, best_score = scores[0]
  306. # DEBUG: Print scores
  307. print(f"\n{'='*80}")
  308. print(f"Attribute: {attr}")
  309. print(f"{'='*80}")
  310. print(f"Top 5 Scores:")
  311. for i, (val, sc) in enumerate(scores[:5]):
  312. print(f" {i+1}. {val}: {sc:.4f}")
  313. print(f"\nBest: {best_val} (score: {best_score:.4f})")
  314. print(f"Base Threshold: {threshold_abs}")
  315. print(f"Base Margin: {margin}")
  316. # Calculate adaptive margin if enabled
  317. effective_margin = margin
  318. if allow_multiple and use_adaptive_margin:
  319. effective_margin = ProductAttributeService.get_adaptive_margin(scores, margin)
  320. print(f"Adaptive Margin: {effective_margin}")
  321. if not allow_multiple:
  322. source = ProductAttributeService.find_value_source(best_val, source_map)
  323. extracted[attr] = [{"value": best_val, "source": source}]
  324. print(f"Single value mode - Selected: {best_val}")
  325. else:
  326. print(f"\nMultiple value mode enabled")
  327. candidates = [best_val]
  328. use_base_threshold = best_score >= threshold_abs
  329. print(f"Use base threshold: {use_base_threshold} (best_score >= {threshold_abs})")
  330. # Get semantic clusters if enabled
  331. clusters = []
  332. if use_semantic_clustering:
  333. clusters = ProductAttributeService.calculate_value_clusters(
  334. allowed_values, scores, cluster_threshold=0.4
  335. )
  336. best_cluster = next((c for c in clusters if best_val in c), [best_val])
  337. print(f"\nSemantic Clusters:")
  338. for idx, cluster in enumerate(clusters):
  339. marker = " <- BEST" if best_val in cluster else ""
  340. print(f" Cluster {idx+1}: {cluster}{marker}")
  341. print(f"\nEvaluating additional candidates:")
  342. for val, sc in scores[1:]:
  343. # Calculate dynamic threshold for this value
  344. if use_dynamic_thresholds and extracted_attrs:
  345. dynamic_thresh = ProductAttributeService.get_dynamic_threshold(
  346. attr, val, sc, extracted_attrs, relationships,
  347. mandatory_attrs, threshold_abs
  348. )
  349. else:
  350. dynamic_thresh = threshold_abs
  351. within_margin = (best_score - sc) <= effective_margin
  352. above_threshold = sc >= dynamic_thresh
  353. # Check if in same semantic cluster as best value
  354. in_cluster = False
  355. if use_semantic_clustering and clusters:
  356. in_cluster = any(best_val in c and val in c for c in clusters)
  357. # DEBUG: Print candidate evaluation
  358. print(f"\n Candidate: {val}")
  359. print(f" Score: {sc:.4f}")
  360. print(f" Margin diff: {best_score - sc:.4f} (within_margin: {within_margin})")
  361. print(f" Dynamic threshold: {dynamic_thresh:.4f} (above_threshold: {above_threshold})")
  362. print(f" In cluster with best: {in_cluster}")
  363. # MODIFIED LOGIC: More permissive for multi-value extraction
  364. # BALANCED LOGIC: Smart multi-value extraction
  365. include_candidate = False
  366. reason = ""
  367. # Calculate score ratio (how close to best score)
  368. score_ratio = sc / best_score if best_score > 0 else 0
  369. if use_base_threshold:
  370. # Best score is good (>= threshold), be selective
  371. if above_threshold and within_margin:
  372. include_candidate = True
  373. reason = "above threshold AND within margin"
  374. elif in_cluster and within_margin and score_ratio >= 0.75:
  375. # Only include cluster members if they're close in score
  376. include_candidate = True
  377. reason = "in cluster AND within margin with good score ratio"
  378. else:
  379. # Best score is low (< threshold), be more careful
  380. # Only include candidates that are very close to the best score
  381. if within_margin and score_ratio >= 0.80:
  382. # Must be at least 80% of best score
  383. include_candidate = True
  384. reason = "within margin with strong score ratio"
  385. elif in_cluster and within_margin and score_ratio >= 0.85:
  386. # Cluster members need even higher ratio when best score is low
  387. include_candidate = True
  388. reason = "in cluster with tight margin and high score ratio"
  389. # Additional filter: Never include "Not Specified" if we have better options
  390. if include_candidate and val.lower() in ["not specified", "not_specified", "unspecified"]:
  391. # Only include "Not Specified" if it's the best value AND no other candidates
  392. if len(candidates) > 1 or (sc < best_score * 0.95):
  393. include_candidate = False
  394. reason = "excluded: 'Not Specified' with better alternatives"
  395. if include_candidate:
  396. candidates.append(val)
  397. print(f" ✓ INCLUDED - Reason: {reason}")
  398. else:
  399. print(f" ✗ EXCLUDED")
  400. # Map each candidate to its source and create array format
  401. extracted[attr] = []
  402. print(f"\nFinal candidates for {attr}: {candidates}")
  403. for candidate in candidates:
  404. source = ProductAttributeService.find_value_source(candidate, source_map)
  405. extracted[attr].append({"value": candidate, "source": source})
  406. print(f" - {candidate} (source: {source})")
  407. print(f"{'='*80}\n")
  408. return extracted
  409. @staticmethod
  410. def extract_attributes(
  411. product_text: str,
  412. mandatory_attrs: Dict[str, List[str]],
  413. source_map: Dict[str, str] = None,
  414. model: str = None,
  415. extract_additional: bool = True,
  416. multiple: Optional[List[str]] = None,
  417. threshold_abs: float = 0.65,
  418. margin: float = 0.15,
  419. use_dynamic_thresholds: bool = True,
  420. use_adaptive_margin: bool = True,
  421. use_semantic_clustering: bool = True
  422. ) -> dict:
  423. """
  424. Use Groq LLM to extract attributes from any product type with enhanced multi-value selection.
  425. Now returns values in array format: [{"value": "...", "source": "..."}]
  426. """
  427. if model is None:
  428. model = settings.SUPPORTED_MODELS[0]
  429. if multiple is None:
  430. multiple = []
  431. if source_map is None:
  432. source_map = {}
  433. # DEBUG: Print what we received
  434. print("\n" + "="*80)
  435. print("EXTRACT ATTRIBUTES - INPUT PARAMETERS")
  436. print("="*80)
  437. print(f"Product text length: {len(product_text)}")
  438. print(f"Mandatory attrs: {list(mandatory_attrs.keys())}")
  439. print(f"Multiple mode for: {multiple}")
  440. print(f"Threshold: {threshold_abs}, Margin: {margin}")
  441. print(f"Dynamic thresholds: {use_dynamic_thresholds}")
  442. print(f"Adaptive margin: {use_adaptive_margin}")
  443. print(f"Semantic clustering: {use_semantic_clustering}")
  444. print("="*80 + "\n")
  445. # Check if product text is empty or minimal
  446. if not product_text or product_text == "No product information available":
  447. return ProductAttributeService._create_error_response(
  448. "No product information provided",
  449. mandatory_attrs,
  450. extract_additional
  451. )
  452. # Create structured prompt for mandatory attributes
  453. mandatory_attr_list = []
  454. for attr_name, allowed_values in mandatory_attrs.items():
  455. mandatory_attr_list.append(f"{attr_name}: {', '.join(allowed_values)}")
  456. mandatory_attr_text = "\n".join(mandatory_attr_list)
  457. additional_instruction = ""
  458. if extract_additional:
  459. additional_instruction = """
  460. 2. Extract ADDITIONAL attributes: Identify any other relevant attributes from the product text
  461. that are NOT in the mandatory list. Only include attributes where you can find actual values
  462. in the product text. Do NOT include attributes with "Not Specified" or empty values.
  463. Examples of attributes to look for (only if present): Brand, Material, Size, Color, Dimensions,
  464. Weight, Features, Style, Theme, Pattern, Finish, Care Instructions, etc."""
  465. output_format = {
  466. "mandatory": {attr: "value or list of values" for attr in mandatory_attrs.keys()},
  467. }
  468. if extract_additional:
  469. output_format["additional"] = {
  470. "example_attribute_1": "actual value found",
  471. "example_attribute_2": "actual value found"
  472. }
  473. output_format["additional"]["_note"] = "Only include attributes with actual values found in text"
  474. prompt = f"""
  475. You are an intelligent product attribute extractor that works with ANY product type.
  476. TASK:
  477. 1. Extract MANDATORY attributes: For each mandatory attribute, select the most appropriate value(s)
  478. from the provided list. Choose the value(s) that best match the product description.
  479. {additional_instruction}
  480. Product Text:
  481. {product_text}
  482. Mandatory Attribute Lists (MUST select from these allowed values):
  483. {mandatory_attr_text}
  484. CRITICAL INSTRUCTIONS:
  485. - Return ONLY valid JSON, nothing else
  486. - No explanations, no markdown, no text before or after the JSON
  487. - For mandatory attributes, choose the value(s) from the provided list that best match
  488. - If a mandatory attribute cannot be determined from the product text, use "Not Specified"
  489. - Prefer exact matches from the allowed values list over generic synonyms
  490. - If multiple values are plausible, you MAY return more than one
  491. {f"- For additional attributes: ONLY include attributes where you found actual values in the product text. DO NOT include attributes with 'Not Specified', 'None', 'N/A', or empty values. If you cannot find a value for an attribute, simply don't include that attribute." if extract_additional else ""}
  492. - Be precise and only extract information that is explicitly stated or clearly implied
  493. Required Output Format:
  494. {json.dumps(output_format, indent=2)}
  495. """
  496. payload = {
  497. "model": model,
  498. "messages": [
  499. {
  500. "role": "system",
  501. "content": f"You are a precise attribute extraction model. Return ONLY valid JSON with {'mandatory and additional' if extract_additional else 'mandatory'} sections. No explanations, no markdown, no other text."
  502. },
  503. {"role": "user", "content": prompt}
  504. ],
  505. "temperature": 0.0,
  506. "max_tokens": 1500
  507. }
  508. headers = {
  509. "Authorization": f"Bearer {settings.GROQ_API_KEY}",
  510. "Content-Type": "application/json",
  511. }
  512. try:
  513. response = requests.post(
  514. settings.GROQ_API_URL,
  515. headers=headers,
  516. json=payload,
  517. timeout=30
  518. )
  519. response.raise_for_status()
  520. result_text = response.json()["choices"][0]["message"]["content"].strip()
  521. # Clean the response
  522. result_text = ProductAttributeService._clean_json_response(result_text)
  523. # Parse JSON
  524. parsed = json.loads(result_text)
  525. # Validate and restructure if needed
  526. parsed = ProductAttributeService._validate_response_structure(
  527. parsed, mandatory_attrs, extract_additional
  528. )
  529. # Clean up and add source tracking to additional attributes in array format
  530. if extract_additional and "additional" in parsed:
  531. cleaned_additional = {}
  532. for k, v in parsed["additional"].items():
  533. if v and v not in ["Not Specified", "None", "N/A", "", "not specified", "none", "n/a"]:
  534. if not (isinstance(v, str) and v.lower() in ["not specified", "none", "n/a", ""]):
  535. source = ProductAttributeService.find_value_source(str(v), source_map)
  536. cleaned_additional[k] = [{"value": str(v), "source": source}]
  537. parsed["additional"] = cleaned_additional
  538. # Calculate attribute relationships if using dynamic thresholds
  539. relationships = {}
  540. if use_dynamic_thresholds:
  541. relationships = ProductAttributeService.calculate_attribute_relationships(
  542. mandatory_attrs, product_text
  543. )
  544. # Process attributes in order, allowing earlier ones to influence later ones
  545. extracted_so_far = {}
  546. for attr in mandatory_attrs.keys():
  547. allow_multiple = attr in multiple
  548. # DEBUG: Print per-attribute processing
  549. print(f"\n>>> Processing attribute: {attr}")
  550. print(f" Allow multiple: {allow_multiple}")
  551. print(f" In multiple list: {attr in multiple}")
  552. print(f" Multiple list: {multiple}")
  553. result = ProductAttributeService.normalize_against_product_text(
  554. product_text=product_text,
  555. mandatory_attrs={attr: mandatory_attrs[attr]},
  556. source_map=source_map,
  557. threshold_abs=threshold_abs,
  558. margin=margin,
  559. allow_multiple=allow_multiple,
  560. extracted_attrs=extracted_so_far,
  561. relationships=relationships,
  562. use_dynamic_thresholds=use_dynamic_thresholds,
  563. use_adaptive_margin=use_adaptive_margin,
  564. use_semantic_clustering=use_semantic_clustering
  565. )
  566. parsed["mandatory"][attr] = result[attr]
  567. extracted_so_far[attr] = result[attr]
  568. return parsed
  569. except requests.exceptions.RequestException as e:
  570. return ProductAttributeService._create_error_response(
  571. str(e), mandatory_attrs, extract_additional
  572. )
  573. except json.JSONDecodeError as e:
  574. return ProductAttributeService._create_error_response(
  575. f"Invalid JSON: {str(e)}", mandatory_attrs, extract_additional, result_text
  576. )
  577. except Exception as e:
  578. return ProductAttributeService._create_error_response(
  579. str(e), mandatory_attrs, extract_additional
  580. )
  581. @staticmethod
  582. def extract_attributes_batch(
  583. products: List[Dict],
  584. mandatory_attrs: Dict[str, List[str]],
  585. model: str = None,
  586. extract_additional: bool = True,
  587. process_image: bool = True,
  588. max_workers: int = 5,
  589. multiple: Optional[List[str]] = None,
  590. threshold_abs: float = 0.65,
  591. margin: float = 0.15,
  592. use_dynamic_thresholds: bool = True,
  593. use_adaptive_margin: bool = True,
  594. use_semantic_clustering: bool = True
  595. ) -> Dict:
  596. """Extract attributes for multiple products in parallel with enhanced multi-value selection and source tracking."""
  597. results = []
  598. successful = 0
  599. failed = 0
  600. ocr_service = OCRService()
  601. if multiple is None:
  602. multiple = []
  603. def process_product(product_data):
  604. """Process a single product."""
  605. product_id = product_data.get('product_id', f"product_{len(results)}")
  606. try:
  607. # Process image if URL is provided
  608. ocr_results = None
  609. ocr_text = None
  610. if process_image and product_data.get('image_url'):
  611. ocr_results = ocr_service.process_image(product_data['image_url'])
  612. # Extract attributes from OCR
  613. if ocr_results and ocr_results.get('detected_text'):
  614. ocr_attrs = ProductAttributeService.extract_attributes_from_ocr(
  615. ocr_results, model
  616. )
  617. ocr_results['extracted_attributes'] = ocr_attrs
  618. # Format OCR text for combining with product text
  619. ocr_text = "\n".join([
  620. f"{item['text']} (confidence: {item['confidence']:.2f})"
  621. for item in ocr_results['detected_text']
  622. ])
  623. # Combine all product information with source tracking
  624. product_text, source_map = ProductAttributeService.combine_product_text(
  625. title=product_data.get('title'),
  626. short_desc=product_data.get('short_desc'),
  627. long_desc=product_data.get('long_desc'),
  628. ocr_text=ocr_text
  629. )
  630. # Extract attributes from combined text with enhanced features
  631. result = ProductAttributeService.extract_attributes(
  632. product_text=product_text,
  633. mandatory_attrs=mandatory_attrs,
  634. source_map=source_map,
  635. model=model,
  636. extract_additional=extract_additional,
  637. multiple=multiple,
  638. threshold_abs=threshold_abs,
  639. margin=margin,
  640. use_dynamic_thresholds=use_dynamic_thresholds,
  641. use_adaptive_margin=use_adaptive_margin,
  642. use_semantic_clustering=use_semantic_clustering
  643. )
  644. result['product_id'] = product_id
  645. # Add OCR results if available
  646. if ocr_results:
  647. result['ocr_results'] = ocr_results
  648. # Check if extraction was successful
  649. if 'error' not in result:
  650. return result, True
  651. else:
  652. return result, False
  653. except Exception as e:
  654. return {
  655. 'product_id': product_id,
  656. 'mandatory': {attr: [{"value": "Not Specified", "source": "error"}] for attr in mandatory_attrs.keys()},
  657. 'additional': {} if extract_additional else None,
  658. 'error': f"Processing error: {str(e)}"
  659. }, False
  660. # Process products in parallel
  661. with ThreadPoolExecutor(max_workers=max_workers) as executor:
  662. future_to_product = {
  663. executor.submit(process_product, product): product
  664. for product in products
  665. }
  666. for future in as_completed(future_to_product):
  667. try:
  668. result, success = future.result()
  669. results.append(result)
  670. if success:
  671. successful += 1
  672. else:
  673. failed += 1
  674. except Exception as e:
  675. failed += 1
  676. results.append({
  677. 'product_id': 'unknown',
  678. 'mandatory': {attr: [{"value": "Not Specified", "source": "error"}] for attr in mandatory_attrs.keys()},
  679. 'additional': {} if extract_additional else None,
  680. 'error': f"Unexpected error: {str(e)}"
  681. })
  682. return {
  683. 'results': results,
  684. 'total_products': len(products),
  685. 'successful': successful,
  686. 'failed': failed
  687. }
  688. @staticmethod
  689. def _clean_json_response(text: str) -> str:
  690. """Clean LLM response to extract valid JSON."""
  691. start_idx = text.find('{')
  692. end_idx = text.rfind('}')
  693. if start_idx != -1 and end_idx != -1:
  694. text = text[start_idx:end_idx + 1]
  695. if "```json" in text:
  696. text = text.split("```json")[1].split("```")[0].strip()
  697. elif "```" in text:
  698. text = text.split("```")[1].split("```")[0].strip()
  699. if text.startswith("json"):
  700. text = text[4:].strip()
  701. return text
  702. @staticmethod
  703. def _validate_response_structure(
  704. parsed: dict,
  705. mandatory_attrs: Dict[str, List[str]],
  706. extract_additional: bool
  707. ) -> dict:
  708. """Validate and fix the response structure."""
  709. expected_sections = ["mandatory"]
  710. if extract_additional:
  711. expected_sections.append("additional")
  712. if not all(section in parsed for section in expected_sections):
  713. if isinstance(parsed, dict):
  714. mandatory_keys = set(mandatory_attrs.keys())
  715. mandatory = {k: v for k, v in parsed.items() if k in mandatory_keys}
  716. additional = {k: v for k, v in parsed.items() if k not in mandatory_keys}
  717. result = {"mandatory": mandatory}
  718. if extract_additional:
  719. result["additional"] = additional
  720. return result
  721. else:
  722. return ProductAttributeService._create_error_response(
  723. "Invalid response structure",
  724. mandatory_attrs,
  725. extract_additional,
  726. str(parsed)
  727. )
  728. return parsed
  729. @staticmethod
  730. def _create_error_response(
  731. error: str,
  732. mandatory_attrs: Dict[str, List[str]],
  733. extract_additional: bool,
  734. raw_output: Optional[str] = None
  735. ) -> dict:
  736. """Create a standardized error response in array format."""
  737. response = {
  738. "mandatory": {attr: [{"value": "Not Specified", "source": "error"}] for attr in mandatory_attrs.keys()},
  739. "error": error
  740. }
  741. if extract_additional:
  742. response["additional"] = {}
  743. if raw_output:
  744. response["raw_output"] = raw_output
  745. return response