| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376 |
- # ==================== services.py (FINAL PERFECT + FULL WHITELIST + SEMANTIC RECOVERY) ====================
- import json
- import hashlib
- import logging
- import warnings
- import time
- from functools import wraps
- from typing import Dict, List, Optional, Tuple
- import os
- import requests
- from django.conf import settings
- from sentence_transformers import SentenceTransformer, util
- # --------------------------------------------------------------------------- #
- # CACHE CONFIG
- # --------------------------------------------------------------------------- #
- from .cache_config import (
- is_caching_enabled,
- ENABLE_ATTRIBUTE_EXTRACTION_CACHE,
- ENABLE_EMBEDDING_CACHE,
- ATTRIBUTE_CACHE_MAX_SIZE,
- EMBEDDING_CACHE_MAX_SIZE,
- )
- logger = logging.getLogger(__name__)
- # --------------------------------------------------------------------------- #
- # ONE-TIME MODEL LOAD
- # --------------------------------------------------------------------------- #
- print("Loading sentence transformer model (semantic recovery)...")
- model_embedder = SentenceTransformer("all-MiniLM-L6-v2")
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
- print("Model loaded")
- # --------------------------------------------------------------------------- #
- # CACHES
- # --------------------------------------------------------------------------- #
- class SimpleCache:
- _cache = {}
- _max_size = ATTRIBUTE_CACHE_MAX_SIZE
- @classmethod
- def get(cls, key: str) -> Optional[Dict]:
- if not ENABLE_ATTRIBUTE_EXTRACTION_CACHE: return None
- return cls._cache.get(key)
- @classmethod
- def set(cls, key: str, value: Dict):
- if not ENABLE_ATTRIBUTE_EXTRACTION_CACHE: return
- if len(cls._cache) >= cls._max_size:
- items = list(cls._cache.items())
- cls._cache = dict(items[int(cls._max_size * 0.2):])
- cls._cache[key] = value
- @classmethod
- def clear(cls): cls._cache.clear()
- @classmethod
- def get_stats(cls) -> Dict:
- return {
- "enabled": ENABLE_ATTRIBUTE_EXTRACTION_CACHE,
- "size": len(cls._cache),
- "max_size": cls._max_size,
- "usage_percent": round(len(cls._cache)/cls._max_size*100, 2) if cls._max_size else 0
- }
- class EmbeddingCache:
- _cache = {}
- _max_size = EMBEDDING_CACHE_MAX_SIZE
- _hit = _miss = 0
- @classmethod
- def get_embedding(cls, text: str, model):
- if not ENABLE_EMBEDDING_CACHE:
- with warnings.catch_warnings():
- warnings.simplefilter("ignore")
- return model.encode(text, convert_to_tensor=True, show_progress_bar=False)
- if text in cls._cache:
- cls._hit += 1
- return cls._cache[text]
- cls._miss += 1
- if len(cls._cache) >= cls._max_size:
- items = list(cls._cache.items())
- cls._cache = dict(items[int(cls._max_size * 0.3):])
- with warnings.catch_warnings():
- warnings.simplefilter("ignore")
- emb = model.encode(text, convert_to_tensor=True, show_progress_bar=False)
- cls._cache[text] = emb
- return emb
- @classmethod
- def clear(cls):
- cls._cache.clear()
- cls._hit = cls._miss = 0
- @classmethod
- def get_stats(cls) -> Dict:
- total = cls._hit + cls._miss
- rate = (cls._hit / total * 100) if total else 0
- return {
- "enabled": ENABLE_EMBEDDING_CACHE,
- "size": len(cls._cache),
- "hits": cls._hit,
- "misses": cls._miss,
- "hit_rate_percent": round(rate, 2),
- }
- # --------------------------------------------------------------------------- #
- # RETRY DECORATOR
- # --------------------------------------------------------------------------- #
- def retry(max_attempts=3, delay=1.0):
- def decorator(f):
- @wraps(f)
- def wrapper(*args, **kwargs):
- last_exc = None
- for i in range(max_attempts):
- try:
- return f(*args, **kwargs)
- except Exception as e:
- last_exc = e
- if i < max_attempts - 1:
- wait = delay * (2 ** i)
- logger.warning(f"Retry {i+1}/{max_attempts} after {wait}s: {e}")
- time.sleep(wait)
- raise last_exc or RuntimeError("Retry failed")
- return wrapper
- return decorator
- # --------------------------------------------------------------------------- #
- # MAIN SERVICE
- # --------------------------------------------------------------------------- #
- class ProductAttributeService:
- @staticmethod
- def combine_product_text(title=None, short_desc=None, long_desc=None, ocr_text=None) -> Tuple[str, Dict[str, str]]:
- parts = []
- source_map = {}
- if title:
- t = str(title).strip()
- parts.append(f"Title: {t}")
- source_map["title"] = t
- if short_desc:
- s = str(short_desc).strip()
- parts.append(f"Description: {s}")
- source_map["short_desc"] = s
- if long_desc:
- l = str(long_desc).strip()
- parts.append(f"Details: {l}")
- source_map["long_desc"] = l
- if ocr_text:
- parts.append(f"OCR Text: {ocr_text}")
- source_map["ocr_text"] = ocr_text
- combined = "\n".join(parts).strip()
- return (combined or "No product information", source_map)
- @staticmethod
- def _cache_key(product_text: str, mandatory_attrs: Dict, extract_additional: bool, multiple: List[str]) -> str:
- payload = {"text": product_text, "attrs": mandatory_attrs, "extra": extract_additional, "multiple": sorted(multiple)}
- return f"attr_{hashlib.md5(json.dumps(payload, sort_keys=True).encode()).hexdigest()}"
- @staticmethod
- def _clean_json(text: str) -> str:
- start = text.find("{")
- end = text.rfind("}") + 1
- if start != -1 and end > start:
- text = text[start:end]
- if "```json" in text:
- text = text.split("```json", 1)[1].split("```", 1)[0]
- elif "```" in text:
- text = text.split("```", 1)[1].split("```", 1)[0]
- if text.lstrip().startswith("json"): text = text[4:]
- return text.strip()
- @staticmethod
- def _lexical_evidence(product_text: str, label: str) -> float:
- pt = product_text.lower()
- tokens = [t for t in label.lower().replace("-", " ").split() if t]
- if not tokens: return 0.0
- hits = sum(1 for t in tokens if t in pt)
- return hits / len(tokens)
- @staticmethod
- def _find_source(value: str, source_map: Dict[str, str]) -> str:
- value_lower = value.lower()
- for src_key, text in source_map.items():
- if value_lower in text.lower():
- return src_key
- return "not_found"
- @staticmethod
- def format_visual_attributes(visual_attributes: Dict) -> Dict:
- formatted = {}
- for key, value in visual_attributes.items():
- if isinstance(value, list):
- formatted[key] = [{"value": str(item), "source": "image"} for item in value]
- elif isinstance(value, dict):
- nested = {}
- for sub_key, sub_val in value.items():
- if isinstance(sub_val, list):
- nested[sub_key] = [{"value": str(v), "source": "image"} for v in sub_val]
- else:
- nested[sub_key] = [{"value": str(sub_val), "source": "image"}]
- formatted[key] = nested
- else:
- formatted[key] = [{"value": str(value), "source": "image"}]
- return formatted
- @staticmethod
- @retry(max_attempts=3, delay=1.0)
- def _call_llm(payload: dict) -> str:
- headers = {"Authorization": f"Bearer {settings.GROQ_API_KEY}", "Content-Type": "application/json"}
- resp = requests.post(settings.GROQ_API_URL, headers=headers, json=payload, timeout=30)
- resp.raise_for_status()
- return resp.json()["choices"][0]["message"]["content"]
- @staticmethod
- def extract_attributes(
- product_text: str,
- mandatory_attrs: Dict[str, List[str]],
- source_map: Dict[str, str] = None,
- model: str = None,
- extract_additional: bool = True,
- multiple: Optional[List[str]] = None,
- use_cache: Optional[bool] = None,
- ) -> dict:
- if model is None: model = settings.SUPPORTED_MODELS[0]
- if multiple is None: multiple = []
- if source_map is None: source_map = {}
- if use_cache is None: use_cache = ENABLE_ATTRIBUTE_EXTRACTION_CACHE
- if not is_caching_enabled(): use_cache = False
- cache_key = None
- if use_cache:
- cache_key = ProductAttributeService._cache_key(product_text, mandatory_attrs, extract_additional, multiple)
- cached = SimpleCache.get(cache_key)
- if cached:
- logger.info(f"CACHE HIT {cache_key[:16]}...")
- return cached
- # --------------------------- PROMPT WITH FULL WHITELIST ---------------------------
- allowed_lines = [f"{attr}: {', '.join(vals)}" for attr, vals in mandatory_attrs.items()]
- allowed_text = "\n".join(allowed_lines)
- allowed_sources = list(source_map.keys()) + ["not_found"]
- source_hint = "|".join(allowed_sources)
- multiple_text = f"\nMULTIPLE ALLOWED FOR: {', '.join(multiple)}" if multiple else ""
- prompt = f"""
- You are a product-attribute classifier.
- Pick **exactly one** value from the list below for each attribute.
- If nothing matches, return "Not Specified".
- ALLOWED VALUES:
- {allowed_text}
- {multiple_text}
- PRODUCT TEXT:
- {product_text}
- OUTPUT (strict JSON only):
- {{
- "mandatory": {{
- "<attr>": [{{"value":"<chosen>", "source":"<{source_hint}>"}}]
- }},
- "additional": {{}}
- }}
- RULES:
- - Pick from allowed values only
- - If not found: "Not Specified" + "not_found"
- - Source must be: {source_hint}
- - Return ONLY JSON
- """
- payload = {
- "model": model,
- "messages": [
- {"role": "system", "content": "You are a JSON-only extractor."},
- {"role": "user", "content": prompt},
- ],
- "temperature": 0.0,
- "max_tokens": 1200,
- }
- try:
- raw = ProductAttributeService._call_llm(payload)
- cleaned = ProductAttributeService._clean_json(raw)
- parsed = json.loads(cleaned)
- except Exception as exc:
- logger.error(f"LLM failed: {exc}")
- return {
- "mandatory": {a: [{"value": "Not Specified", "source": "llm_error"}] for a in mandatory_attrs},
- "additional": {} if not extract_additional else {},
- "error": str(exc)
- }
- # --------------------------- VALIDATION + SMART RECOVERY ---------------------------
- pt_emb = EmbeddingCache.get_embedding(product_text, model_embedder)
- def _sanitize(section: dict, allowed: Dict):
- sanitized = {}
- for attr, items in section.items():
- if attr not in allowed: continue
- chosen = []
- for it in (items if isinstance(items, list) else [items]):
- if not isinstance(it, dict): it = {"value": str(it), "source": "not_found"}
- val = str(it.get("value", "")).strip()
- src = str(it.get("source", "not_found")).lower()
- # --- LLM SAYS "Not Specified" → SMART RECOVERY ---
- if val == "Not Specified":
- # 1. Lexical recovery
- for av in allowed[attr]:
- if ProductAttributeService._lexical_evidence(product_text, av) > 0.6:
- src = ProductAttributeService._find_source(av, source_map)
- chosen.append({"value": av, "source": src})
- break
- else:
- # 2. Semantic recovery
- best_val, best_score = max(
- ((av, float(util.cos_sim(pt_emb, EmbeddingCache.get_embedding(av, model_embedder)).item()))
- for av in allowed[attr]),
- key=lambda x: x[1]
- )
- if best_score > 0.75:
- src = ProductAttributeService._find_source(best_val, source_map)
- chosen.append({"value": best_val, "source": src})
- else:
- chosen.append({"value": "Not Specified", "source": "not_found"})
- continue
- # --- VALIDATE LLM CHOICE ---
- if val not in allowed[attr]: continue
- if ProductAttributeService._lexical_evidence(product_text, val) < 0.2: continue
- if src not in source_map and src != "not_found": src = "not_found"
- chosen.append({"value": val, "source": src})
- sanitized[attr] = chosen or [{"value": "Not Specified", "source": "not_found"}]
- return sanitized
- parsed["mandatory"] = _sanitize(parsed.get("mandatory", {}), mandatory_attrs)
- # --- ADDITIONAL ATTRIBUTES ---
- if extract_additional and "additional" in parsed:
- additional = {}
- for attr, items in parsed["additional"].items():
- good = []
- for it in (items if isinstance(items, list) else [items]):
- if not isinstance(it, dict): it = {"value": str(it), "source": "not_found"}
- val = str(it.get("value", "")).strip()
- src = str(it.get("source", "not_found")).lower()
- if src not in source_map and src != "not_found": src = "not_found"
- if val: good.append({"value": val, "source": src})
- if good: additional[attr] = good
- parsed["additional"] = additional
- else:
- parsed.pop("additional", None)
- if use_cache and cache_key:
- SimpleCache.set(cache_key, parsed)
- logger.info(f"CACHE SET {cache_key[:16]}...")
- return parsed
- @staticmethod
- def get_cache_stats() -> Dict:
- return {
- "global_enabled": is_caching_enabled(),
- "result_cache": SimpleCache.get_stats(),
- "embedding_cache": EmbeddingCache.get_stats(),
- }
- @staticmethod
- def clear_all_caches():
- SimpleCache.clear()
- EmbeddingCache.clear()
- logger.info("All caches cleared")
|