services.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921
  1. # ==================== services.py (PERFORMANCE OPTIMIZED) ====================
  2. import requests
  3. import json
  4. import re
  5. import hashlib
  6. import logging
  7. from typing import Dict, List, Optional, Tuple
  8. from django.conf import settings
  9. from concurrent.futures import ThreadPoolExecutor, as_completed
  10. from sentence_transformers import SentenceTransformer, util
  11. import numpy as np
  12. logger = logging.getLogger(__name__)
  13. # ⚡ CRITICAL FIX: Initialize embedding model ONCE at module level
  14. print("Loading sentence transformer model (one-time initialization)...")
  15. model_embedder = SentenceTransformer("all-MiniLM-L6-v2")
  16. # Disable progress bars to prevent "Batches: 100%" spam
  17. import os
  18. os.environ['TOKENIZERS_PARALLELISM'] = 'false'
  19. print("✓ Model loaded successfully")
  20. # ==================== CACHING CLASSES ====================
  21. class SimpleCache:
  22. """In-memory cache for attribute extraction results."""
  23. _cache = {}
  24. _max_size = 1000
  25. @classmethod
  26. def get(cls, key: str) -> Optional[Dict]:
  27. return cls._cache.get(key)
  28. @classmethod
  29. def set(cls, key: str, value: Dict):
  30. if len(cls._cache) >= cls._max_size:
  31. items = list(cls._cache.items())
  32. cls._cache = dict(items[int(cls._max_size * 0.2):])
  33. cls._cache[key] = value
  34. @classmethod
  35. def clear(cls):
  36. cls._cache.clear()
  37. @classmethod
  38. def get_stats(cls) -> Dict:
  39. return {
  40. "size": len(cls._cache),
  41. "max_size": cls._max_size,
  42. "usage_percent": round(len(cls._cache) / cls._max_size * 100, 2)
  43. }
  44. class EmbeddingCache:
  45. """Cache for sentence transformer embeddings."""
  46. _cache = {}
  47. _max_size = 500
  48. _hit_count = 0
  49. _miss_count = 0
  50. @classmethod
  51. def get_embedding(cls, text: str, model):
  52. """Get or compute embedding with caching"""
  53. if text in cls._cache:
  54. cls._hit_count += 1
  55. return cls._cache[text]
  56. cls._miss_count += 1
  57. if len(cls._cache) >= cls._max_size:
  58. items = list(cls._cache.items())
  59. cls._cache = dict(items[int(cls._max_size * 0.3):])
  60. # ⚡ CRITICAL: Disable verbose output
  61. import warnings
  62. with warnings.catch_warnings():
  63. warnings.simplefilter("ignore")
  64. embedding = model.encode(text, convert_to_tensor=True, show_progress_bar=False)
  65. cls._cache[text] = embedding
  66. return embedding
  67. @classmethod
  68. def clear(cls):
  69. cls._cache.clear()
  70. cls._hit_count = 0
  71. cls._miss_count = 0
  72. @classmethod
  73. def get_stats(cls) -> Dict:
  74. total = cls._hit_count + cls._miss_count
  75. hit_rate = (cls._hit_count / total * 100) if total > 0 else 0
  76. return {
  77. "size": len(cls._cache),
  78. "hits": cls._hit_count,
  79. "misses": cls._miss_count,
  80. "hit_rate_percent": round(hit_rate, 2)
  81. }
  82. # ==================== MAIN SERVICE CLASS ====================
  83. class ProductAttributeService:
  84. """Service class for extracting product attributes using Groq LLM."""
  85. @staticmethod
  86. def _generate_cache_key(product_text: str, mandatory_attrs: Dict) -> str:
  87. """Generate cache key from product text and attributes."""
  88. attrs_str = json.dumps(mandatory_attrs, sort_keys=True)
  89. content = f"{product_text}:{attrs_str}"
  90. return f"attr_{hashlib.md5(content.encode()).hexdigest()}"
  91. @staticmethod
  92. def normalize_dimension_text(text: str) -> str:
  93. """Normalize dimension text to format like '16x20'."""
  94. if not text:
  95. return ""
  96. text = text.lower()
  97. text = re.sub(r'\s*(inches|inch|in|cm|centimeters|mm|millimeters)\s*', '', text, flags=re.IGNORECASE)
  98. numbers = re.findall(r'\d+\.?\d*', text)
  99. if not numbers:
  100. return ""
  101. float_numbers = []
  102. for num in numbers:
  103. try:
  104. float_numbers.append(float(num))
  105. except:
  106. continue
  107. if len(float_numbers) < 2:
  108. return ""
  109. if len(float_numbers) == 3:
  110. float_numbers = [float_numbers[0], float_numbers[2]]
  111. elif len(float_numbers) > 3:
  112. float_numbers = sorted(float_numbers)[-2:]
  113. else:
  114. float_numbers = float_numbers[:2]
  115. formatted_numbers = []
  116. for num in float_numbers:
  117. if num.is_integer():
  118. formatted_numbers.append(str(int(num)))
  119. else:
  120. formatted_numbers.append(f"{num:.1f}")
  121. formatted_numbers.sort(key=lambda x: float(x))
  122. return f"{formatted_numbers[0]}x{formatted_numbers[1]}"
  123. @staticmethod
  124. def normalize_value_for_matching(value: str, attr_name: str = "") -> str:
  125. """Normalize a value based on its attribute type."""
  126. dimension_keywords = ['dimension', 'size', 'measurement']
  127. if any(keyword in attr_name.lower() for keyword in dimension_keywords):
  128. normalized = ProductAttributeService.normalize_dimension_text(value)
  129. if normalized:
  130. return normalized
  131. return value.strip()
  132. @staticmethod
  133. def combine_product_text(
  134. title: Optional[str] = None,
  135. short_desc: Optional[str] = None,
  136. long_desc: Optional[str] = None,
  137. ocr_text: Optional[str] = None
  138. ) -> Tuple[str, Dict[str, str]]:
  139. """Combine product metadata into a single text block."""
  140. parts = []
  141. source_map = {}
  142. if title:
  143. title_str = str(title).strip()
  144. parts.append(f"Title: {title_str}")
  145. source_map['title'] = title_str
  146. if short_desc:
  147. short_str = str(short_desc).strip()
  148. parts.append(f"Description: {short_str}")
  149. source_map['short_desc'] = short_str
  150. if long_desc:
  151. long_str = str(long_desc).strip()
  152. parts.append(f"Details: {long_str}")
  153. source_map['long_desc'] = long_str
  154. if ocr_text:
  155. parts.append(f"OCR Text: {ocr_text}")
  156. source_map['ocr_text'] = ocr_text
  157. combined = "\n".join(parts).strip()
  158. if not combined:
  159. return "No product information available", {}
  160. return combined, source_map
  161. @staticmethod
  162. def find_value_source(value: str, source_map: Dict[str, str], attr_name: str = "") -> str:
  163. """Find which source(s) contain the given value."""
  164. value_lower = value.lower()
  165. value_tokens = set(value_lower.replace("-", " ").replace("x", " ").split())
  166. is_dimension_attr = any(keyword in attr_name.lower() for keyword in ['dimension', 'size', 'measurement'])
  167. sources_found = []
  168. source_scores = {}
  169. for source_name, source_text in source_map.items():
  170. source_lower = source_text.lower()
  171. if value_lower in source_lower:
  172. source_scores[source_name] = 1.0
  173. continue
  174. if is_dimension_attr:
  175. normalized_value = ProductAttributeService.normalize_dimension_text(value)
  176. if not normalized_value:
  177. normalized_value = value.replace("x", " ").strip()
  178. normalized_source = ProductAttributeService.normalize_dimension_text(source_text)
  179. if normalized_value == normalized_source:
  180. source_scores[source_name] = 0.95
  181. continue
  182. dim_parts = normalized_value.split("x") if "x" in normalized_value else []
  183. if len(dim_parts) == 2:
  184. if all(part in source_text for part in dim_parts):
  185. source_scores[source_name] = 0.85
  186. continue
  187. token_matches = sum(1 for token in value_tokens if token and token in source_lower)
  188. if token_matches > 0 and len(value_tokens) > 0:
  189. source_scores[source_name] = token_matches / len(value_tokens)
  190. if source_scores:
  191. max_score = max(source_scores.values())
  192. sources_found = [s for s, score in source_scores.items() if score == max_score]
  193. priority = ['title', 'short_desc', 'long_desc', 'ocr_text']
  194. for p in priority:
  195. if p in sources_found:
  196. return p
  197. return sources_found[0] if sources_found else "Not found"
  198. return "Not found"
  199. @staticmethod
  200. def format_visual_attributes(visual_attributes: Dict) -> Dict:
  201. """Convert visual attributes to array format with source tracking."""
  202. formatted = {}
  203. for key, value in visual_attributes.items():
  204. if isinstance(value, list):
  205. formatted[key] = [{"value": str(item), "source": "image"} for item in value]
  206. elif isinstance(value, dict):
  207. nested_formatted = {}
  208. for nested_key, nested_value in value.items():
  209. if isinstance(nested_value, list):
  210. nested_formatted[nested_key] = [{"value": str(item), "source": "image"} for item in nested_value]
  211. else:
  212. nested_formatted[nested_key] = [{"value": str(nested_value), "source": "image"}]
  213. formatted[key] = nested_formatted
  214. else:
  215. formatted[key] = [{"value": str(value), "source": "image"}]
  216. return formatted
  217. @staticmethod
  218. def extract_attributes_from_ocr(ocr_results: Dict, model: str = None) -> Dict:
  219. """Extract structured attributes from OCR text using LLM."""
  220. if model is None:
  221. model = settings.SUPPORTED_MODELS[0]
  222. detected_text = ocr_results.get('detected_text', [])
  223. if not detected_text:
  224. return {}
  225. ocr_text = "\n".join([f"Text: {item['text']}, Confidence: {item['confidence']:.2f}"
  226. for item in detected_text])
  227. prompt = f"""
  228. You are an AI model that extracts structured attributes from OCR text detected on product images.
  229. Given the OCR detections below, infer the possible product attributes and return them as a clean JSON object.
  230. OCR Text:
  231. {ocr_text}
  232. Extract relevant attributes like:
  233. - brand
  234. - model_number
  235. - size (waist_size, length, etc.)
  236. - collection
  237. - any other relevant product information
  238. Return a JSON object with only the attributes you can confidently identify.
  239. If an attribute is not present, do not include it in the response.
  240. """
  241. payload = {
  242. "model": model,
  243. "messages": [
  244. {
  245. "role": "system",
  246. "content": "You are a helpful AI that extracts structured data from OCR output. Return only valid JSON."
  247. },
  248. {"role": "user", "content": prompt}
  249. ],
  250. "temperature": 0.2,
  251. "max_tokens": 500
  252. }
  253. headers = {
  254. "Authorization": f"Bearer {settings.GROQ_API_KEY}",
  255. "Content-Type": "application/json",
  256. }
  257. try:
  258. response = requests.post(
  259. settings.GROQ_API_URL,
  260. headers=headers,
  261. json=payload,
  262. timeout=30
  263. )
  264. response.raise_for_status()
  265. result_text = response.json()["choices"][0]["message"]["content"].strip()
  266. result_text = ProductAttributeService._clean_json_response(result_text)
  267. parsed = json.loads(result_text)
  268. formatted_attributes = {}
  269. for key, value in parsed.items():
  270. if key == "error":
  271. continue
  272. if isinstance(value, dict):
  273. nested_formatted = {}
  274. for nested_key, nested_value in value.items():
  275. nested_formatted[nested_key] = [{"value": str(nested_value), "source": "image"}]
  276. formatted_attributes[key] = nested_formatted
  277. elif isinstance(value, list):
  278. formatted_attributes[key] = [{"value": str(item), "source": "image"} for item in value]
  279. else:
  280. formatted_attributes[key] = [{"value": str(value), "source": "image"}]
  281. return formatted_attributes
  282. except Exception as e:
  283. logger.error(f"OCR attribute extraction failed: {str(e)}")
  284. return {"error": f"Failed to extract attributes from OCR: {str(e)}"}
  285. @staticmethod
  286. def calculate_attribute_relationships(
  287. mandatory_attrs: Dict[str, List[str]],
  288. product_text: str
  289. ) -> Dict[str, float]:
  290. """Calculate semantic relationships between attribute values."""
  291. pt_emb = EmbeddingCache.get_embedding(product_text, model_embedder)
  292. attr_scores = {}
  293. for attr, values in mandatory_attrs.items():
  294. attr_scores[attr] = {}
  295. for val in values:
  296. contexts = [val, f"for {val}", f"use in {val}", f"suitable for {val}"]
  297. ctx_embs = [EmbeddingCache.get_embedding(c, model_embedder) for c in contexts]
  298. sem_sim = max(float(util.cos_sim(pt_emb, ce).item()) for ce in ctx_embs)
  299. attr_scores[attr][val] = sem_sim
  300. relationships = {}
  301. attr_list = list(mandatory_attrs.keys())
  302. for i, attr1 in enumerate(attr_list):
  303. for attr2 in attr_list[i+1:]:
  304. for val1 in mandatory_attrs[attr1]:
  305. for val2 in mandatory_attrs[attr2]:
  306. emb1 = EmbeddingCache.get_embedding(val1, model_embedder)
  307. emb2 = EmbeddingCache.get_embedding(val2, model_embedder)
  308. sim = float(util.cos_sim(emb1, emb2).item())
  309. key1 = f"{attr1}:{val1}->{attr2}:{val2}"
  310. key2 = f"{attr2}:{val2}->{attr1}:{val1}"
  311. relationships[key1] = sim
  312. relationships[key2] = sim
  313. return relationships
  314. @staticmethod
  315. def calculate_value_clusters(
  316. values: List[str],
  317. scores: List[Tuple[str, float]],
  318. cluster_threshold: float = 0.4
  319. ) -> List[List[str]]:
  320. """Group values into semantic clusters."""
  321. if len(values) <= 1:
  322. return [[val] for val, _ in scores]
  323. embeddings = [EmbeddingCache.get_embedding(val, model_embedder) for val in values]
  324. similarity_matrix = np.zeros((len(values), len(values)))
  325. for i in range(len(values)):
  326. for j in range(i+1, len(values)):
  327. sim = float(util.cos_sim(embeddings[i], embeddings[j]).item())
  328. similarity_matrix[i][j] = sim
  329. similarity_matrix[j][i] = sim
  330. clusters = []
  331. visited = set()
  332. for i, (val, score) in enumerate(scores):
  333. if i in visited:
  334. continue
  335. cluster = [val]
  336. visited.add(i)
  337. for j in range(len(values)):
  338. if j not in visited and similarity_matrix[i][j] >= cluster_threshold:
  339. cluster.append(values[j])
  340. visited.add(j)
  341. clusters.append(cluster)
  342. return clusters
  343. @staticmethod
  344. def get_dynamic_threshold(
  345. attr: str,
  346. val: str,
  347. base_score: float,
  348. extracted_attrs: Dict[str, List[Dict[str, str]]],
  349. relationships: Dict[str, float],
  350. mandatory_attrs: Dict[str, List[str]],
  351. base_threshold: float = 0.65,
  352. boost_factor: float = 0.15
  353. ) -> float:
  354. """Calculate dynamic threshold based on relationships."""
  355. threshold = base_threshold
  356. max_relationship = 0.0
  357. for other_attr, other_values_list in extracted_attrs.items():
  358. if other_attr == attr:
  359. continue
  360. for other_val_dict in other_values_list:
  361. other_val = other_val_dict['value']
  362. key = f"{attr}:{val}->{other_attr}:{other_val}"
  363. if key in relationships:
  364. max_relationship = max(max_relationship, relationships[key])
  365. if max_relationship > 0.6:
  366. threshold = base_threshold - (boost_factor * max_relationship)
  367. return max(0.3, threshold)
  368. @staticmethod
  369. def get_adaptive_margin(
  370. scores: List[Tuple[str, float]],
  371. base_margin: float = 0.15,
  372. max_margin: float = 0.22
  373. ) -> float:
  374. """Calculate adaptive margin based on score distribution."""
  375. if len(scores) < 2:
  376. return base_margin
  377. score_values = [s for _, s in scores]
  378. best_score = score_values[0]
  379. if best_score < 0.5:
  380. top_scores = score_values[:min(4, len(score_values))]
  381. score_range = max(top_scores) - min(top_scores)
  382. if score_range < 0.30:
  383. score_factor = (0.5 - best_score) * 0.35
  384. adaptive = base_margin + score_factor + (0.30 - score_range) * 0.2
  385. return min(adaptive, max_margin)
  386. return base_margin
  387. @staticmethod
  388. def _lexical_evidence(product_text: str, label: str) -> float:
  389. """Calculate lexical overlap between product text and label."""
  390. pt = product_text.lower()
  391. tokens = [t for t in label.lower().replace("-", " ").split() if t]
  392. if not tokens:
  393. return 0.0
  394. hits = sum(1 for t in tokens if t in pt)
  395. return hits / len(tokens)
  396. @staticmethod
  397. def normalize_against_product_text(
  398. product_text: str,
  399. mandatory_attrs: Dict[str, List[str]],
  400. source_map: Dict[str, str],
  401. threshold_abs: float = 0.65,
  402. margin: float = 0.15,
  403. allow_multiple: bool = False,
  404. sem_weight: float = 0.8,
  405. lex_weight: float = 0.2,
  406. extracted_attrs: Optional[Dict[str, List[Dict[str, str]]]] = None,
  407. relationships: Optional[Dict[str, float]] = None,
  408. use_dynamic_thresholds: bool = True,
  409. use_adaptive_margin: bool = True,
  410. use_semantic_clustering: bool = True
  411. ) -> dict:
  412. """Score each allowed value against the product_text."""
  413. if extracted_attrs is None:
  414. extracted_attrs = {}
  415. if relationships is None:
  416. relationships = {}
  417. pt_emb = EmbeddingCache.get_embedding(product_text, model_embedder)
  418. extracted = {}
  419. for attr, allowed_values in mandatory_attrs.items():
  420. scores: List[Tuple[str, float]] = []
  421. is_dimension_attr = any(keyword in attr.lower() for keyword in ['dimension', 'size', 'measurement'])
  422. normalized_product_text = ProductAttributeService.normalize_dimension_text(product_text) if is_dimension_attr else ""
  423. for val in allowed_values:
  424. if is_dimension_attr:
  425. normalized_val = ProductAttributeService.normalize_dimension_text(val)
  426. if normalized_val and normalized_product_text and normalized_val == normalized_product_text:
  427. scores.append((val, 1.0))
  428. continue
  429. if normalized_val:
  430. val_numbers = normalized_val.split('x')
  431. text_lower = product_text.lower()
  432. if all(num in text_lower for num in val_numbers):
  433. idx1 = text_lower.find(val_numbers[0])
  434. idx2 = text_lower.find(val_numbers[1])
  435. if idx1 != -1 and idx2 != -1:
  436. distance = abs(idx2 - idx1)
  437. if distance < 20:
  438. scores.append((val, 0.95))
  439. continue
  440. contexts = [val, f"for {val}", f"use in {val}", f"suitable for {val}", f"{val} room"]
  441. ctx_embs = [EmbeddingCache.get_embedding(c, model_embedder) for c in contexts]
  442. sem_sim = max(float(util.cos_sim(pt_emb, ce).item()) for ce in ctx_embs)
  443. lex_score = ProductAttributeService._lexical_evidence(product_text, val)
  444. final_score = sem_weight * sem_sim + lex_weight * lex_score
  445. scores.append((val, final_score))
  446. scores.sort(key=lambda x: x[1], reverse=True)
  447. best_val, best_score = scores[0]
  448. effective_margin = margin
  449. if allow_multiple and use_adaptive_margin:
  450. effective_margin = ProductAttributeService.get_adaptive_margin(scores, margin)
  451. if is_dimension_attr and best_score >= 0.90:
  452. source = ProductAttributeService.find_value_source(best_val, source_map, attr)
  453. extracted[attr] = [{"value": best_val, "source": source}]
  454. continue
  455. if not allow_multiple:
  456. source = ProductAttributeService.find_value_source(best_val, source_map, attr)
  457. extracted[attr] = [{"value": best_val, "source": source}]
  458. else:
  459. candidates = [best_val]
  460. use_base_threshold = best_score >= threshold_abs
  461. clusters = []
  462. if use_semantic_clustering:
  463. clusters = ProductAttributeService.calculate_value_clusters(
  464. allowed_values, scores, cluster_threshold=0.4
  465. )
  466. best_cluster = next((c for c in clusters if best_val in c), [best_val])
  467. for val, sc in scores[1:]:
  468. min_score = 0.4 if is_dimension_attr else 0.3
  469. if sc < min_score:
  470. continue
  471. if use_dynamic_thresholds and extracted_attrs:
  472. dynamic_thresh = ProductAttributeService.get_dynamic_threshold(
  473. attr, val, sc, extracted_attrs, relationships,
  474. mandatory_attrs, threshold_abs
  475. )
  476. else:
  477. dynamic_thresh = threshold_abs
  478. within_margin = (best_score - sc) <= effective_margin
  479. above_threshold = sc >= dynamic_thresh
  480. in_cluster = False
  481. if use_semantic_clustering and clusters:
  482. in_cluster = any(best_val in c and val in c for c in clusters)
  483. if use_base_threshold:
  484. if above_threshold and within_margin:
  485. candidates.append(val)
  486. elif in_cluster and within_margin:
  487. candidates.append(val)
  488. else:
  489. if within_margin:
  490. candidates.append(val)
  491. elif in_cluster and (best_score - sc) <= effective_margin * 2.0:
  492. candidates.append(val)
  493. extracted[attr] = []
  494. for candidate in candidates:
  495. source = ProductAttributeService.find_value_source(candidate, source_map, attr)
  496. extracted[attr].append({"value": candidate, "source": source})
  497. return extracted
  498. @staticmethod
  499. def extract_attributes(
  500. product_text: str,
  501. mandatory_attrs: Dict[str, List[str]],
  502. source_map: Dict[str, str] = None,
  503. model: str = None,
  504. extract_additional: bool = True,
  505. multiple: Optional[List[str]] = None,
  506. threshold_abs: float = 0.65,
  507. margin: float = 0.15,
  508. use_dynamic_thresholds: bool = True,
  509. use_adaptive_margin: bool = True,
  510. use_semantic_clustering: bool = True,
  511. use_cache: bool = True
  512. ) -> dict:
  513. """Extract attributes from product text using Groq LLM."""
  514. if model is None:
  515. model = settings.SUPPORTED_MODELS[0]
  516. if multiple is None:
  517. multiple = []
  518. if source_map is None:
  519. source_map = {}
  520. if not product_text or product_text == "No product information available":
  521. return ProductAttributeService._create_error_response(
  522. "No product information provided",
  523. mandatory_attrs,
  524. extract_additional
  525. )
  526. # ⚡ CHECK CACHE FIRST
  527. if use_cache:
  528. cache_key = ProductAttributeService._generate_cache_key(product_text, mandatory_attrs)
  529. cached_result = SimpleCache.get(cache_key)
  530. if cached_result:
  531. logger.info(f"✓ Cache hit")
  532. return cached_result
  533. mandatory_attr_list = []
  534. for attr_name, allowed_values in mandatory_attrs.items():
  535. mandatory_attr_list.append(f"{attr_name}: {', '.join(allowed_values)}")
  536. mandatory_attr_text = "\n".join(mandatory_attr_list)
  537. additional_instruction = ""
  538. if extract_additional:
  539. additional_instruction = """
  540. 2. Extract ADDITIONAL attributes: Identify any other relevant attributes from the product text
  541. that are NOT in the mandatory list. Only include attributes where you can find actual values
  542. in the product text. Do NOT include attributes with "Not Specified" or empty values.
  543. Examples of attributes to look for (only if present): Brand, Material, Size, Color, Dimensions,
  544. Weight, Features, Style, Theme, Pattern, Finish, Care Instructions, etc."""
  545. output_format = {
  546. "mandatory": {attr: "value or list of values" for attr in mandatory_attrs.keys()},
  547. }
  548. if extract_additional:
  549. output_format["additional"] = {
  550. "example_attribute_1": "actual value found",
  551. "example_attribute_2": "actual value found"
  552. }
  553. output_format["additional"]["_note"] = "Only include attributes with actual values found in text"
  554. prompt = f"""
  555. You are an intelligent product attribute extractor that works with ANY product type.
  556. TASK:
  557. 1. Extract MANDATORY attributes: For each mandatory attribute, select the most appropriate value(s)
  558. from the provided list. Choose the value(s) that best match the product description.
  559. {additional_instruction}
  560. Product Text:
  561. {product_text}
  562. Mandatory Attribute Lists (MUST select from these allowed values):
  563. {mandatory_attr_text}
  564. CRITICAL INSTRUCTIONS:
  565. - Return ONLY valid JSON, nothing else
  566. - No explanations, no markdown, no text before or after the JSON
  567. - For mandatory attributes, choose the value(s) from the provided list that best match
  568. - If a mandatory attribute cannot be determined from the product text, use "Not Specified"
  569. - Prefer exact matches from the allowed values list over generic synonyms
  570. - If multiple values are plausible, you MAY return more than one
  571. {f"- For additional attributes: ONLY include attributes where you found actual values in the product text. DO NOT include attributes with 'Not Specified', 'None', 'N/A', or empty values. If you cannot find a value for an attribute, simply don't include that attribute." if extract_additional else ""}
  572. - Be precise and only extract information that is explicitly stated or clearly implied
  573. Required Output Format:
  574. {json.dumps(output_format, indent=2)}
  575. """
  576. payload = {
  577. "model": model,
  578. "messages": [
  579. {
  580. "role": "system",
  581. "content": f"You are a precise attribute extraction model. Return ONLY valid JSON with {'mandatory and additional' if extract_additional else 'mandatory'} sections. No explanations, no markdown, no other text."
  582. },
  583. {"role": "user", "content": prompt}
  584. ],
  585. "temperature": 0.0,
  586. "max_tokens": 1500
  587. }
  588. headers = {
  589. "Authorization": f"Bearer {settings.GROQ_API_KEY}",
  590. "Content-Type": "application/json",
  591. }
  592. try:
  593. response = requests.post(
  594. settings.GROQ_API_URL,
  595. headers=headers,
  596. json=payload,
  597. timeout=30
  598. )
  599. response.raise_for_status()
  600. result_text = response.json()["choices"][0]["message"]["content"].strip()
  601. result_text = ProductAttributeService._clean_json_response(result_text)
  602. parsed = json.loads(result_text)
  603. parsed = ProductAttributeService._validate_response_structure(
  604. parsed, mandatory_attrs, extract_additional, source_map
  605. )
  606. if extract_additional and "additional" in parsed:
  607. cleaned_additional = {}
  608. for k, v in parsed["additional"].items():
  609. if v and v not in ["Not Specified", "None", "N/A", "", "not specified", "none", "n/a"]:
  610. if not (isinstance(v, str) and v.lower() in ["not specified", "none", "n/a", ""]):
  611. if isinstance(v, list):
  612. cleaned_additional[k] = []
  613. for item in v:
  614. if isinstance(item, dict) and "value" in item:
  615. if "source" not in item:
  616. item["source"] = ProductAttributeService.find_value_source(
  617. item["value"], source_map, k
  618. )
  619. cleaned_additional[k].append(item)
  620. else:
  621. source = ProductAttributeService.find_value_source(str(item), source_map, k)
  622. cleaned_additional[k].append({"value": str(item), "source": source})
  623. else:
  624. source = ProductAttributeService.find_value_source(str(v), source_map, k)
  625. cleaned_additional[k] = [{"value": str(v), "source": source}]
  626. parsed["additional"] = cleaned_additional
  627. relationships = {}
  628. if use_dynamic_thresholds:
  629. relationships = ProductAttributeService.calculate_attribute_relationships(
  630. mandatory_attrs, product_text
  631. )
  632. extracted_so_far = {}
  633. for attr in mandatory_attrs.keys():
  634. allow_multiple = attr in multiple
  635. result = ProductAttributeService.normalize_against_product_text(
  636. product_text=product_text,
  637. mandatory_attrs={attr: mandatory_attrs[attr]},
  638. source_map=source_map,
  639. threshold_abs=threshold_abs,
  640. margin=margin,
  641. allow_multiple=allow_multiple,
  642. extracted_attrs=extracted_so_far,
  643. relationships=relationships,
  644. use_dynamic_thresholds=use_dynamic_thresholds,
  645. use_adaptive_margin=use_adaptive_margin,
  646. use_semantic_clustering=use_semantic_clustering
  647. )
  648. parsed["mandatory"][attr] = result[attr]
  649. extracted_so_far[attr] = result[attr]
  650. # ⚡ CACHE THE RESULT
  651. if use_cache:
  652. SimpleCache.set(cache_key, parsed)
  653. return parsed
  654. except requests.exceptions.RequestException as e:
  655. logger.error(f"Request exception: {str(e)}")
  656. return ProductAttributeService._create_error_response(
  657. str(e), mandatory_attrs, extract_additional
  658. )
  659. except json.JSONDecodeError as e:
  660. logger.error(f"JSON decode error: {str(e)}")
  661. return ProductAttributeService._create_error_response(
  662. f"Invalid JSON: {str(e)}", mandatory_attrs, extract_additional, result_text
  663. )
  664. except Exception as e:
  665. logger.error(f"Unexpected error: {str(e)}")
  666. return ProductAttributeService._create_error_response(
  667. str(e), mandatory_attrs, extract_additional
  668. )
  669. @staticmethod
  670. def _clean_json_response(text: str) -> str:
  671. """Clean LLM response to extract valid JSON."""
  672. start_idx = text.find('{')
  673. end_idx = text.rfind('}')
  674. if start_idx != -1 and end_idx != -1:
  675. text = text[start_idx:end_idx + 1]
  676. if "```json" in text:
  677. text = text.split("```json")[1].split("```")[0].strip()
  678. elif "```" in text:
  679. text = text.split("```")[1].split("```")[0].strip()
  680. if text.startswith("json"):
  681. text = text[4:].strip()
  682. return text
  683. @staticmethod
  684. def _validate_response_structure(
  685. parsed: dict,
  686. mandatory_attrs: Dict[str, List[str]],
  687. extract_additional: bool,
  688. source_map: Dict[str, str] = None
  689. ) -> dict:
  690. """Validate and fix the response structure."""
  691. if source_map is None:
  692. source_map = {}
  693. expected_sections = ["mandatory"]
  694. if extract_additional:
  695. expected_sections.append("additional")
  696. if not all(section in parsed for section in expected_sections):
  697. if isinstance(parsed, dict):
  698. mandatory_keys = set(mandatory_attrs.keys())
  699. mandatory = {k: v for k, v in parsed.items() if k in mandatory_keys}
  700. additional = {k: v for k, v in parsed.items() if k not in mandatory_keys}
  701. result = {"mandatory": mandatory}
  702. if extract_additional:
  703. result["additional"] = additional
  704. parsed = result
  705. else:
  706. return ProductAttributeService._create_error_response(
  707. "Invalid response structure",
  708. mandatory_attrs,
  709. extract_additional,
  710. str(parsed)
  711. )
  712. if "mandatory" in parsed:
  713. converted_mandatory = {}
  714. for attr, value in parsed["mandatory"].items():
  715. if isinstance(value, list):
  716. converted_mandatory[attr] = []
  717. for item in value:
  718. if isinstance(item, dict) and "value" in item:
  719. if "source" not in item:
  720. item["source"] = ProductAttributeService.find_value_source(
  721. item["value"], source_map, attr
  722. )
  723. converted_mandatory[attr].append(item)
  724. else:
  725. source = ProductAttributeService.find_value_source(str(item), source_map, attr)
  726. converted_mandatory[attr].append({"value": str(item), "source": source})
  727. else:
  728. source = ProductAttributeService.find_value_source(str(value), source_map, attr)
  729. converted_mandatory[attr] = [{"value": str(value), "source": source}]
  730. parsed["mandatory"] = converted_mandatory
  731. return parsed
  732. @staticmethod
  733. def _create_error_response(
  734. error: str,
  735. mandatory_attrs: Dict[str, List[str]],
  736. extract_additional: bool,
  737. raw_output: Optional[str] = None
  738. ) -> dict:
  739. """Create a standardized error response."""
  740. response = {
  741. "mandatory": {attr: [{"value": "Not Specified", "source": "error"}] for attr in mandatory_attrs.keys()},
  742. "error": error
  743. }
  744. if extract_additional:
  745. response["additional"] = {}
  746. if raw_output:
  747. response["raw_output"] = raw_output
  748. return response
  749. @staticmethod
  750. def get_cache_stats() -> Dict:
  751. """Get statistics for both caches."""
  752. return {
  753. "simple_cache": SimpleCache.get_stats(),
  754. "embedding_cache": EmbeddingCache.get_stats()
  755. }
  756. @staticmethod
  757. def clear_all_caches():
  758. """Clear both caches."""
  759. SimpleCache.clear()
  760. EmbeddingCache.clear()
  761. logger.info("All caches cleared")