services.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963
  1. # ==================== services.py (WITH CACHE CONTROL) ====================
  2. import requests
  3. import json
  4. import re
  5. import hashlib
  6. import logging
  7. from typing import Dict, List, Optional, Tuple
  8. from django.conf import settings
  9. from concurrent.futures import ThreadPoolExecutor, as_completed
  10. from sentence_transformers import SentenceTransformer, util
  11. import numpy as np
  12. # ⚡ IMPORT CACHE CONFIGURATION
  13. from .cache_config import (
  14. is_caching_enabled,
  15. ENABLE_ATTRIBUTE_EXTRACTION_CACHE,
  16. ENABLE_EMBEDDING_CACHE,
  17. ATTRIBUTE_CACHE_MAX_SIZE,
  18. EMBEDDING_CACHE_MAX_SIZE
  19. )
  20. logger = logging.getLogger(__name__)
  21. # ⚡ CRITICAL FIX: Initialize embedding model ONCE at module level
  22. print("Loading sentence transformer model (one-time initialization)...")
  23. model_embedder = SentenceTransformer("all-MiniLM-L6-v2")
  24. # Disable progress bars to prevent "Batches: 100%" spam
  25. import os
  26. os.environ['TOKENIZERS_PARALLELISM'] = 'false'
  27. print("✓ Model loaded successfully")
  28. # ==================== CACHING CLASSES ====================
  29. class SimpleCache:
  30. """In-memory cache for attribute extraction results."""
  31. _cache = {}
  32. _max_size = ATTRIBUTE_CACHE_MAX_SIZE
  33. @classmethod
  34. def get(cls, key: str) -> Optional[Dict]:
  35. """Get value from cache. Returns None if caching is disabled."""
  36. if not ENABLE_ATTRIBUTE_EXTRACTION_CACHE:
  37. return None
  38. return cls._cache.get(key)
  39. @classmethod
  40. def set(cls, key: str, value: Dict):
  41. """Set value in cache. Does nothing if caching is disabled."""
  42. if not ENABLE_ATTRIBUTE_EXTRACTION_CACHE:
  43. return
  44. if len(cls._cache) >= cls._max_size:
  45. items = list(cls._cache.items())
  46. cls._cache = dict(items[int(cls._max_size * 0.2):])
  47. cls._cache[key] = value
  48. @classmethod
  49. def clear(cls):
  50. """Clear the cache."""
  51. cls._cache.clear()
  52. @classmethod
  53. def get_stats(cls) -> Dict:
  54. """Get cache statistics."""
  55. return {
  56. "enabled": ENABLE_ATTRIBUTE_EXTRACTION_CACHE,
  57. "size": len(cls._cache),
  58. "max_size": cls._max_size,
  59. "usage_percent": round(len(cls._cache) / cls._max_size * 100, 2) if cls._max_size > 0 else 0
  60. }
  61. class EmbeddingCache:
  62. """Cache for sentence transformer embeddings."""
  63. _cache = {}
  64. _max_size = EMBEDDING_CACHE_MAX_SIZE
  65. _hit_count = 0
  66. _miss_count = 0
  67. @classmethod
  68. def get_embedding(cls, text: str, model):
  69. """Get or compute embedding with optional caching"""
  70. # If caching is disabled, always compute fresh
  71. if not ENABLE_EMBEDDING_CACHE:
  72. import warnings
  73. with warnings.catch_warnings():
  74. warnings.simplefilter("ignore")
  75. embedding = model.encode(text, convert_to_tensor=True, show_progress_bar=False)
  76. return embedding
  77. # Caching is enabled, check cache first
  78. if text in cls._cache:
  79. cls._hit_count += 1
  80. return cls._cache[text]
  81. cls._miss_count += 1
  82. if len(cls._cache) >= cls._max_size:
  83. items = list(cls._cache.items())
  84. cls._cache = dict(items[int(cls._max_size * 0.3):])
  85. # Compute embedding
  86. import warnings
  87. with warnings.catch_warnings():
  88. warnings.simplefilter("ignore")
  89. embedding = model.encode(text, convert_to_tensor=True, show_progress_bar=False)
  90. cls._cache[text] = embedding
  91. return embedding
  92. @classmethod
  93. def clear(cls):
  94. """Clear the cache and reset statistics."""
  95. cls._cache.clear()
  96. cls._hit_count = 0
  97. cls._miss_count = 0
  98. @classmethod
  99. def get_stats(cls) -> Dict:
  100. """Get cache statistics."""
  101. total = cls._hit_count + cls._miss_count
  102. hit_rate = (cls._hit_count / total * 100) if total > 0 else 0
  103. return {
  104. "enabled": ENABLE_EMBEDDING_CACHE,
  105. "size": len(cls._cache),
  106. "hits": cls._hit_count,
  107. "misses": cls._miss_count,
  108. "hit_rate_percent": round(hit_rate, 2)
  109. }
  110. # ==================== MAIN SERVICE CLASS ====================
  111. class ProductAttributeService:
  112. """Service class for extracting product attributes using Groq LLM."""
  113. @staticmethod
  114. def _generate_cache_key(product_text: str, mandatory_attrs: Dict) -> str:
  115. """Generate cache key from product text and attributes."""
  116. attrs_str = json.dumps(mandatory_attrs, sort_keys=True)
  117. content = f"{product_text}:{attrs_str}"
  118. return f"attr_{hashlib.md5(content.encode()).hexdigest()}"
  119. @staticmethod
  120. def normalize_dimension_text(text: str) -> str:
  121. """Normalize dimension text to format like '16x20'."""
  122. if not text:
  123. return ""
  124. text = text.lower()
  125. text = re.sub(r'\s*(inches|inch|in|cm|centimeters|mm|millimeters)\s*', '', text, flags=re.IGNORECASE)
  126. numbers = re.findall(r'\d+\.?\d*', text)
  127. if not numbers:
  128. return ""
  129. float_numbers = []
  130. for num in numbers:
  131. try:
  132. float_numbers.append(float(num))
  133. except:
  134. continue
  135. if len(float_numbers) < 2:
  136. return ""
  137. if len(float_numbers) == 3:
  138. float_numbers = [float_numbers[0], float_numbers[2]]
  139. elif len(float_numbers) > 3:
  140. float_numbers = sorted(float_numbers)[-2:]
  141. else:
  142. float_numbers = float_numbers[:2]
  143. formatted_numbers = []
  144. for num in float_numbers:
  145. if num.is_integer():
  146. formatted_numbers.append(str(int(num)))
  147. else:
  148. formatted_numbers.append(f"{num:.1f}")
  149. formatted_numbers.sort(key=lambda x: float(x))
  150. return f"{formatted_numbers[0]}x{formatted_numbers[1]}"
  151. @staticmethod
  152. def normalize_value_for_matching(value: str, attr_name: str = "") -> str:
  153. """Normalize a value based on its attribute type."""
  154. dimension_keywords = ['dimension', 'size', 'measurement']
  155. if any(keyword in attr_name.lower() for keyword in dimension_keywords):
  156. normalized = ProductAttributeService.normalize_dimension_text(value)
  157. if normalized:
  158. return normalized
  159. return value.strip()
  160. @staticmethod
  161. def combine_product_text(
  162. title: Optional[str] = None,
  163. short_desc: Optional[str] = None,
  164. long_desc: Optional[str] = None,
  165. ocr_text: Optional[str] = None
  166. ) -> Tuple[str, Dict[str, str]]:
  167. """Combine product metadata into a single text block."""
  168. parts = []
  169. source_map = {}
  170. if title:
  171. title_str = str(title).strip()
  172. parts.append(f"Title: {title_str}")
  173. source_map['title'] = title_str
  174. if short_desc:
  175. short_str = str(short_desc).strip()
  176. parts.append(f"Description: {short_str}")
  177. source_map['short_desc'] = short_str
  178. if long_desc:
  179. long_str = str(long_desc).strip()
  180. parts.append(f"Details: {long_str}")
  181. source_map['long_desc'] = long_str
  182. if ocr_text:
  183. parts.append(f"OCR Text: {ocr_text}")
  184. source_map['ocr_text'] = ocr_text
  185. combined = "\n".join(parts).strip()
  186. if not combined:
  187. return "No product information available", {}
  188. return combined, source_map
  189. @staticmethod
  190. def find_value_source(value: str, source_map: Dict[str, str], attr_name: str = "") -> str:
  191. """Find which source(s) contain the given value."""
  192. value_lower = value.lower()
  193. value_tokens = set(value_lower.replace("-", " ").replace("x", " ").split())
  194. is_dimension_attr = any(keyword in attr_name.lower() for keyword in ['dimension', 'size', 'measurement'])
  195. sources_found = []
  196. source_scores = {}
  197. for source_name, source_text in source_map.items():
  198. source_lower = source_text.lower()
  199. if value_lower in source_lower:
  200. source_scores[source_name] = 1.0
  201. continue
  202. if is_dimension_attr:
  203. normalized_value = ProductAttributeService.normalize_dimension_text(value)
  204. if not normalized_value:
  205. normalized_value = value.replace("x", " ").strip()
  206. normalized_source = ProductAttributeService.normalize_dimension_text(source_text)
  207. if normalized_value == normalized_source:
  208. source_scores[source_name] = 0.95
  209. continue
  210. dim_parts = normalized_value.split("x") if "x" in normalized_value else []
  211. if len(dim_parts) == 2:
  212. if all(part in source_text for part in dim_parts):
  213. source_scores[source_name] = 0.85
  214. continue
  215. token_matches = sum(1 for token in value_tokens if token and token in source_lower)
  216. if token_matches > 0 and len(value_tokens) > 0:
  217. source_scores[source_name] = token_matches / len(value_tokens)
  218. if source_scores:
  219. max_score = max(source_scores.values())
  220. sources_found = [s for s, score in source_scores.items() if score == max_score]
  221. priority = ['title', 'short_desc', 'long_desc', 'ocr_text']
  222. for p in priority:
  223. if p in sources_found:
  224. return p
  225. return sources_found[0] if sources_found else "Not found"
  226. return "Not found"
  227. @staticmethod
  228. def format_visual_attributes(visual_attributes: Dict) -> Dict:
  229. """Convert visual attributes to array format with source tracking."""
  230. formatted = {}
  231. for key, value in visual_attributes.items():
  232. if isinstance(value, list):
  233. formatted[key] = [{"value": str(item), "source": "image"} for item in value]
  234. elif isinstance(value, dict):
  235. nested_formatted = {}
  236. for nested_key, nested_value in value.items():
  237. if isinstance(nested_value, list):
  238. nested_formatted[nested_key] = [{"value": str(item), "source": "image"} for item in nested_value]
  239. else:
  240. nested_formatted[nested_key] = [{"value": str(nested_value), "source": "image"}]
  241. formatted[key] = nested_formatted
  242. else:
  243. formatted[key] = [{"value": str(value), "source": "image"}]
  244. return formatted
  245. @staticmethod
  246. def extract_attributes_from_ocr(ocr_results: Dict, model: str = None) -> Dict:
  247. """Extract structured attributes from OCR text using LLM."""
  248. if model is None:
  249. model = settings.SUPPORTED_MODELS[0]
  250. detected_text = ocr_results.get('detected_text', [])
  251. if not detected_text:
  252. return {}
  253. ocr_text = "\n".join([f"Text: {item['text']}, Confidence: {item['confidence']:.2f}"
  254. for item in detected_text])
  255. prompt = f"""
  256. You are an AI model that extracts structured attributes from OCR text detected on product images.
  257. Given the OCR detections below, infer the possible product attributes and return them as a clean JSON object.
  258. OCR Text:
  259. {ocr_text}
  260. Extract relevant attributes like:
  261. - brand
  262. - model_number
  263. - size (waist_size, length, etc.)
  264. - collection
  265. - any other relevant product information
  266. Return a JSON object with only the attributes you can confidently identify.
  267. If an attribute is not present, do not include it in the response.
  268. """
  269. payload = {
  270. "model": model,
  271. "messages": [
  272. {
  273. "role": "system",
  274. "content": "You are a helpful AI that extracts structured data from OCR output. Return only valid JSON."
  275. },
  276. {"role": "user", "content": prompt}
  277. ],
  278. "temperature": 0.2,
  279. "max_tokens": 500
  280. }
  281. headers = {
  282. "Authorization": f"Bearer {settings.GROQ_API_KEY}",
  283. "Content-Type": "application/json",
  284. }
  285. try:
  286. response = requests.post(
  287. settings.GROQ_API_URL,
  288. headers=headers,
  289. json=payload,
  290. timeout=30
  291. )
  292. response.raise_for_status()
  293. result_text = response.json()["choices"][0]["message"]["content"].strip()
  294. result_text = ProductAttributeService._clean_json_response(result_text)
  295. parsed = json.loads(result_text)
  296. formatted_attributes = {}
  297. for key, value in parsed.items():
  298. if key == "error":
  299. continue
  300. if isinstance(value, dict):
  301. nested_formatted = {}
  302. for nested_key, nested_value in value.items():
  303. nested_formatted[nested_key] = [{"value": str(nested_value), "source": "image"}]
  304. formatted_attributes[key] = nested_formatted
  305. elif isinstance(value, list):
  306. formatted_attributes[key] = [{"value": str(item), "source": "image"} for item in value]
  307. else:
  308. formatted_attributes[key] = [{"value": str(value), "source": "image"}]
  309. return formatted_attributes
  310. except Exception as e:
  311. logger.error(f"OCR attribute extraction failed: {str(e)}")
  312. return {"error": f"Failed to extract attributes from OCR: {str(e)}"}
  313. @staticmethod
  314. def calculate_attribute_relationships(
  315. mandatory_attrs: Dict[str, List[str]],
  316. product_text: str
  317. ) -> Dict[str, float]:
  318. """Calculate semantic relationships between attribute values."""
  319. pt_emb = EmbeddingCache.get_embedding(product_text, model_embedder)
  320. attr_scores = {}
  321. for attr, values in mandatory_attrs.items():
  322. attr_scores[attr] = {}
  323. for val in values:
  324. contexts = [val, f"for {val}", f"use in {val}", f"suitable for {val}"]
  325. ctx_embs = [EmbeddingCache.get_embedding(c, model_embedder) for c in contexts]
  326. sem_sim = max(float(util.cos_sim(pt_emb, ce).item()) for ce in ctx_embs)
  327. attr_scores[attr][val] = sem_sim
  328. relationships = {}
  329. attr_list = list(mandatory_attrs.keys())
  330. for i, attr1 in enumerate(attr_list):
  331. for attr2 in attr_list[i+1:]:
  332. for val1 in mandatory_attrs[attr1]:
  333. for val2 in mandatory_attrs[attr2]:
  334. emb1 = EmbeddingCache.get_embedding(val1, model_embedder)
  335. emb2 = EmbeddingCache.get_embedding(val2, model_embedder)
  336. sim = float(util.cos_sim(emb1, emb2).item())
  337. key1 = f"{attr1}:{val1}->{attr2}:{val2}"
  338. key2 = f"{attr2}:{val2}->{attr1}:{val1}"
  339. relationships[key1] = sim
  340. relationships[key2] = sim
  341. return relationships
  342. @staticmethod
  343. def calculate_value_clusters(
  344. values: List[str],
  345. scores: List[Tuple[str, float]],
  346. cluster_threshold: float = 0.4
  347. ) -> List[List[str]]:
  348. """Group values into semantic clusters."""
  349. if len(values) <= 1:
  350. return [[val] for val, _ in scores]
  351. embeddings = [EmbeddingCache.get_embedding(val, model_embedder) for val in values]
  352. similarity_matrix = np.zeros((len(values), len(values)))
  353. for i in range(len(values)):
  354. for j in range(i+1, len(values)):
  355. sim = float(util.cos_sim(embeddings[i], embeddings[j]).item())
  356. similarity_matrix[i][j] = sim
  357. similarity_matrix[j][i] = sim
  358. clusters = []
  359. visited = set()
  360. for i, (val, score) in enumerate(scores):
  361. if i in visited:
  362. continue
  363. cluster = [val]
  364. visited.add(i)
  365. for j in range(len(values)):
  366. if j not in visited and similarity_matrix[i][j] >= cluster_threshold:
  367. cluster.append(values[j])
  368. visited.add(j)
  369. clusters.append(cluster)
  370. return clusters
  371. @staticmethod
  372. def get_dynamic_threshold(
  373. attr: str,
  374. val: str,
  375. base_score: float,
  376. extracted_attrs: Dict[str, List[Dict[str, str]]],
  377. relationships: Dict[str, float],
  378. mandatory_attrs: Dict[str, List[str]],
  379. base_threshold: float = 0.65,
  380. boost_factor: float = 0.15
  381. ) -> float:
  382. """Calculate dynamic threshold based on relationships."""
  383. threshold = base_threshold
  384. max_relationship = 0.0
  385. for other_attr, other_values_list in extracted_attrs.items():
  386. if other_attr == attr:
  387. continue
  388. for other_val_dict in other_values_list:
  389. other_val = other_val_dict['value']
  390. key = f"{attr}:{val}->{other_attr}:{other_val}"
  391. if key in relationships:
  392. max_relationship = max(max_relationship, relationships[key])
  393. if max_relationship > 0.6:
  394. threshold = base_threshold - (boost_factor * max_relationship)
  395. return max(0.3, threshold)
  396. @staticmethod
  397. def get_adaptive_margin(
  398. scores: List[Tuple[str, float]],
  399. base_margin: float = 0.15,
  400. max_margin: float = 0.22
  401. ) -> float:
  402. """Calculate adaptive margin based on score distribution."""
  403. if len(scores) < 2:
  404. return base_margin
  405. score_values = [s for _, s in scores]
  406. best_score = score_values[0]
  407. if best_score < 0.5:
  408. top_scores = score_values[:min(4, len(score_values))]
  409. score_range = max(top_scores) - min(top_scores)
  410. if score_range < 0.30:
  411. score_factor = (0.5 - best_score) * 0.35
  412. adaptive = base_margin + score_factor + (0.30 - score_range) * 0.2
  413. return min(adaptive, max_margin)
  414. return base_margin
  415. @staticmethod
  416. def _lexical_evidence(product_text: str, label: str) -> float:
  417. """Calculate lexical overlap between product text and label."""
  418. pt = product_text.lower()
  419. tokens = [t for t in label.lower().replace("-", " ").split() if t]
  420. if not tokens:
  421. return 0.0
  422. hits = sum(1 for t in tokens if t in pt)
  423. return hits / len(tokens)
  424. @staticmethod
  425. def normalize_against_product_text(
  426. product_text: str,
  427. mandatory_attrs: Dict[str, List[str]],
  428. source_map: Dict[str, str],
  429. threshold_abs: float = 0.65,
  430. margin: float = 0.15,
  431. allow_multiple: bool = False,
  432. sem_weight: float = 0.8,
  433. lex_weight: float = 0.2,
  434. extracted_attrs: Optional[Dict[str, List[Dict[str, str]]]] = None,
  435. relationships: Optional[Dict[str, float]] = None,
  436. use_dynamic_thresholds: bool = True,
  437. use_adaptive_margin: bool = True,
  438. use_semantic_clustering: bool = True
  439. ) -> dict:
  440. """Score each allowed value against the product_text."""
  441. if extracted_attrs is None:
  442. extracted_attrs = {}
  443. if relationships is None:
  444. relationships = {}
  445. pt_emb = EmbeddingCache.get_embedding(product_text, model_embedder)
  446. extracted = {}
  447. for attr, allowed_values in mandatory_attrs.items():
  448. scores: List[Tuple[str, float]] = []
  449. is_dimension_attr = any(keyword in attr.lower() for keyword in ['dimension', 'size', 'measurement'])
  450. normalized_product_text = ProductAttributeService.normalize_dimension_text(product_text) if is_dimension_attr else ""
  451. for val in allowed_values:
  452. if is_dimension_attr:
  453. normalized_val = ProductAttributeService.normalize_dimension_text(val)
  454. if normalized_val and normalized_product_text and normalized_val == normalized_product_text:
  455. scores.append((val, 1.0))
  456. continue
  457. if normalized_val:
  458. val_numbers = normalized_val.split('x')
  459. text_lower = product_text.lower()
  460. if all(num in text_lower for num in val_numbers):
  461. idx1 = text_lower.find(val_numbers[0])
  462. idx2 = text_lower.find(val_numbers[1])
  463. if idx1 != -1 and idx2 != -1:
  464. distance = abs(idx2 - idx1)
  465. if distance < 20:
  466. scores.append((val, 0.95))
  467. continue
  468. contexts = [val, f"for {val}", f"use in {val}", f"suitable for {val}", f"{val} room"]
  469. ctx_embs = [EmbeddingCache.get_embedding(c, model_embedder) for c in contexts]
  470. sem_sim = max(float(util.cos_sim(pt_emb, ce).item()) for ce in ctx_embs)
  471. lex_score = ProductAttributeService._lexical_evidence(product_text, val)
  472. final_score = sem_weight * sem_sim + lex_weight * lex_score
  473. scores.append((val, final_score))
  474. scores.sort(key=lambda x: x[1], reverse=True)
  475. best_val, best_score = scores[0]
  476. effective_margin = margin
  477. if allow_multiple and use_adaptive_margin:
  478. effective_margin = ProductAttributeService.get_adaptive_margin(scores, margin)
  479. if is_dimension_attr and best_score >= 0.90:
  480. source = ProductAttributeService.find_value_source(best_val, source_map, attr)
  481. extracted[attr] = [{"value": best_val, "source": source}]
  482. continue
  483. if not allow_multiple:
  484. source = ProductAttributeService.find_value_source(best_val, source_map, attr)
  485. extracted[attr] = [{"value": best_val, "source": source}]
  486. else:
  487. candidates = [best_val]
  488. use_base_threshold = best_score >= threshold_abs
  489. clusters = []
  490. if use_semantic_clustering:
  491. clusters = ProductAttributeService.calculate_value_clusters(
  492. allowed_values, scores, cluster_threshold=0.4
  493. )
  494. best_cluster = next((c for c in clusters if best_val in c), [best_val])
  495. for val, sc in scores[1:]:
  496. min_score = 0.4 if is_dimension_attr else 0.3
  497. if sc < min_score:
  498. continue
  499. if use_dynamic_thresholds and extracted_attrs:
  500. dynamic_thresh = ProductAttributeService.get_dynamic_threshold(
  501. attr, val, sc, extracted_attrs, relationships,
  502. mandatory_attrs, threshold_abs
  503. )
  504. else:
  505. dynamic_thresh = threshold_abs
  506. within_margin = (best_score - sc) <= effective_margin
  507. above_threshold = sc >= dynamic_thresh
  508. in_cluster = False
  509. if use_semantic_clustering and clusters:
  510. in_cluster = any(best_val in c and val in c for c in clusters)
  511. if use_base_threshold:
  512. if above_threshold and within_margin:
  513. candidates.append(val)
  514. elif in_cluster and within_margin:
  515. candidates.append(val)
  516. else:
  517. if within_margin:
  518. candidates.append(val)
  519. elif in_cluster and (best_score - sc) <= effective_margin * 2.0:
  520. candidates.append(val)
  521. extracted[attr] = []
  522. for candidate in candidates:
  523. source = ProductAttributeService.find_value_source(candidate, source_map, attr)
  524. extracted[attr].append({"value": candidate, "source": source})
  525. return extracted
  526. @staticmethod
  527. def extract_attributes(
  528. product_text: str,
  529. mandatory_attrs: Dict[str, List[str]],
  530. source_map: Dict[str, str] = None,
  531. model: str = None,
  532. extract_additional: bool = True,
  533. multiple: Optional[List[str]] = None,
  534. threshold_abs: float = 0.65,
  535. margin: float = 0.15,
  536. use_dynamic_thresholds: bool = True,
  537. use_adaptive_margin: bool = True,
  538. use_semantic_clustering: bool = True,
  539. use_cache: bool = None # ⚡ NEW: Can override global setting
  540. ) -> dict:
  541. """Extract attributes from product text using Groq LLM."""
  542. if model is None:
  543. model = settings.SUPPORTED_MODELS[0]
  544. if multiple is None:
  545. multiple = []
  546. if source_map is None:
  547. source_map = {}
  548. # ⚡ CACHE CONTROL: use parameter if provided, otherwise use global setting
  549. if use_cache is None:
  550. use_cache = ENABLE_ATTRIBUTE_EXTRACTION_CACHE
  551. # If caching is globally disabled, force use_cache to False
  552. if not is_caching_enabled():
  553. use_cache = False
  554. if not product_text or product_text == "No product information available":
  555. return ProductAttributeService._create_error_response(
  556. "No product information provided",
  557. mandatory_attrs,
  558. extract_additional
  559. )
  560. # ⚡ CHECK CACHE FIRST (only if enabled)
  561. if use_cache:
  562. cache_key = ProductAttributeService._generate_cache_key(product_text, mandatory_attrs)
  563. cached_result = SimpleCache.get(cache_key)
  564. if cached_result:
  565. logger.info(f"✓ Cache hit (caching enabled)")
  566. return cached_result
  567. else:
  568. logger.info(f"⚠ Cache disabled - processing fresh")
  569. mandatory_attr_list = []
  570. for attr_name, allowed_values in mandatory_attrs.items():
  571. mandatory_attr_list.append(f"{attr_name}: {', '.join(allowed_values)}")
  572. mandatory_attr_text = "\n".join(mandatory_attr_list)
  573. additional_instruction = ""
  574. if extract_additional:
  575. additional_instruction = """
  576. 2. Extract ADDITIONAL attributes: Identify any other relevant attributes from the product text
  577. that are NOT in the mandatory list. Only include attributes where you can find actual values
  578. in the product text. Do NOT include attributes with "Not Specified" or empty values.
  579. Examples of attributes to look for (only if present): Brand, Material, Size, Color, Dimensions,
  580. Weight, Features, Style, Theme, Pattern, Finish, Care Instructions, etc."""
  581. output_format = {
  582. "mandatory": {attr: "value or list of values" for attr in mandatory_attrs.keys()},
  583. }
  584. if extract_additional:
  585. output_format["additional"] = {
  586. "example_attribute_1": "actual value found",
  587. "example_attribute_2": "actual value found"
  588. }
  589. output_format["additional"]["_note"] = "Only include attributes with actual values found in text"
  590. prompt = f"""
  591. You are an intelligent product attribute extractor that works with ANY product type.
  592. TASK:
  593. 1. Extract MANDATORY attributes: For each mandatory attribute, select the most appropriate value(s)
  594. from the provided list. Choose the value(s) that best match the product description.
  595. {additional_instruction}
  596. Product Text:
  597. {product_text}
  598. Mandatory Attribute Lists (MUST select from these allowed values):
  599. {mandatory_attr_text}
  600. CRITICAL INSTRUCTIONS:
  601. - Return ONLY valid JSON, nothing else
  602. - No explanations, no markdown, no text before or after the JSON
  603. - For mandatory attributes, choose the value(s) from the provided list that best match
  604. - If a mandatory attribute cannot be determined from the product text, use "Not Specified"
  605. - Prefer exact matches from the allowed values list over generic synonyms
  606. - If multiple values are plausible, you MAY return more than one
  607. {f"- For additional attributes: ONLY include attributes where you found actual values in the product text. DO NOT include attributes with 'Not Specified', 'None', 'N/A', or empty values. If you cannot find a value for an attribute, simply don't include that attribute." if extract_additional else ""}
  608. - Be precise and only extract information that is explicitly stated or clearly implied
  609. Required Output Format:
  610. {json.dumps(output_format, indent=2)}
  611. """
  612. payload = {
  613. "model": model,
  614. "messages": [
  615. {
  616. "role": "system",
  617. "content": f"You are a precise attribute extraction model. Return ONLY valid JSON with {'mandatory and additional' if extract_additional else 'mandatory'} sections. No explanations, no markdown, no other text."
  618. },
  619. {"role": "user", "content": prompt}
  620. ],
  621. "temperature": 0.0,
  622. "max_tokens": 1500
  623. }
  624. headers = {
  625. "Authorization": f"Bearer {settings.GROQ_API_KEY}",
  626. "Content-Type": "application/json",
  627. }
  628. try:
  629. response = requests.post(
  630. settings.GROQ_API_URL,
  631. headers=headers,
  632. json=payload,
  633. timeout=30
  634. )
  635. response.raise_for_status()
  636. result_text = response.json()["choices"][0]["message"]["content"].strip()
  637. result_text = ProductAttributeService._clean_json_response(result_text)
  638. parsed = json.loads(result_text)
  639. parsed = ProductAttributeService._validate_response_structure(
  640. parsed, mandatory_attrs, extract_additional, source_map
  641. )
  642. if extract_additional and "additional" in parsed:
  643. cleaned_additional = {}
  644. for k, v in parsed["additional"].items():
  645. if v and v not in ["Not Specified", "None", "N/A", "", "not specified", "none", "n/a"]:
  646. if not (isinstance(v, str) and v.lower() in ["not specified", "none", "n/a", ""]):
  647. if isinstance(v, list):
  648. cleaned_additional[k] = []
  649. for item in v:
  650. if isinstance(item, dict) and "value" in item:
  651. if "source" not in item:
  652. item["source"] = ProductAttributeService.find_value_source(
  653. item["value"], source_map, k
  654. )
  655. cleaned_additional[k].append(item)
  656. else:
  657. source = ProductAttributeService.find_value_source(str(item), source_map, k)
  658. cleaned_additional[k].append({"value": str(item), "source": source})
  659. else:
  660. source = ProductAttributeService.find_value_source(str(v), source_map, k)
  661. cleaned_additional[k] = [{"value": str(v), "source": source}]
  662. parsed["additional"] = cleaned_additional
  663. relationships = {}
  664. if use_dynamic_thresholds:
  665. relationships = ProductAttributeService.calculate_attribute_relationships(
  666. mandatory_attrs, product_text
  667. )
  668. extracted_so_far = {}
  669. for attr in mandatory_attrs.keys():
  670. allow_multiple = attr in multiple
  671. result = ProductAttributeService.normalize_against_product_text(
  672. product_text=product_text,
  673. mandatory_attrs={attr: mandatory_attrs[attr]},
  674. source_map=source_map,
  675. threshold_abs=threshold_abs,
  676. margin=margin,
  677. allow_multiple=allow_multiple,
  678. extracted_attrs=extracted_so_far,
  679. relationships=relationships,
  680. use_dynamic_thresholds=use_dynamic_thresholds,
  681. use_adaptive_margin=use_adaptive_margin,
  682. use_semantic_clustering=use_semantic_clustering
  683. )
  684. parsed["mandatory"][attr] = result[attr]
  685. extracted_so_far[attr] = result[attr]
  686. # ⚡ CACHE THE RESULT (only if enabled)
  687. if use_cache:
  688. SimpleCache.set(cache_key, parsed)
  689. logger.info(f"✓ Result cached")
  690. return parsed
  691. except requests.exceptions.RequestException as e:
  692. logger.error(f"Request exception: {str(e)}")
  693. return ProductAttributeService._create_error_response(
  694. str(e), mandatory_attrs, extract_additional
  695. )
  696. except json.JSONDecodeError as e:
  697. logger.error(f"JSON decode error: {str(e)}")
  698. return ProductAttributeService._create_error_response(
  699. f"Invalid JSON: {str(e)}", mandatory_attrs, extract_additional, result_text
  700. )
  701. except Exception as e:
  702. logger.error(f"Unexpected error: {str(e)}")
  703. return ProductAttributeService._create_error_response(
  704. str(e), mandatory_attrs, extract_additional
  705. )
  706. @staticmethod
  707. def _clean_json_response(text: str) -> str:
  708. """Clean LLM response to extract valid JSON."""
  709. start_idx = text.find('{')
  710. end_idx = text.rfind('}')
  711. if start_idx != -1 and end_idx != -1:
  712. text = text[start_idx:end_idx + 1]
  713. if "```json" in text:
  714. text = text.split("```json")[1].split("```")[0].strip()
  715. elif "```" in text:
  716. text = text.split("```")[1].split("```")[0].strip()
  717. if text.startswith("json"):
  718. text = text[4:].strip()
  719. return text
  720. @staticmethod
  721. def _validate_response_structure(
  722. parsed: dict,
  723. mandatory_attrs: Dict[str, List[str]],
  724. extract_additional: bool,
  725. source_map: Dict[str, str] = None
  726. ) -> dict:
  727. """Validate and fix the response structure."""
  728. if source_map is None:
  729. source_map = {}
  730. expected_sections = ["mandatory"]
  731. if extract_additional:
  732. expected_sections.append("additional")
  733. if not all(section in parsed for section in expected_sections):
  734. if isinstance(parsed, dict):
  735. mandatory_keys = set(mandatory_attrs.keys())
  736. mandatory = {k: v for k, v in parsed.items() if k in mandatory_keys}
  737. additional = {k: v for k, v in parsed.items() if k not in mandatory_keys}
  738. result = {"mandatory": mandatory}
  739. if extract_additional:
  740. result["additional"] = additional
  741. parsed = result
  742. else:
  743. return ProductAttributeService._create_error_response(
  744. "Invalid response structure",
  745. mandatory_attrs,
  746. extract_additional,
  747. str(parsed)
  748. )
  749. if "mandatory" in parsed:
  750. converted_mandatory = {}
  751. for attr, value in parsed["mandatory"].items():
  752. if isinstance(value, list):
  753. converted_mandatory[attr] = []
  754. for item in value:
  755. if isinstance(item, dict) and "value" in item:
  756. if "source" not in item:
  757. item["source"] = ProductAttributeService.find_value_source(
  758. item["value"], source_map, attr
  759. )
  760. converted_mandatory[attr].append(item)
  761. else:
  762. source = ProductAttributeService.find_value_source(str(item), source_map, attr)
  763. converted_mandatory[attr].append({"value": str(item), "source": source})
  764. else:
  765. source = ProductAttributeService.find_value_source(str(value), source_map, attr)
  766. converted_mandatory[attr] = [{"value": str(value), "source": source}]
  767. parsed["mandatory"] = converted_mandatory
  768. return parsed
  769. @staticmethod
  770. def _create_error_response(
  771. error: str,
  772. mandatory_attrs: Dict[str, List[str]],
  773. extract_additional: bool,
  774. raw_output: Optional[str] = None
  775. ) -> dict:
  776. """Create a standardized error response."""
  777. response = {
  778. "mandatory": {attr: [{"value": "Not Specified", "source": "error"}] for attr in mandatory_attrs.keys()},
  779. "error": error
  780. }
  781. if extract_additional:
  782. response["additional"] = {}
  783. if raw_output:
  784. response["raw_output"] = raw_output
  785. return response
  786. @staticmethod
  787. def get_cache_stats() -> Dict:
  788. """Get statistics for all caches including global status."""
  789. return {
  790. "global_caching_enabled": is_caching_enabled(),
  791. "simple_cache": SimpleCache.get_stats(),
  792. "embedding_cache": EmbeddingCache.get_stats()
  793. }
  794. @staticmethod
  795. def clear_all_caches():
  796. """Clear all caches."""
  797. SimpleCache.clear()
  798. EmbeddingCache.clear()
  799. logger.info("All caches cleared")