services.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. # ==================== services.py (WITH USER VALUE REASONING) ====================
  2. import json
  3. import hashlib
  4. import logging
  5. import time
  6. from functools import wraps
  7. from typing import Dict, List, Optional, Tuple
  8. import requests
  9. from django.conf import settings
  10. from .cache_config import (
  11. is_caching_enabled,
  12. ENABLE_ATTRIBUTE_EXTRACTION_CACHE,
  13. ATTRIBUTE_CACHE_MAX_SIZE,
  14. )
  15. logger = logging.getLogger(__name__)
  16. # --------------------------------------------------------------------------- #
  17. # CACHES
  18. # --------------------------------------------------------------------------- #
  19. class SimpleCache:
  20. _cache = {}
  21. _max_size = ATTRIBUTE_CACHE_MAX_SIZE
  22. @classmethod
  23. def get(cls, key: str) -> Optional[Dict]:
  24. if not ENABLE_ATTRIBUTE_EXTRACTION_CACHE: return None
  25. return cls._cache.get(key)
  26. @classmethod
  27. def set(cls, key: str, value: Dict):
  28. if not ENABLE_ATTRIBUTE_EXTRACTION_CACHE: return
  29. if len(cls._cache) >= cls._max_size:
  30. items = list(cls._cache.items())
  31. cls._cache = dict(items[int(cls._max_size * 0.2):])
  32. cls._cache[key] = value
  33. @classmethod
  34. def clear(cls): cls._cache.clear()
  35. @classmethod
  36. def get_stats(cls) -> Dict:
  37. return {
  38. "enabled": ENABLE_ATTRIBUTE_EXTRACTION_CACHE,
  39. "size": len(cls._cache),
  40. "max_size": cls._max_size,
  41. "usage_percent": round(len(cls._cache)/cls._max_size*100, 2) if cls._max_size else 0
  42. }
  43. # --------------------------------------------------------------------------- #
  44. # RETRY DECORATOR
  45. # --------------------------------------------------------------------------- #
  46. def retry(max_attempts=3, delay=1.0):
  47. def decorator(f):
  48. @wraps(f)
  49. def wrapper(*args, **kwargs):
  50. last_exc = None
  51. for i in range(max_attempts):
  52. try:
  53. return f(*args, **kwargs)
  54. except Exception as e:
  55. last_exc = e
  56. if i < max_attempts - 1:
  57. wait = delay * (2 ** i)
  58. logger.warning(f"Retry {i+1}/{max_attempts} after {wait}s: {e}")
  59. time.sleep(wait)
  60. raise last_exc or RuntimeError("Retry failed")
  61. return wrapper
  62. return decorator
  63. # --------------------------------------------------------------------------- #
  64. # MAIN SERVICE
  65. # --------------------------------------------------------------------------- #
  66. class ProductAttributeService:
  67. @staticmethod
  68. def combine_product_text(title=None, short_desc=None, long_desc=None, ocr_text=None) -> Tuple[str, Dict[str, str]]:
  69. parts = []
  70. source_map = {}
  71. if title:
  72. t = str(title).strip()
  73. parts.append(f"Title: {t}")
  74. source_map["title"] = t
  75. if short_desc:
  76. s = str(short_desc).strip()
  77. parts.append(f"Description: {s}")
  78. source_map["short_desc"] = s
  79. if long_desc:
  80. l = str(long_desc).strip()
  81. parts.append(f"Details: {l}")
  82. source_map["long_desc"] = l
  83. if ocr_text:
  84. parts.append(f"OCR Text: {ocr_text}")
  85. source_map["ocr_text"] = ocr_text
  86. combined = "\n".join(parts).strip()
  87. return (combined or "No product information", source_map)
  88. @staticmethod
  89. def _cache_key(product_text: str, mandatory_attrs: Dict, extract_additional: bool, multiple: List[str], user_values: Dict = None) -> str:
  90. payload = {
  91. "text": product_text,
  92. "attrs": mandatory_attrs,
  93. "extra": extract_additional,
  94. "multiple": sorted(multiple),
  95. "user_values": user_values or {}
  96. }
  97. return f"attr_{hashlib.md5(json.dumps(payload, sort_keys=True).encode()).hexdigest()}"
  98. @staticmethod
  99. def _clean_json(text: str) -> str:
  100. start = text.find("{")
  101. end = text.rfind("}") + 1
  102. if start != -1 and end > start:
  103. text = text[start:end]
  104. if "```json" in text:
  105. text = text.split("```json", 1)[1].split("```", 1)[0]
  106. elif "```" in text:
  107. text = text.split("```", 1)[1].split("```", 1)[0]
  108. if text.lstrip().startswith("json"): text = text[4:]
  109. return text.strip()
  110. @staticmethod
  111. def format_visual_attributes(visual_attributes: Dict) -> Dict:
  112. formatted = {}
  113. for key, value in visual_attributes.items():
  114. if isinstance(value, list):
  115. formatted[key] = [{"value": str(item), "source": "image"} for item in value]
  116. elif isinstance(value, dict):
  117. nested = {}
  118. for sub_key, sub_val in value.items():
  119. if isinstance(sub_val, list):
  120. nested[sub_key] = [{"value": str(v), "source": "image"} for v in sub_val]
  121. else:
  122. nested[sub_key] = [{"value": str(sub_val), "source": "image"}]
  123. formatted[key] = nested
  124. else:
  125. formatted[key] = [{"value": str(value), "source": "image"}]
  126. return formatted
  127. @staticmethod
  128. @retry(max_attempts=3, delay=1.0)
  129. def _call_llm(payload: dict) -> str:
  130. headers = {"Authorization": f"Bearer {settings.GROQ_API_KEY}", "Content-Type": "application/json"}
  131. resp = requests.post(settings.GROQ_API_URL, headers=headers, json=payload, timeout=30)
  132. resp.raise_for_status()
  133. return resp.json()["choices"][0]["message"]["content"]
  134. @staticmethod
  135. def extract_attributes(
  136. product_text: str,
  137. mandatory_attrs: Dict[str, List[str]],
  138. source_map: Dict[str, str] = None,
  139. model: str = None,
  140. extract_additional: bool = True,
  141. multiple: Optional[List[str]] = None,
  142. use_cache: Optional[bool] = None,
  143. user_entered_values: Optional[Dict[str, str]] = None, # NEW PARAMETER
  144. ) -> dict:
  145. if model is None: model = settings.SUPPORTED_MODELS[0]
  146. if multiple is None: multiple = []
  147. if source_map is None: source_map = {}
  148. if user_entered_values is None: user_entered_values = {}
  149. if use_cache is None: use_cache = ENABLE_ATTRIBUTE_EXTRACTION_CACHE
  150. if not is_caching_enabled(): use_cache = False
  151. cache_key = None
  152. if use_cache:
  153. cache_key = ProductAttributeService._cache_key(
  154. product_text, mandatory_attrs, extract_additional, multiple, user_entered_values
  155. )
  156. cached = SimpleCache.get(cache_key)
  157. if cached:
  158. logger.info(f"CACHE HIT {cache_key[:16]}...")
  159. return cached
  160. # --------------------------- BUILD USER VALUES SECTION ---------------------------
  161. user_values_section = ""
  162. if user_entered_values:
  163. user_lines = []
  164. for attr, value in user_entered_values.items():
  165. user_lines.append(f" - {attr}: {value}")
  166. user_values_section = f"""
  167. USER MANUALLY ENTERED VALUES:
  168. {chr(10).join(user_lines)}
  169. IMPORTANT INSTRUCTIONS FOR USER VALUES:
  170. 1. Compare the user-entered value with what you find in the product text
  171. 2. Evaluate if the user value is correct, partially correct, or incorrect for this product
  172. 3. Choose the BEST value (could be user's value, or from allowed list, or inferred)
  173. 4. Always provide a "reason" field explaining your decision
  174. 5. DO NOT hallucinate - be honest if user's value seems wrong based on product evidence
  175. 6. If user's value is not in the allowed list but seems correct, chose the most nearest value from the allowed list with proper reasoning.
  176. """
  177. # --------------------------- PROMPT ---------------------------
  178. allowed_lines = [f"{attr}: {', '.join(vals)}" for attr, vals in mandatory_attrs.items()]
  179. allowed_text = "\n".join(allowed_lines)
  180. allowed_sources = list(source_map.keys()) + ["title", "description", "inferred"]
  181. source_hint = "|".join(allowed_sources)
  182. multiple_text = f"\nMULTIPLE ALLOWED FOR: {', '.join(multiple)}" if multiple else ""
  183. print("Multiple text for attr: ")
  184. print(multiple_text)
  185. additional_instructions = """
  186. For the 'additional' section, identify any other important product attributes and their values (e.g., 'Color', 'Material', 'Weight' etc) that are present in the PRODUCT TEXT but not in the Mandatory Attribute list.
  187. For each additional attribute, use the best available value from the PRODUCT TEXT and specify the 'source'.
  188. """ if extract_additional else ""
  189. prompt = f"""
  190. You are a product-attribute classifier and validator.
  191. Understand the product text very deeply. If the same product is available somewhere online, use that knowledge to predict accurate attribute values.
  192. Do not depend only on word-by-word matching from the product text - interpret the meaning and suggest attributes intelligently.
  193. Pick the *closest meaning* value from the allowed list, even if not an exact word match.
  194. I want values for all mandatory attributes.
  195. If a value is not found anywhere, the source should be "inferred".
  196. Note: Source means from where you have concluded the result. Choose one of these value <{source_hint}>
  197. ALLOWED VALUES (MANDATORY):
  198. {allowed_text}
  199. Note: Always return multiple values for these attributes: {multiple_text}. These values must be most possible values from the list and should be max 2 values.
  200. {user_values_section}
  201. {additional_instructions}
  202. PRODUCT TEXT:
  203. {product_text}
  204. OUTPUT (strict JSON only):
  205. {{
  206. "mandatory": {{
  207. "<attr>": [{{
  208. "value": "<chosen_value>",
  209. "source": "<{source_hint}>",
  210. "reason": "Explanation of why this value was chosen. If user provided a value, explain why you agreed/disagreed with it.",
  211. "original_value": "<user_entered_value_if_provided>",
  212. "decision": "accepted|rejected"
  213. }}]
  214. }},
  215. "additional": {{
  216. "Additional_Attr_1": [{{
  217. "value": "Value 1",
  218. "source": "<{source_hint}>",
  219. "reason": "Why this attribute and value were identified"
  220. }}]
  221. }}
  222. }}
  223. RULES:
  224. - For each mandatory attribute with a user-entered value, include "original_value" and "decision" fields
  225. - "decision" values: "accepted" (used user's value), "rejected" (used different value), "not_provided" (no user value given)
  226. - "reason" must explain your choice, especially when rejecting user input
  227. - For 'additional' attributes: Strictly Extract other key attributes other than mandatory attributes from the text.
  228. - For 'multiple' attributes, always give multiple value for those attribues, choose wisely and max 2 multiple attribute that are very close.
  229. - Source must be one of: {source_hint}
  230. - Be honest and specific in your reasoning.
  231. - Return ONLY valid JSON
  232. """
  233. payload = {
  234. "model": model,
  235. "messages": [
  236. {"role": "system", "content": "You are a JSON-only extractor and validator. Always provide clear reasoning for your decisions."},
  237. {"role": "user", "content": prompt},
  238. ],
  239. "temperature": 0.3,
  240. "max_tokens": 2000, # Increased for reasoning
  241. }
  242. try:
  243. raw = ProductAttributeService._call_llm(payload)
  244. logger.info("Raw LLM response received")
  245. print(raw)
  246. cleaned = ProductAttributeService._clean_json(raw)
  247. parsed = json.loads(cleaned)
  248. except Exception as exc:
  249. logger.error(f"LLM failed: {exc}")
  250. return {
  251. "mandatory": {
  252. a: [{
  253. "value": "Not Specified",
  254. "source": "llm_error",
  255. "reason": f"LLM processing failed: {str(exc)}"
  256. }] for a in mandatory_attrs
  257. },
  258. "additional": {} if not extract_additional else {},
  259. "error": str(exc)
  260. }
  261. if use_cache and cache_key:
  262. SimpleCache.set(cache_key, parsed)
  263. logger.info(f"CACHE SET {cache_key[:16]}...")
  264. return parsed
  265. @staticmethod
  266. def get_cache_stats() -> Dict:
  267. return {
  268. "global_enabled": is_caching_enabled(),
  269. "result_cache": SimpleCache.get_stats(),
  270. }
  271. @staticmethod
  272. def clear_all_caches():
  273. SimpleCache.clear()
  274. logger.info("All caches cleared")