# ==================== services.py (FINAL PERFECT + FULL WHITELIST + SEMANTIC RECOVERY) ==================== import json import hashlib import logging import warnings import time from functools import wraps from typing import Dict, List, Optional, Tuple import os import requests from django.conf import settings from sentence_transformers import SentenceTransformer, util # --------------------------------------------------------------------------- # # CACHE CONFIG # --------------------------------------------------------------------------- # from .cache_config import ( is_caching_enabled, ENABLE_ATTRIBUTE_EXTRACTION_CACHE, ENABLE_EMBEDDING_CACHE, ATTRIBUTE_CACHE_MAX_SIZE, EMBEDDING_CACHE_MAX_SIZE, ) logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # # ONE-TIME MODEL LOAD # --------------------------------------------------------------------------- # print("Loading sentence transformer model (semantic recovery)...") model_embedder = SentenceTransformer("all-MiniLM-L6-v2") os.environ["TOKENIZERS_PARALLELISM"] = "false" print("Model loaded") # --------------------------------------------------------------------------- # # CACHES # --------------------------------------------------------------------------- # class SimpleCache: _cache = {} _max_size = ATTRIBUTE_CACHE_MAX_SIZE @classmethod def get(cls, key: str) -> Optional[Dict]: if not ENABLE_ATTRIBUTE_EXTRACTION_CACHE: return None return cls._cache.get(key) @classmethod def set(cls, key: str, value: Dict): if not ENABLE_ATTRIBUTE_EXTRACTION_CACHE: return if len(cls._cache) >= cls._max_size: items = list(cls._cache.items()) cls._cache = dict(items[int(cls._max_size * 0.2):]) cls._cache[key] = value @classmethod def clear(cls): cls._cache.clear() @classmethod def get_stats(cls) -> Dict: return { "enabled": ENABLE_ATTRIBUTE_EXTRACTION_CACHE, "size": len(cls._cache), "max_size": cls._max_size, "usage_percent": round(len(cls._cache)/cls._max_size*100, 2) if cls._max_size else 0 } class EmbeddingCache: _cache = {} _max_size = EMBEDDING_CACHE_MAX_SIZE _hit = _miss = 0 @classmethod def get_embedding(cls, text: str, model): if not ENABLE_EMBEDDING_CACHE: with warnings.catch_warnings(): warnings.simplefilter("ignore") return model.encode(text, convert_to_tensor=True, show_progress_bar=False) if text in cls._cache: cls._hit += 1 return cls._cache[text] cls._miss += 1 if len(cls._cache) >= cls._max_size: items = list(cls._cache.items()) cls._cache = dict(items[int(cls._max_size * 0.3):]) with warnings.catch_warnings(): warnings.simplefilter("ignore") emb = model.encode(text, convert_to_tensor=True, show_progress_bar=False) cls._cache[text] = emb return emb @classmethod def clear(cls): cls._cache.clear() cls._hit = cls._miss = 0 @classmethod def get_stats(cls) -> Dict: total = cls._hit + cls._miss rate = (cls._hit / total * 100) if total else 0 return { "enabled": ENABLE_EMBEDDING_CACHE, "size": len(cls._cache), "hits": cls._hit, "misses": cls._miss, "hit_rate_percent": round(rate, 2), } # --------------------------------------------------------------------------- # # RETRY DECORATOR # --------------------------------------------------------------------------- # def retry(max_attempts=3, delay=1.0): def decorator(f): @wraps(f) def wrapper(*args, **kwargs): last_exc = None for i in range(max_attempts): try: return f(*args, **kwargs) except Exception as e: last_exc = e if i < max_attempts - 1: wait = delay * (2 ** i) logger.warning(f"Retry {i+1}/{max_attempts} after {wait}s: {e}") time.sleep(wait) raise last_exc or RuntimeError("Retry failed") return wrapper return decorator # --------------------------------------------------------------------------- # # MAIN SERVICE # --------------------------------------------------------------------------- # class ProductAttributeService: @staticmethod def combine_product_text(title=None, short_desc=None, long_desc=None, ocr_text=None) -> Tuple[str, Dict[str, str]]: parts = [] source_map = {} if title: t = str(title).strip() parts.append(f"Title: {t}") source_map["title"] = t if short_desc: s = str(short_desc).strip() parts.append(f"Description: {s}") source_map["short_desc"] = s if long_desc: l = str(long_desc).strip() parts.append(f"Details: {l}") source_map["long_desc"] = l if ocr_text: parts.append(f"OCR Text: {ocr_text}") source_map["ocr_text"] = ocr_text combined = "\n".join(parts).strip() return (combined or "No product information", source_map) @staticmethod def _cache_key(product_text: str, mandatory_attrs: Dict, extract_additional: bool, multiple: List[str]) -> str: payload = {"text": product_text, "attrs": mandatory_attrs, "extra": extract_additional, "multiple": sorted(multiple)} return f"attr_{hashlib.md5(json.dumps(payload, sort_keys=True).encode()).hexdigest()}" @staticmethod def _clean_json(text: str) -> str: start = text.find("{") end = text.rfind("}") + 1 if start != -1 and end > start: text = text[start:end] if "```json" in text: text = text.split("```json", 1)[1].split("```", 1)[0] elif "```" in text: text = text.split("```", 1)[1].split("```", 1)[0] if text.lstrip().startswith("json"): text = text[4:] return text.strip() @staticmethod def _lexical_evidence(product_text: str, label: str) -> float: pt = product_text.lower() tokens = [t for t in label.lower().replace("-", " ").split() if t] if not tokens: return 0.0 hits = sum(1 for t in tokens if t in pt) return hits / len(tokens) @staticmethod def _find_source(value: str, source_map: Dict[str, str]) -> str: value_lower = value.lower() for src_key, text in source_map.items(): if value_lower in text.lower(): return src_key return "not_found" @staticmethod def format_visual_attributes(visual_attributes: Dict) -> Dict: formatted = {} for key, value in visual_attributes.items(): if isinstance(value, list): formatted[key] = [{"value": str(item), "source": "image"} for item in value] elif isinstance(value, dict): nested = {} for sub_key, sub_val in value.items(): if isinstance(sub_val, list): nested[sub_key] = [{"value": str(v), "source": "image"} for v in sub_val] else: nested[sub_key] = [{"value": str(sub_val), "source": "image"}] formatted[key] = nested else: formatted[key] = [{"value": str(value), "source": "image"}] return formatted @staticmethod @retry(max_attempts=3, delay=1.0) def _call_llm(payload: dict) -> str: headers = {"Authorization": f"Bearer {settings.GROQ_API_KEY}", "Content-Type": "application/json"} resp = requests.post(settings.GROQ_API_URL, headers=headers, json=payload, timeout=30) resp.raise_for_status() return resp.json()["choices"][0]["message"]["content"] @staticmethod def extract_attributes( product_text: str, mandatory_attrs: Dict[str, List[str]], source_map: Dict[str, str] = None, model: str = None, extract_additional: bool = True, multiple: Optional[List[str]] = None, use_cache: Optional[bool] = None, ) -> dict: if model is None: model = settings.SUPPORTED_MODELS[0] if multiple is None: multiple = [] if source_map is None: source_map = {} if use_cache is None: use_cache = ENABLE_ATTRIBUTE_EXTRACTION_CACHE if not is_caching_enabled(): use_cache = False cache_key = None if use_cache: cache_key = ProductAttributeService._cache_key(product_text, mandatory_attrs, extract_additional, multiple) cached = SimpleCache.get(cache_key) if cached: logger.info(f"CACHE HIT {cache_key[:16]}...") return cached # --------------------------- PROMPT WITH FULL WHITELIST --------------------------- allowed_lines = [f"{attr}: {', '.join(vals)}" for attr, vals in mandatory_attrs.items()] allowed_text = "\n".join(allowed_lines) allowed_sources = list(source_map.keys()) + ["not_found"] source_hint = "|".join(allowed_sources) multiple_text = f"\nMULTIPLE ALLOWED FOR: {', '.join(multiple)}" if multiple else "" prompt = f""" You are a product-attribute classifier. Pick **exactly one** value from the list below for each attribute. If nothing matches, return "Not Specified". ALLOWED VALUES: {allowed_text} {multiple_text} PRODUCT TEXT: {product_text} OUTPUT (strict JSON only): {{ "mandatory": {{ "": [{{"value":"", "source":"<{source_hint}>"}}] }}, "additional": {{}} }} RULES: - Pick from allowed values only - If not found: "Not Specified" + "not_found" - Source must be: {source_hint} - Return ONLY JSON """ payload = { "model": model, "messages": [ {"role": "system", "content": "You are a JSON-only extractor."}, {"role": "user", "content": prompt}, ], "temperature": 0.0, "max_tokens": 1200, } try: raw = ProductAttributeService._call_llm(payload) cleaned = ProductAttributeService._clean_json(raw) parsed = json.loads(cleaned) except Exception as exc: logger.error(f"LLM failed: {exc}") return { "mandatory": {a: [{"value": "Not Specified", "source": "llm_error"}] for a in mandatory_attrs}, "additional": {} if not extract_additional else {}, "error": str(exc) } # --------------------------- VALIDATION + SMART RECOVERY --------------------------- pt_emb = EmbeddingCache.get_embedding(product_text, model_embedder) def _sanitize(section: dict, allowed: Dict): sanitized = {} for attr, items in section.items(): if attr not in allowed: continue chosen = [] for it in (items if isinstance(items, list) else [items]): if not isinstance(it, dict): it = {"value": str(it), "source": "not_found"} val = str(it.get("value", "")).strip() src = str(it.get("source", "not_found")).lower() # --- LLM SAYS "Not Specified" → SMART RECOVERY --- if val == "Not Specified": # 1. Lexical recovery for av in allowed[attr]: if ProductAttributeService._lexical_evidence(product_text, av) > 0.6: src = ProductAttributeService._find_source(av, source_map) chosen.append({"value": av, "source": src}) break else: # 2. Semantic recovery best_val, best_score = max( ((av, float(util.cos_sim(pt_emb, EmbeddingCache.get_embedding(av, model_embedder)).item())) for av in allowed[attr]), key=lambda x: x[1] ) if best_score > 0.75: src = ProductAttributeService._find_source(best_val, source_map) chosen.append({"value": best_val, "source": src}) else: chosen.append({"value": "Not Specified", "source": "not_found"}) continue # --- VALIDATE LLM CHOICE --- if val not in allowed[attr]: continue if ProductAttributeService._lexical_evidence(product_text, val) < 0.2: continue if src not in source_map and src != "not_found": src = "not_found" chosen.append({"value": val, "source": src}) sanitized[attr] = chosen or [{"value": "Not Specified", "source": "not_found"}] return sanitized parsed["mandatory"] = _sanitize(parsed.get("mandatory", {}), mandatory_attrs) # --- ADDITIONAL ATTRIBUTES --- if extract_additional and "additional" in parsed: additional = {} for attr, items in parsed["additional"].items(): good = [] for it in (items if isinstance(items, list) else [items]): if not isinstance(it, dict): it = {"value": str(it), "source": "not_found"} val = str(it.get("value", "")).strip() src = str(it.get("source", "not_found")).lower() if src not in source_map and src != "not_found": src = "not_found" if val: good.append({"value": val, "source": src}) if good: additional[attr] = good parsed["additional"] = additional else: parsed.pop("additional", None) if use_cache and cache_key: SimpleCache.set(cache_key, parsed) logger.info(f"CACHE SET {cache_key[:16]}...") return parsed @staticmethod def get_cache_stats() -> Dict: return { "global_enabled": is_caching_enabled(), "result_cache": SimpleCache.get_stats(), "embedding_cache": EmbeddingCache.get_stats(), } @staticmethod def clear_all_caches(): SimpleCache.clear() EmbeddingCache.clear() logger.info("All caches cleared")