services.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802
  1. # ==================== services.py ====================
  2. import requests
  3. import json
  4. from typing import Dict, List, Optional, Tuple
  5. from django.conf import settings
  6. from concurrent.futures import ThreadPoolExecutor, as_completed
  7. from sentence_transformers import SentenceTransformer, util
  8. import numpy as np
  9. from .ocr_service import OCRService
  10. # Initialize embedding model for normalization
  11. model_embedder = SentenceTransformer("all-MiniLM-L6-v2")
  12. class ProductAttributeService:
  13. """Service class for extracting product attributes using Groq LLM."""
  14. @staticmethod
  15. def combine_product_text(
  16. title: Optional[str] = None,
  17. short_desc: Optional[str] = None,
  18. long_desc: Optional[str] = None,
  19. ocr_text: Optional[str] = None
  20. ) -> Tuple[str, Dict[str, str]]:
  21. """
  22. Combine product metadata into a single text block.
  23. Returns: (combined_text, source_map) where source_map tracks which text came from where
  24. """
  25. parts = []
  26. source_map = {}
  27. if title:
  28. title_str = str(title).strip()
  29. parts.append(f"Title: {title_str}")
  30. source_map['title'] = title_str
  31. if short_desc:
  32. short_str = str(short_desc).strip()
  33. parts.append(f"Description: {short_str}")
  34. source_map['short_desc'] = short_str
  35. if long_desc:
  36. long_str = str(long_desc).strip()
  37. parts.append(f"Details: {long_str}")
  38. source_map['long_desc'] = long_str
  39. if ocr_text:
  40. parts.append(f"OCR Text: {ocr_text}")
  41. source_map['ocr_text'] = ocr_text
  42. combined = "\n".join(parts).strip()
  43. if not combined:
  44. return "No product information available", {}
  45. return combined, source_map
  46. @staticmethod
  47. def find_value_source(value: str, source_map: Dict[str, str]) -> str:
  48. """
  49. Find which source(s) contain the given value.
  50. Returns the source name(s) where the value appears.
  51. """
  52. value_lower = value.lower()
  53. # Split value into tokens for better matching
  54. value_tokens = set(value_lower.replace("-", " ").split())
  55. sources_found = []
  56. source_scores = {}
  57. for source_name, source_text in source_map.items():
  58. source_lower = source_text.lower()
  59. # Check for exact phrase match first
  60. if value_lower in source_lower:
  61. source_scores[source_name] = 1.0
  62. continue
  63. # Check for token matches
  64. token_matches = sum(1 for token in value_tokens if token in source_lower)
  65. if token_matches > 0:
  66. source_scores[source_name] = token_matches / len(value_tokens)
  67. # Return source with highest score, or all sources if multiple have same score
  68. if source_scores:
  69. max_score = max(source_scores.values())
  70. sources_found = [s for s, score in source_scores.items() if score == max_score]
  71. # Prioritize: title > short_desc > long_desc > ocr_text
  72. priority = ['title', 'short_desc', 'long_desc', 'ocr_text']
  73. for p in priority:
  74. if p in sources_found:
  75. return p
  76. return sources_found[0] if sources_found else "Not found"
  77. return "Not found"
  78. @staticmethod
  79. def extract_attributes_from_ocr(ocr_results: Dict, model: str = None) -> Dict:
  80. """Extract structured attributes from OCR text using LLM."""
  81. if model is None:
  82. model = settings.SUPPORTED_MODELS[0]
  83. detected_text = ocr_results.get('detected_text', [])
  84. if not detected_text:
  85. return {}
  86. # Format OCR text for prompt
  87. ocr_text = "\n".join([f"Text: {item['text']}, Confidence: {item['confidence']:.2f}"
  88. for item in detected_text])
  89. prompt = f"""
  90. You are an AI model that extracts structured attributes from OCR text detected on product images.
  91. Given the OCR detections below, infer the possible product attributes and return them as a clean JSON object.
  92. OCR Text:
  93. {ocr_text}
  94. Extract relevant attributes like:
  95. - brand
  96. - model_number
  97. - size (waist_size, length, etc.)
  98. - collection
  99. - any other relevant product information
  100. Return a JSON object with only the attributes you can confidently identify.
  101. If an attribute is not present, do not include it in the response.
  102. """
  103. payload = {
  104. "model": model,
  105. "messages": [
  106. {
  107. "role": "system",
  108. "content": "You are a helpful AI that extracts structured data from OCR output. Return only valid JSON."
  109. },
  110. {"role": "user", "content": prompt}
  111. ],
  112. "temperature": 0.2,
  113. "max_tokens": 500
  114. }
  115. headers = {
  116. "Authorization": f"Bearer {settings.GROQ_API_KEY}",
  117. "Content-Type": "application/json",
  118. }
  119. try:
  120. response = requests.post(
  121. settings.GROQ_API_URL,
  122. headers=headers,
  123. json=payload,
  124. timeout=30
  125. )
  126. response.raise_for_status()
  127. result_text = response.json()["choices"][0]["message"]["content"].strip()
  128. # Clean and parse JSON
  129. result_text = ProductAttributeService._clean_json_response(result_text)
  130. parsed = json.loads(result_text)
  131. return parsed
  132. except Exception as e:
  133. return {"error": f"Failed to extract attributes from OCR: {str(e)}"}
  134. @staticmethod
  135. def calculate_attribute_relationships(
  136. mandatory_attrs: Dict[str, List[str]],
  137. product_text: str
  138. ) -> Dict[str, float]:
  139. """
  140. Calculate semantic relationships between attribute values across different attributes.
  141. Returns a matrix of cross-attribute value similarities.
  142. """
  143. pt_emb = model_embedder.encode(product_text, convert_to_tensor=True)
  144. # Calculate similarities between all attribute values and product text
  145. attr_scores = {}
  146. for attr, values in mandatory_attrs.items():
  147. attr_scores[attr] = {}
  148. for val in values:
  149. contexts = [val, f"for {val}", f"use in {val}", f"suitable for {val}"]
  150. ctx_embs = [model_embedder.encode(c, convert_to_tensor=True) for c in contexts]
  151. sem_sim = max(float(util.cos_sim(pt_emb, ce).item()) for ce in ctx_embs)
  152. attr_scores[attr][val] = sem_sim
  153. # Calculate cross-attribute value relationships
  154. relationships = {}
  155. attr_list = list(mandatory_attrs.keys())
  156. for i, attr1 in enumerate(attr_list):
  157. for attr2 in attr_list[i+1:]:
  158. # Calculate pairwise similarities between values of different attributes
  159. for val1 in mandatory_attrs[attr1]:
  160. for val2 in mandatory_attrs[attr2]:
  161. emb1 = model_embedder.encode(val1, convert_to_tensor=True)
  162. emb2 = model_embedder.encode(val2, convert_to_tensor=True)
  163. sim = float(util.cos_sim(emb1, emb2).item())
  164. # Store bidirectional relationships
  165. key1 = f"{attr1}:{val1}->{attr2}:{val2}"
  166. key2 = f"{attr2}:{val2}->{attr1}:{val1}"
  167. relationships[key1] = sim
  168. relationships[key2] = sim
  169. return relationships
  170. @staticmethod
  171. def calculate_value_clusters(
  172. values: List[str],
  173. scores: List[Tuple[str, float]],
  174. cluster_threshold: float = 0.4
  175. ) -> List[List[str]]:
  176. """
  177. Group values into semantic clusters based on their similarity to each other.
  178. Returns clusters of related values.
  179. """
  180. if len(values) <= 1:
  181. return [[val] for val, _ in scores]
  182. # Get embeddings for all values
  183. embeddings = [model_embedder.encode(val, convert_to_tensor=True) for val in values]
  184. # Calculate pairwise similarities
  185. similarity_matrix = np.zeros((len(values), len(values)))
  186. for i in range(len(values)):
  187. for j in range(i+1, len(values)):
  188. sim = float(util.cos_sim(embeddings[i], embeddings[j]).item())
  189. similarity_matrix[i][j] = sim
  190. similarity_matrix[j][i] = sim
  191. # Simple clustering: group values with high similarity
  192. clusters = []
  193. visited = set()
  194. for i, (val, score) in enumerate(scores):
  195. if i in visited:
  196. continue
  197. cluster = [val]
  198. visited.add(i)
  199. # Find similar values
  200. for j in range(len(values)):
  201. if j not in visited and similarity_matrix[i][j] >= cluster_threshold:
  202. cluster.append(values[j])
  203. visited.add(j)
  204. clusters.append(cluster)
  205. return clusters
  206. @staticmethod
  207. def get_dynamic_threshold(
  208. attr: str,
  209. val: str,
  210. base_score: float,
  211. extracted_attrs: Dict[str, List[Dict[str, str]]],
  212. relationships: Dict[str, float],
  213. mandatory_attrs: Dict[str, List[str]],
  214. base_threshold: float = 0.65,
  215. boost_factor: float = 0.15
  216. ) -> float:
  217. """
  218. Calculate dynamic threshold based on relationships with already-extracted attributes.
  219. """
  220. threshold = base_threshold
  221. # Check relationships with already extracted attributes
  222. max_relationship = 0.0
  223. for other_attr, other_values_list in extracted_attrs.items():
  224. if other_attr == attr:
  225. continue
  226. for other_val_dict in other_values_list:
  227. other_val = other_val_dict['value']
  228. key = f"{attr}:{val}->{other_attr}:{other_val}"
  229. if key in relationships:
  230. max_relationship = max(max_relationship, relationships[key])
  231. # If strong relationship exists, lower threshold
  232. if max_relationship > 0.6:
  233. threshold = base_threshold - (boost_factor * max_relationship)
  234. return max(0.3, threshold)
  235. @staticmethod
  236. def get_adaptive_margin(
  237. scores: List[Tuple[str, float]],
  238. base_margin: float = 0.15,
  239. max_margin: float = 0.22
  240. ) -> float:
  241. """
  242. Calculate adaptive margin based on score distribution.
  243. """
  244. if len(scores) < 2:
  245. return base_margin
  246. score_values = [s for _, s in scores]
  247. best_score = score_values[0]
  248. # If best score is very low, use adaptive margin but be more conservative
  249. if best_score < 0.5:
  250. # Calculate score spread in top 3-4 scores only (more selective)
  251. top_scores = score_values[:min(4, len(score_values))]
  252. score_range = max(top_scores) - min(top_scores)
  253. # Very controlled margin increase
  254. if score_range < 0.30:
  255. # Much more conservative scaling
  256. score_factor = (0.5 - best_score) * 0.35
  257. adaptive = base_margin + score_factor + (0.30 - score_range) * 0.2
  258. return min(adaptive, max_margin)
  259. return base_margin
  260. @staticmethod
  261. def _lexical_evidence(product_text: str, label: str) -> float:
  262. """Calculate lexical overlap between product text and label."""
  263. pt = product_text.lower()
  264. tokens = [t for t in label.lower().replace("-", " ").split() if t]
  265. if not tokens:
  266. return 0.0
  267. hits = sum(1 for t in tokens if t in pt)
  268. return hits / len(tokens)
  269. @staticmethod
  270. def normalize_against_product_text(
  271. product_text: str,
  272. mandatory_attrs: Dict[str, List[str]],
  273. source_map: Dict[str, str],
  274. threshold_abs: float = 0.65,
  275. margin: float = 0.15,
  276. allow_multiple: bool = False,
  277. sem_weight: float = 0.8,
  278. lex_weight: float = 0.2,
  279. extracted_attrs: Optional[Dict[str, List[Dict[str, str]]]] = None,
  280. relationships: Optional[Dict[str, float]] = None,
  281. use_dynamic_thresholds: bool = True,
  282. use_adaptive_margin: bool = True,
  283. use_semantic_clustering: bool = True
  284. ) -> dict:
  285. """
  286. Score each allowed value against the product_text with dynamic thresholds.
  287. Returns dict with values in array format: [{"value": "...", "source": "..."}]
  288. """
  289. if extracted_attrs is None:
  290. extracted_attrs = {}
  291. if relationships is None:
  292. relationships = {}
  293. pt_emb = model_embedder.encode(product_text, convert_to_tensor=True)
  294. extracted = {}
  295. for attr, allowed_values in mandatory_attrs.items():
  296. scores: List[Tuple[str, float]] = []
  297. for val in allowed_values:
  298. contexts = [val, f"for {val}", f"use in {val}", f"suitable for {val}", f"{val} room"]
  299. ctx_embs = [model_embedder.encode(c, convert_to_tensor=True) for c in contexts]
  300. sem_sim = max(float(util.cos_sim(pt_emb, ce).item()) for ce in ctx_embs)
  301. lex_score = ProductAttributeService._lexical_evidence(product_text, val)
  302. final_score = sem_weight * sem_sim + lex_weight * lex_score
  303. scores.append((val, final_score))
  304. scores.sort(key=lambda x: x[1], reverse=True)
  305. best_val, best_score = scores[0]
  306. # Calculate adaptive margin if enabled
  307. effective_margin = margin
  308. if allow_multiple and use_adaptive_margin:
  309. effective_margin = ProductAttributeService.get_adaptive_margin(scores, margin)
  310. if not allow_multiple:
  311. source = ProductAttributeService.find_value_source(best_val, source_map)
  312. extracted[attr] = [{"value": best_val, "source": source}]
  313. else:
  314. candidates = [best_val]
  315. use_base_threshold = best_score >= threshold_abs
  316. # Get semantic clusters if enabled
  317. clusters = []
  318. if use_semantic_clustering:
  319. clusters = ProductAttributeService.calculate_value_clusters(
  320. allowed_values, scores, cluster_threshold=0.4
  321. )
  322. best_cluster = next((c for c in clusters if best_val in c), [best_val])
  323. for val, sc in scores[1:]:
  324. # Calculate dynamic threshold for this value
  325. if use_dynamic_thresholds and extracted_attrs:
  326. dynamic_thresh = ProductAttributeService.get_dynamic_threshold(
  327. attr, val, sc, extracted_attrs, relationships,
  328. mandatory_attrs, threshold_abs
  329. )
  330. else:
  331. dynamic_thresh = threshold_abs
  332. within_margin = (best_score - sc) <= effective_margin
  333. above_threshold = sc >= dynamic_thresh
  334. # Check if in same semantic cluster as best value
  335. in_cluster = False
  336. if use_semantic_clustering and clusters:
  337. in_cluster = any(best_val in c and val in c for c in clusters)
  338. if use_base_threshold:
  339. # Best score is good, require threshold OR (cluster + margin)
  340. if above_threshold and within_margin:
  341. candidates.append(val)
  342. elif in_cluster and within_margin:
  343. candidates.append(val)
  344. else:
  345. # Best score is low, use margin OR cluster logic
  346. if within_margin:
  347. candidates.append(val)
  348. elif in_cluster and (best_score - sc) <= effective_margin * 2.0:
  349. # Extended margin for cluster members
  350. candidates.append(val)
  351. # Map each candidate to its source and create array format
  352. extracted[attr] = []
  353. for candidate in candidates:
  354. source = ProductAttributeService.find_value_source(candidate, source_map)
  355. extracted[attr].append({"value": candidate, "source": source})
  356. return extracted
  357. @staticmethod
  358. def extract_attributes(
  359. product_text: str,
  360. mandatory_attrs: Dict[str, List[str]],
  361. source_map: Dict[str, str] = None,
  362. model: str = None,
  363. extract_additional: bool = True,
  364. multiple: Optional[List[str]] = None,
  365. threshold_abs: float = 0.65,
  366. margin: float = 0.15,
  367. use_dynamic_thresholds: bool = True,
  368. use_adaptive_margin: bool = True,
  369. use_semantic_clustering: bool = True
  370. ) -> dict:
  371. """
  372. Use Groq LLM to extract attributes from any product type with enhanced multi-value selection.
  373. Now returns values in array format: [{"value": "...", "source": "..."}]
  374. """
  375. if model is None:
  376. model = settings.SUPPORTED_MODELS[0]
  377. if multiple is None:
  378. multiple = []
  379. if source_map is None:
  380. source_map = {}
  381. # Check if product text is empty or minimal
  382. if not product_text or product_text == "No product information available":
  383. return ProductAttributeService._create_error_response(
  384. "No product information provided",
  385. mandatory_attrs,
  386. extract_additional
  387. )
  388. # Create structured prompt for mandatory attributes
  389. mandatory_attr_list = []
  390. for attr_name, allowed_values in mandatory_attrs.items():
  391. mandatory_attr_list.append(f"{attr_name}: {', '.join(allowed_values)}")
  392. mandatory_attr_text = "\n".join(mandatory_attr_list)
  393. additional_instruction = ""
  394. if extract_additional:
  395. additional_instruction = """
  396. 2. Extract ADDITIONAL attributes: Identify any other relevant attributes from the product text
  397. that are NOT in the mandatory list. Only include attributes where you can find actual values
  398. in the product text. Do NOT include attributes with "Not Specified" or empty values.
  399. Examples of attributes to look for (only if present): Brand, Material, Size, Color, Dimensions,
  400. Weight, Features, Style, Theme, Pattern, Finish, Care Instructions, etc."""
  401. output_format = {
  402. "mandatory": {attr: "value or list of values" for attr in mandatory_attrs.keys()},
  403. }
  404. if extract_additional:
  405. output_format["additional"] = {
  406. "example_attribute_1": "actual value found",
  407. "example_attribute_2": "actual value found"
  408. }
  409. output_format["additional"]["_note"] = "Only include attributes with actual values found in text"
  410. prompt = f"""
  411. You are an intelligent product attribute extractor that works with ANY product type.
  412. TASK:
  413. 1. Extract MANDATORY attributes: For each mandatory attribute, select the most appropriate value(s)
  414. from the provided list. Choose the value(s) that best match the product description.
  415. {additional_instruction}
  416. Product Text:
  417. {product_text}
  418. Mandatory Attribute Lists (MUST select from these allowed values):
  419. {mandatory_attr_text}
  420. CRITICAL INSTRUCTIONS:
  421. - Return ONLY valid JSON, nothing else
  422. - No explanations, no markdown, no text before or after the JSON
  423. - For mandatory attributes, choose the value(s) from the provided list that best match
  424. - If a mandatory attribute cannot be determined from the product text, use "Not Specified"
  425. - Prefer exact matches from the allowed values list over generic synonyms
  426. - If multiple values are plausible, you MAY return more than one
  427. {f"- For additional attributes: ONLY include attributes where you found actual values in the product text. DO NOT include attributes with 'Not Specified', 'None', 'N/A', or empty values. If you cannot find a value for an attribute, simply don't include that attribute." if extract_additional else ""}
  428. - Be precise and only extract information that is explicitly stated or clearly implied
  429. Required Output Format:
  430. {json.dumps(output_format, indent=2)}
  431. """
  432. payload = {
  433. "model": model,
  434. "messages": [
  435. {
  436. "role": "system",
  437. "content": f"You are a precise attribute extraction model. Return ONLY valid JSON with {'mandatory and additional' if extract_additional else 'mandatory'} sections. No explanations, no markdown, no other text."
  438. },
  439. {"role": "user", "content": prompt}
  440. ],
  441. "temperature": 0.0,
  442. "max_tokens": 1500
  443. }
  444. headers = {
  445. "Authorization": f"Bearer {settings.GROQ_API_KEY}",
  446. "Content-Type": "application/json",
  447. }
  448. try:
  449. response = requests.post(
  450. settings.GROQ_API_URL,
  451. headers=headers,
  452. json=payload,
  453. timeout=30
  454. )
  455. response.raise_for_status()
  456. result_text = response.json()["choices"][0]["message"]["content"].strip()
  457. # Clean the response
  458. result_text = ProductAttributeService._clean_json_response(result_text)
  459. # Parse JSON
  460. parsed = json.loads(result_text)
  461. # Validate and restructure if needed
  462. parsed = ProductAttributeService._validate_response_structure(
  463. parsed, mandatory_attrs, extract_additional
  464. )
  465. # Clean up and add source tracking to additional attributes in array format
  466. if extract_additional and "additional" in parsed:
  467. cleaned_additional = {}
  468. for k, v in parsed["additional"].items():
  469. if v and v not in ["Not Specified", "None", "N/A", "", "not specified", "none", "n/a"]:
  470. if not (isinstance(v, str) and v.lower() in ["not specified", "none", "n/a", ""]):
  471. source = ProductAttributeService.find_value_source(str(v), source_map)
  472. cleaned_additional[k] = [{"value": str(v), "source": source}]
  473. parsed["additional"] = cleaned_additional
  474. # Calculate attribute relationships if using dynamic thresholds
  475. relationships = {}
  476. if use_dynamic_thresholds:
  477. relationships = ProductAttributeService.calculate_attribute_relationships(
  478. mandatory_attrs, product_text
  479. )
  480. # Process attributes in order, allowing earlier ones to influence later ones
  481. extracted_so_far = {}
  482. for attr in mandatory_attrs.keys():
  483. allow_multiple = attr in multiple
  484. result = ProductAttributeService.normalize_against_product_text(
  485. product_text=product_text,
  486. mandatory_attrs={attr: mandatory_attrs[attr]},
  487. source_map=source_map,
  488. threshold_abs=threshold_abs,
  489. margin=margin,
  490. allow_multiple=allow_multiple,
  491. extracted_attrs=extracted_so_far,
  492. relationships=relationships,
  493. use_dynamic_thresholds=use_dynamic_thresholds,
  494. use_adaptive_margin=use_adaptive_margin,
  495. use_semantic_clustering=use_semantic_clustering
  496. )
  497. parsed["mandatory"][attr] = result[attr]
  498. extracted_so_far[attr] = result[attr]
  499. return parsed
  500. except requests.exceptions.RequestException as e:
  501. return ProductAttributeService._create_error_response(
  502. str(e), mandatory_attrs, extract_additional
  503. )
  504. except json.JSONDecodeError as e:
  505. return ProductAttributeService._create_error_response(
  506. f"Invalid JSON: {str(e)}", mandatory_attrs, extract_additional, result_text
  507. )
  508. except Exception as e:
  509. return ProductAttributeService._create_error_response(
  510. str(e), mandatory_attrs, extract_additional
  511. )
  512. @staticmethod
  513. def extract_attributes_batch(
  514. products: List[Dict],
  515. mandatory_attrs: Dict[str, List[str]],
  516. model: str = None,
  517. extract_additional: bool = True,
  518. process_image: bool = True,
  519. max_workers: int = 5,
  520. multiple: Optional[List[str]] = None,
  521. threshold_abs: float = 0.65,
  522. margin: float = 0.15,
  523. use_dynamic_thresholds: bool = True,
  524. use_adaptive_margin: bool = True,
  525. use_semantic_clustering: bool = True
  526. ) -> Dict:
  527. """Extract attributes for multiple products in parallel with enhanced multi-value selection and source tracking."""
  528. results = []
  529. successful = 0
  530. failed = 0
  531. ocr_service = OCRService()
  532. if multiple is None:
  533. multiple = []
  534. def process_product(product_data):
  535. """Process a single product."""
  536. product_id = product_data.get('product_id', f"product_{len(results)}")
  537. try:
  538. # Process image if URL is provided
  539. ocr_results = None
  540. ocr_text = None
  541. if process_image and product_data.get('image_url'):
  542. ocr_results = ocr_service.process_image(product_data['image_url'])
  543. # Extract attributes from OCR
  544. if ocr_results and ocr_results.get('detected_text'):
  545. ocr_attrs = ProductAttributeService.extract_attributes_from_ocr(
  546. ocr_results, model
  547. )
  548. ocr_results['extracted_attributes'] = ocr_attrs
  549. # Format OCR text for combining with product text
  550. ocr_text = "\n".join([
  551. f"{item['text']} (confidence: {item['confidence']:.2f})"
  552. for item in ocr_results['detected_text']
  553. ])
  554. # Combine all product information with source tracking
  555. product_text, source_map = ProductAttributeService.combine_product_text(
  556. title=product_data.get('title'),
  557. short_desc=product_data.get('short_desc'),
  558. long_desc=product_data.get('long_desc'),
  559. ocr_text=ocr_text
  560. )
  561. # Extract attributes from combined text with enhanced features
  562. result = ProductAttributeService.extract_attributes(
  563. product_text=product_text,
  564. mandatory_attrs=mandatory_attrs,
  565. source_map=source_map,
  566. model=model,
  567. extract_additional=extract_additional,
  568. multiple=multiple,
  569. threshold_abs=threshold_abs,
  570. margin=margin,
  571. use_dynamic_thresholds=use_dynamic_thresholds,
  572. use_adaptive_margin=use_adaptive_margin,
  573. use_semantic_clustering=use_semantic_clustering
  574. )
  575. result['product_id'] = product_id
  576. # Add OCR results if available
  577. if ocr_results:
  578. result['ocr_results'] = ocr_results
  579. # Check if extraction was successful
  580. if 'error' not in result:
  581. return result, True
  582. else:
  583. return result, False
  584. except Exception as e:
  585. return {
  586. 'product_id': product_id,
  587. 'mandatory': {attr: [{"value": "Not Specified", "source": "error"}] for attr in mandatory_attrs.keys()},
  588. 'additional': {} if extract_additional else None,
  589. 'error': f"Processing error: {str(e)}"
  590. }, False
  591. # Process products in parallel
  592. with ThreadPoolExecutor(max_workers=max_workers) as executor:
  593. future_to_product = {
  594. executor.submit(process_product, product): product
  595. for product in products
  596. }
  597. for future in as_completed(future_to_product):
  598. try:
  599. result, success = future.result()
  600. results.append(result)
  601. if success:
  602. successful += 1
  603. else:
  604. failed += 1
  605. except Exception as e:
  606. failed += 1
  607. results.append({
  608. 'product_id': 'unknown',
  609. 'mandatory': {attr: [{"value": "Not Specified", "source": "error"}] for attr in mandatory_attrs.keys()},
  610. 'additional': {} if extract_additional else None,
  611. 'error': f"Unexpected error: {str(e)}"
  612. })
  613. return {
  614. 'results': results,
  615. 'total_products': len(products),
  616. 'successful': successful,
  617. 'failed': failed
  618. }
  619. @staticmethod
  620. def _clean_json_response(text: str) -> str:
  621. """Clean LLM response to extract valid JSON."""
  622. start_idx = text.find('{')
  623. end_idx = text.rfind('}')
  624. if start_idx != -1 and end_idx != -1:
  625. text = text[start_idx:end_idx + 1]
  626. if "```json" in text:
  627. text = text.split("```json")[1].split("```")[0].strip()
  628. elif "```" in text:
  629. text = text.split("```")[1].split("```")[0].strip()
  630. if text.startswith("json"):
  631. text = text[4:].strip()
  632. return text
  633. @staticmethod
  634. def _validate_response_structure(
  635. parsed: dict,
  636. mandatory_attrs: Dict[str, List[str]],
  637. extract_additional: bool
  638. ) -> dict:
  639. """Validate and fix the response structure."""
  640. expected_sections = ["mandatory"]
  641. if extract_additional:
  642. expected_sections.append("additional")
  643. if not all(section in parsed for section in expected_sections):
  644. if isinstance(parsed, dict):
  645. mandatory_keys = set(mandatory_attrs.keys())
  646. mandatory = {k: v for k, v in parsed.items() if k in mandatory_keys}
  647. additional = {k: v for k, v in parsed.items() if k not in mandatory_keys}
  648. result = {"mandatory": mandatory}
  649. if extract_additional:
  650. result["additional"] = additional
  651. return result
  652. else:
  653. return ProductAttributeService._create_error_response(
  654. "Invalid response structure",
  655. mandatory_attrs,
  656. extract_additional,
  657. str(parsed)
  658. )
  659. return parsed
  660. @staticmethod
  661. def _create_error_response(
  662. error: str,
  663. mandatory_attrs: Dict[str, List[str]],
  664. extract_additional: bool,
  665. raw_output: Optional[str] = None
  666. ) -> dict:
  667. """Create a standardized error response in array format."""
  668. response = {
  669. "mandatory": {attr: [{"value": "Not Specified", "source": "error"}] for attr in mandatory_attrs.keys()},
  670. "error": error
  671. }
  672. if extract_additional:
  673. response["additional"] = {}
  674. if raw_output:
  675. response["raw_output"] = raw_output
  676. return response