services.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. # ==================== services.py (FINAL PERFECT + FULL WHITELIST + SEMANTIC RECOVERY) ====================
  2. import json
  3. import hashlib
  4. import logging
  5. import warnings
  6. import time
  7. from functools import wraps
  8. from typing import Dict, List, Optional, Tuple
  9. import os
  10. import requests
  11. from django.conf import settings
  12. from sentence_transformers import SentenceTransformer, util
  13. # --------------------------------------------------------------------------- #
  14. # CACHE CONFIG
  15. # --------------------------------------------------------------------------- #
  16. from .cache_config import (
  17. is_caching_enabled,
  18. ENABLE_ATTRIBUTE_EXTRACTION_CACHE,
  19. ENABLE_EMBEDDING_CACHE,
  20. ATTRIBUTE_CACHE_MAX_SIZE,
  21. EMBEDDING_CACHE_MAX_SIZE,
  22. )
  23. logger = logging.getLogger(__name__)
  24. # --------------------------------------------------------------------------- #
  25. # ONE-TIME MODEL LOAD
  26. # --------------------------------------------------------------------------- #
  27. print("Loading sentence transformer model (semantic recovery)...")
  28. model_embedder = SentenceTransformer("all-MiniLM-L6-v2")
  29. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  30. print("Model loaded")
  31. # --------------------------------------------------------------------------- #
  32. # CACHES
  33. # --------------------------------------------------------------------------- #
  34. class SimpleCache:
  35. _cache = {}
  36. _max_size = ATTRIBUTE_CACHE_MAX_SIZE
  37. @classmethod
  38. def get(cls, key: str) -> Optional[Dict]:
  39. if not ENABLE_ATTRIBUTE_EXTRACTION_CACHE: return None
  40. return cls._cache.get(key)
  41. @classmethod
  42. def set(cls, key: str, value: Dict):
  43. if not ENABLE_ATTRIBUTE_EXTRACTION_CACHE: return
  44. if len(cls._cache) >= cls._max_size:
  45. items = list(cls._cache.items())
  46. cls._cache = dict(items[int(cls._max_size * 0.2):])
  47. cls._cache[key] = value
  48. @classmethod
  49. def clear(cls): cls._cache.clear()
  50. @classmethod
  51. def get_stats(cls) -> Dict:
  52. return {
  53. "enabled": ENABLE_ATTRIBUTE_EXTRACTION_CACHE,
  54. "size": len(cls._cache),
  55. "max_size": cls._max_size,
  56. "usage_percent": round(len(cls._cache)/cls._max_size*100, 2) if cls._max_size else 0
  57. }
  58. class EmbeddingCache:
  59. _cache = {}
  60. _max_size = EMBEDDING_CACHE_MAX_SIZE
  61. _hit = _miss = 0
  62. @classmethod
  63. def get_embedding(cls, text: str, model):
  64. if not ENABLE_EMBEDDING_CACHE:
  65. with warnings.catch_warnings():
  66. warnings.simplefilter("ignore")
  67. return model.encode(text, convert_to_tensor=True, show_progress_bar=False)
  68. if text in cls._cache:
  69. cls._hit += 1
  70. return cls._cache[text]
  71. cls._miss += 1
  72. if len(cls._cache) >= cls._max_size:
  73. items = list(cls._cache.items())
  74. cls._cache = dict(items[int(cls._max_size * 0.3):])
  75. with warnings.catch_warnings():
  76. warnings.simplefilter("ignore")
  77. emb = model.encode(text, convert_to_tensor=True, show_progress_bar=False)
  78. cls._cache[text] = emb
  79. return emb
  80. @classmethod
  81. def clear(cls):
  82. cls._cache.clear()
  83. cls._hit = cls._miss = 0
  84. @classmethod
  85. def get_stats(cls) -> Dict:
  86. total = cls._hit + cls._miss
  87. rate = (cls._hit / total * 100) if total else 0
  88. return {
  89. "enabled": ENABLE_EMBEDDING_CACHE,
  90. "size": len(cls._cache),
  91. "hits": cls._hit,
  92. "misses": cls._miss,
  93. "hit_rate_percent": round(rate, 2),
  94. }
  95. # --------------------------------------------------------------------------- #
  96. # RETRY DECORATOR
  97. # --------------------------------------------------------------------------- #
  98. def retry(max_attempts=3, delay=1.0):
  99. def decorator(f):
  100. @wraps(f)
  101. def wrapper(*args, **kwargs):
  102. last_exc = None
  103. for i in range(max_attempts):
  104. try:
  105. return f(*args, **kwargs)
  106. except Exception as e:
  107. last_exc = e
  108. if i < max_attempts - 1:
  109. wait = delay * (2 ** i)
  110. logger.warning(f"Retry {i+1}/{max_attempts} after {wait}s: {e}")
  111. time.sleep(wait)
  112. raise last_exc or RuntimeError("Retry failed")
  113. return wrapper
  114. return decorator
  115. # --------------------------------------------------------------------------- #
  116. # MAIN SERVICE
  117. # --------------------------------------------------------------------------- #
  118. class ProductAttributeService:
  119. @staticmethod
  120. def combine_product_text(title=None, short_desc=None, long_desc=None, ocr_text=None) -> Tuple[str, Dict[str, str]]:
  121. parts = []
  122. source_map = {}
  123. if title:
  124. t = str(title).strip()
  125. parts.append(f"Title: {t}")
  126. source_map["title"] = t
  127. if short_desc:
  128. s = str(short_desc).strip()
  129. parts.append(f"Description: {s}")
  130. source_map["short_desc"] = s
  131. if long_desc:
  132. l = str(long_desc).strip()
  133. parts.append(f"Details: {l}")
  134. source_map["long_desc"] = l
  135. if ocr_text:
  136. parts.append(f"OCR Text: {ocr_text}")
  137. source_map["ocr_text"] = ocr_text
  138. combined = "\n".join(parts).strip()
  139. return (combined or "No product information", source_map)
  140. @staticmethod
  141. def _cache_key(product_text: str, mandatory_attrs: Dict, extract_additional: bool, multiple: List[str]) -> str:
  142. payload = {"text": product_text, "attrs": mandatory_attrs, "extra": extract_additional, "multiple": sorted(multiple)}
  143. return f"attr_{hashlib.md5(json.dumps(payload, sort_keys=True).encode()).hexdigest()}"
  144. @staticmethod
  145. def _clean_json(text: str) -> str:
  146. start = text.find("{")
  147. end = text.rfind("}") + 1
  148. if start != -1 and end > start:
  149. text = text[start:end]
  150. if "```json" in text:
  151. text = text.split("```json", 1)[1].split("```", 1)[0]
  152. elif "```" in text:
  153. text = text.split("```", 1)[1].split("```", 1)[0]
  154. if text.lstrip().startswith("json"): text = text[4:]
  155. return text.strip()
  156. @staticmethod
  157. def _lexical_evidence(product_text: str, label: str) -> float:
  158. pt = product_text.lower()
  159. tokens = [t for t in label.lower().replace("-", " ").split() if t]
  160. if not tokens: return 0.0
  161. hits = sum(1 for t in tokens if t in pt)
  162. return hits / len(tokens)
  163. @staticmethod
  164. def _find_source(value: str, source_map: Dict[str, str]) -> str:
  165. value_lower = value.lower()
  166. for src_key, text in source_map.items():
  167. if value_lower in text.lower():
  168. return src_key
  169. return "not_found"
  170. @staticmethod
  171. def format_visual_attributes(visual_attributes: Dict) -> Dict:
  172. formatted = {}
  173. for key, value in visual_attributes.items():
  174. if isinstance(value, list):
  175. formatted[key] = [{"value": str(item), "source": "image"} for item in value]
  176. elif isinstance(value, dict):
  177. nested = {}
  178. for sub_key, sub_val in value.items():
  179. if isinstance(sub_val, list):
  180. nested[sub_key] = [{"value": str(v), "source": "image"} for v in sub_val]
  181. else:
  182. nested[sub_key] = [{"value": str(sub_val), "source": "image"}]
  183. formatted[key] = nested
  184. else:
  185. formatted[key] = [{"value": str(value), "source": "image"}]
  186. return formatted
  187. @staticmethod
  188. @retry(max_attempts=3, delay=1.0)
  189. def _call_llm(payload: dict) -> str:
  190. headers = {"Authorization": f"Bearer {settings.GROQ_API_KEY}", "Content-Type": "application/json"}
  191. resp = requests.post(settings.GROQ_API_URL, headers=headers, json=payload, timeout=30)
  192. resp.raise_for_status()
  193. return resp.json()["choices"][0]["message"]["content"]
  194. @staticmethod
  195. def extract_attributes(
  196. product_text: str,
  197. mandatory_attrs: Dict[str, List[str]],
  198. source_map: Dict[str, str] = None,
  199. model: str = None,
  200. extract_additional: bool = True,
  201. multiple: Optional[List[str]] = None,
  202. use_cache: Optional[bool] = None,
  203. ) -> dict:
  204. if model is None: model = settings.SUPPORTED_MODELS[0]
  205. if multiple is None: multiple = []
  206. if source_map is None: source_map = {}
  207. if use_cache is None: use_cache = ENABLE_ATTRIBUTE_EXTRACTION_CACHE
  208. if not is_caching_enabled(): use_cache = False
  209. cache_key = None
  210. if use_cache:
  211. cache_key = ProductAttributeService._cache_key(product_text, mandatory_attrs, extract_additional, multiple)
  212. cached = SimpleCache.get(cache_key)
  213. if cached:
  214. logger.info(f"CACHE HIT {cache_key[:16]}...")
  215. return cached
  216. # --------------------------- PROMPT WITH FULL WHITELIST ---------------------------
  217. allowed_lines = [f"{attr}: {', '.join(vals)}" for attr, vals in mandatory_attrs.items()]
  218. allowed_text = "\n".join(allowed_lines)
  219. allowed_sources = list(source_map.keys()) + ["not_found"]
  220. source_hint = "|".join(allowed_sources)
  221. multiple_text = f"\nMULTIPLE ALLOWED FOR: {', '.join(multiple)}" if multiple else ""
  222. prompt = f"""
  223. You are a product-attribute classifier.
  224. Pick **exactly one** value from the list below for each attribute.
  225. If nothing matches, return "Not Specified".
  226. ALLOWED VALUES:
  227. {allowed_text}
  228. {multiple_text}
  229. PRODUCT TEXT:
  230. {product_text}
  231. OUTPUT (strict JSON only):
  232. {{
  233. "mandatory": {{
  234. "<attr>": [{{"value":"<chosen>", "source":"<{source_hint}>"}}]
  235. }},
  236. "additional": {{}}
  237. }}
  238. RULES:
  239. - Pick from allowed values only
  240. - If not found: "Not Specified" + "not_found"
  241. - Source must be: {source_hint}
  242. - Return ONLY JSON
  243. """
  244. payload = {
  245. "model": model,
  246. "messages": [
  247. {"role": "system", "content": "You are a JSON-only extractor."},
  248. {"role": "user", "content": prompt},
  249. ],
  250. "temperature": 0.0,
  251. "max_tokens": 1200,
  252. }
  253. try:
  254. raw = ProductAttributeService._call_llm(payload)
  255. cleaned = ProductAttributeService._clean_json(raw)
  256. parsed = json.loads(cleaned)
  257. except Exception as exc:
  258. logger.error(f"LLM failed: {exc}")
  259. return {
  260. "mandatory": {a: [{"value": "Not Specified", "source": "llm_error"}] for a in mandatory_attrs},
  261. "additional": {} if not extract_additional else {},
  262. "error": str(exc)
  263. }
  264. # --------------------------- VALIDATION + SMART RECOVERY ---------------------------
  265. pt_emb = EmbeddingCache.get_embedding(product_text, model_embedder)
  266. def _sanitize(section: dict, allowed: Dict):
  267. sanitized = {}
  268. for attr, items in section.items():
  269. if attr not in allowed: continue
  270. chosen = []
  271. for it in (items if isinstance(items, list) else [items]):
  272. if not isinstance(it, dict): it = {"value": str(it), "source": "not_found"}
  273. val = str(it.get("value", "")).strip()
  274. src = str(it.get("source", "not_found")).lower()
  275. # --- LLM SAYS "Not Specified" → SMART RECOVERY ---
  276. if val == "Not Specified":
  277. # 1. Lexical recovery
  278. for av in allowed[attr]:
  279. if ProductAttributeService._lexical_evidence(product_text, av) > 0.6:
  280. src = ProductAttributeService._find_source(av, source_map)
  281. chosen.append({"value": av, "source": src})
  282. break
  283. else:
  284. # 2. Semantic recovery
  285. best_val, best_score = max(
  286. ((av, float(util.cos_sim(pt_emb, EmbeddingCache.get_embedding(av, model_embedder)).item()))
  287. for av in allowed[attr]),
  288. key=lambda x: x[1]
  289. )
  290. if best_score > 0.75:
  291. src = ProductAttributeService._find_source(best_val, source_map)
  292. chosen.append({"value": best_val, "source": src})
  293. else:
  294. chosen.append({"value": "Not Specified", "source": "not_found"})
  295. continue
  296. # --- VALIDATE LLM CHOICE ---
  297. if val not in allowed[attr]: continue
  298. if ProductAttributeService._lexical_evidence(product_text, val) < 0.2: continue
  299. if src not in source_map and src != "not_found": src = "not_found"
  300. chosen.append({"value": val, "source": src})
  301. sanitized[attr] = chosen or [{"value": "Not Specified", "source": "not_found"}]
  302. return sanitized
  303. parsed["mandatory"] = _sanitize(parsed.get("mandatory", {}), mandatory_attrs)
  304. # --- ADDITIONAL ATTRIBUTES ---
  305. if extract_additional and "additional" in parsed:
  306. additional = {}
  307. for attr, items in parsed["additional"].items():
  308. good = []
  309. for it in (items if isinstance(items, list) else [items]):
  310. if not isinstance(it, dict): it = {"value": str(it), "source": "not_found"}
  311. val = str(it.get("value", "")).strip()
  312. src = str(it.get("source", "not_found")).lower()
  313. if src not in source_map and src != "not_found": src = "not_found"
  314. if val: good.append({"value": val, "source": src})
  315. if good: additional[attr] = good
  316. parsed["additional"] = additional
  317. else:
  318. parsed.pop("additional", None)
  319. if use_cache and cache_key:
  320. SimpleCache.set(cache_key, parsed)
  321. logger.info(f"CACHE SET {cache_key[:16]}...")
  322. return parsed
  323. @staticmethod
  324. def get_cache_stats() -> Dict:
  325. return {
  326. "global_enabled": is_caching_enabled(),
  327. "result_cache": SimpleCache.get_stats(),
  328. "embedding_cache": EmbeddingCache.get_stats(),
  329. }
  330. @staticmethod
  331. def clear_all_caches():
  332. SimpleCache.clear()
  333. EmbeddingCache.clear()
  334. logger.info("All caches cleared")