|
@@ -1,645 +1,216 @@
|
|
|
-# ==================== services.py (WITH CACHE CONTROL) ====================
|
|
|
|
|
-import requests
|
|
|
|
|
|
|
+# ==================== services.py (FINAL PERFECT + FULL WHITELIST + SEMANTIC RECOVERY) ====================
|
|
|
import json
|
|
import json
|
|
|
-import re
|
|
|
|
|
import hashlib
|
|
import hashlib
|
|
|
import logging
|
|
import logging
|
|
|
|
|
+import warnings
|
|
|
|
|
+import time
|
|
|
|
|
+from functools import wraps
|
|
|
from typing import Dict, List, Optional, Tuple
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
|
+import os
|
|
|
|
|
+import requests
|
|
|
from django.conf import settings
|
|
from django.conf import settings
|
|
|
-from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
|
|
from sentence_transformers import SentenceTransformer, util
|
|
from sentence_transformers import SentenceTransformer, util
|
|
|
-import numpy as np
|
|
|
|
|
|
|
|
|
|
-# ⚡ IMPORT CACHE CONFIGURATION
|
|
|
|
|
|
|
+# --------------------------------------------------------------------------- #
|
|
|
|
|
+# CACHE CONFIG
|
|
|
|
|
+# --------------------------------------------------------------------------- #
|
|
|
from .cache_config import (
|
|
from .cache_config import (
|
|
|
is_caching_enabled,
|
|
is_caching_enabled,
|
|
|
ENABLE_ATTRIBUTE_EXTRACTION_CACHE,
|
|
ENABLE_ATTRIBUTE_EXTRACTION_CACHE,
|
|
|
ENABLE_EMBEDDING_CACHE,
|
|
ENABLE_EMBEDDING_CACHE,
|
|
|
ATTRIBUTE_CACHE_MAX_SIZE,
|
|
ATTRIBUTE_CACHE_MAX_SIZE,
|
|
|
- EMBEDDING_CACHE_MAX_SIZE
|
|
|
|
|
|
|
+ EMBEDDING_CACHE_MAX_SIZE,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
-# ⚡ CRITICAL FIX: Initialize embedding model ONCE at module level
|
|
|
|
|
-print("Loading sentence transformer model (one-time initialization)...")
|
|
|
|
|
|
|
+# --------------------------------------------------------------------------- #
|
|
|
|
|
+# ONE-TIME MODEL LOAD
|
|
|
|
|
+# --------------------------------------------------------------------------- #
|
|
|
|
|
+print("Loading sentence transformer model (semantic recovery)...")
|
|
|
model_embedder = SentenceTransformer("all-MiniLM-L6-v2")
|
|
model_embedder = SentenceTransformer("all-MiniLM-L6-v2")
|
|
|
-# Disable progress bars to prevent "Batches: 100%" spam
|
|
|
|
|
-import os
|
|
|
|
|
-os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
|
|
|
|
-print("✓ Model loaded successfully")
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-# ==================== CACHING CLASSES ====================
|
|
|
|
|
|
|
+os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
|
|
+print("Model loaded")
|
|
|
|
|
|
|
|
|
|
+# --------------------------------------------------------------------------- #
|
|
|
|
|
+# CACHES
|
|
|
|
|
+# --------------------------------------------------------------------------- #
|
|
|
class SimpleCache:
|
|
class SimpleCache:
|
|
|
- """In-memory cache for attribute extraction results."""
|
|
|
|
|
_cache = {}
|
|
_cache = {}
|
|
|
_max_size = ATTRIBUTE_CACHE_MAX_SIZE
|
|
_max_size = ATTRIBUTE_CACHE_MAX_SIZE
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
@classmethod
|
|
@classmethod
|
|
|
def get(cls, key: str) -> Optional[Dict]:
|
|
def get(cls, key: str) -> Optional[Dict]:
|
|
|
- """Get value from cache. Returns None if caching is disabled."""
|
|
|
|
|
- if not ENABLE_ATTRIBUTE_EXTRACTION_CACHE:
|
|
|
|
|
- return None
|
|
|
|
|
|
|
+ if not ENABLE_ATTRIBUTE_EXTRACTION_CACHE: return None
|
|
|
return cls._cache.get(key)
|
|
return cls._cache.get(key)
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
@classmethod
|
|
@classmethod
|
|
|
def set(cls, key: str, value: Dict):
|
|
def set(cls, key: str, value: Dict):
|
|
|
- """Set value in cache. Does nothing if caching is disabled."""
|
|
|
|
|
- if not ENABLE_ATTRIBUTE_EXTRACTION_CACHE:
|
|
|
|
|
- return
|
|
|
|
|
-
|
|
|
|
|
|
|
+ if not ENABLE_ATTRIBUTE_EXTRACTION_CACHE: return
|
|
|
if len(cls._cache) >= cls._max_size:
|
|
if len(cls._cache) >= cls._max_size:
|
|
|
items = list(cls._cache.items())
|
|
items = list(cls._cache.items())
|
|
|
cls._cache = dict(items[int(cls._max_size * 0.2):])
|
|
cls._cache = dict(items[int(cls._max_size * 0.2):])
|
|
|
cls._cache[key] = value
|
|
cls._cache[key] = value
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
@classmethod
|
|
@classmethod
|
|
|
- def clear(cls):
|
|
|
|
|
- """Clear the cache."""
|
|
|
|
|
- cls._cache.clear()
|
|
|
|
|
-
|
|
|
|
|
|
|
+ def clear(cls): cls._cache.clear()
|
|
|
|
|
+
|
|
|
@classmethod
|
|
@classmethod
|
|
|
def get_stats(cls) -> Dict:
|
|
def get_stats(cls) -> Dict:
|
|
|
- """Get cache statistics."""
|
|
|
|
|
return {
|
|
return {
|
|
|
"enabled": ENABLE_ATTRIBUTE_EXTRACTION_CACHE,
|
|
"enabled": ENABLE_ATTRIBUTE_EXTRACTION_CACHE,
|
|
|
"size": len(cls._cache),
|
|
"size": len(cls._cache),
|
|
|
"max_size": cls._max_size,
|
|
"max_size": cls._max_size,
|
|
|
- "usage_percent": round(len(cls._cache) / cls._max_size * 100, 2) if cls._max_size > 0 else 0
|
|
|
|
|
|
|
+ "usage_percent": round(len(cls._cache)/cls._max_size*100, 2) if cls._max_size else 0
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-
|
|
|
|
|
class EmbeddingCache:
|
|
class EmbeddingCache:
|
|
|
- """Cache for sentence transformer embeddings."""
|
|
|
|
|
_cache = {}
|
|
_cache = {}
|
|
|
_max_size = EMBEDDING_CACHE_MAX_SIZE
|
|
_max_size = EMBEDDING_CACHE_MAX_SIZE
|
|
|
- _hit_count = 0
|
|
|
|
|
- _miss_count = 0
|
|
|
|
|
-
|
|
|
|
|
|
|
+ _hit = _miss = 0
|
|
|
|
|
+
|
|
|
@classmethod
|
|
@classmethod
|
|
|
def get_embedding(cls, text: str, model):
|
|
def get_embedding(cls, text: str, model):
|
|
|
- """Get or compute embedding with optional caching"""
|
|
|
|
|
- # If caching is disabled, always compute fresh
|
|
|
|
|
if not ENABLE_EMBEDDING_CACHE:
|
|
if not ENABLE_EMBEDDING_CACHE:
|
|
|
- import warnings
|
|
|
|
|
with warnings.catch_warnings():
|
|
with warnings.catch_warnings():
|
|
|
warnings.simplefilter("ignore")
|
|
warnings.simplefilter("ignore")
|
|
|
- embedding = model.encode(text, convert_to_tensor=True, show_progress_bar=False)
|
|
|
|
|
- return embedding
|
|
|
|
|
-
|
|
|
|
|
- # Caching is enabled, check cache first
|
|
|
|
|
|
|
+ return model.encode(text, convert_to_tensor=True, show_progress_bar=False)
|
|
|
if text in cls._cache:
|
|
if text in cls._cache:
|
|
|
- cls._hit_count += 1
|
|
|
|
|
|
|
+ cls._hit += 1
|
|
|
return cls._cache[text]
|
|
return cls._cache[text]
|
|
|
-
|
|
|
|
|
- cls._miss_count += 1
|
|
|
|
|
-
|
|
|
|
|
|
|
+ cls._miss += 1
|
|
|
if len(cls._cache) >= cls._max_size:
|
|
if len(cls._cache) >= cls._max_size:
|
|
|
items = list(cls._cache.items())
|
|
items = list(cls._cache.items())
|
|
|
cls._cache = dict(items[int(cls._max_size * 0.3):])
|
|
cls._cache = dict(items[int(cls._max_size * 0.3):])
|
|
|
-
|
|
|
|
|
- # Compute embedding
|
|
|
|
|
- import warnings
|
|
|
|
|
with warnings.catch_warnings():
|
|
with warnings.catch_warnings():
|
|
|
warnings.simplefilter("ignore")
|
|
warnings.simplefilter("ignore")
|
|
|
- embedding = model.encode(text, convert_to_tensor=True, show_progress_bar=False)
|
|
|
|
|
-
|
|
|
|
|
- cls._cache[text] = embedding
|
|
|
|
|
- return embedding
|
|
|
|
|
-
|
|
|
|
|
|
|
+ emb = model.encode(text, convert_to_tensor=True, show_progress_bar=False)
|
|
|
|
|
+ cls._cache[text] = emb
|
|
|
|
|
+ return emb
|
|
|
|
|
+
|
|
|
@classmethod
|
|
@classmethod
|
|
|
def clear(cls):
|
|
def clear(cls):
|
|
|
- """Clear the cache and reset statistics."""
|
|
|
|
|
cls._cache.clear()
|
|
cls._cache.clear()
|
|
|
- cls._hit_count = 0
|
|
|
|
|
- cls._miss_count = 0
|
|
|
|
|
-
|
|
|
|
|
|
|
+ cls._hit = cls._miss = 0
|
|
|
|
|
+
|
|
|
@classmethod
|
|
@classmethod
|
|
|
def get_stats(cls) -> Dict:
|
|
def get_stats(cls) -> Dict:
|
|
|
- """Get cache statistics."""
|
|
|
|
|
- total = cls._hit_count + cls._miss_count
|
|
|
|
|
- hit_rate = (cls._hit_count / total * 100) if total > 0 else 0
|
|
|
|
|
|
|
+ total = cls._hit + cls._miss
|
|
|
|
|
+ rate = (cls._hit / total * 100) if total else 0
|
|
|
return {
|
|
return {
|
|
|
"enabled": ENABLE_EMBEDDING_CACHE,
|
|
"enabled": ENABLE_EMBEDDING_CACHE,
|
|
|
"size": len(cls._cache),
|
|
"size": len(cls._cache),
|
|
|
- "hits": cls._hit_count,
|
|
|
|
|
- "misses": cls._miss_count,
|
|
|
|
|
- "hit_rate_percent": round(hit_rate, 2)
|
|
|
|
|
|
|
+ "hits": cls._hit,
|
|
|
|
|
+ "misses": cls._miss,
|
|
|
|
|
+ "hit_rate_percent": round(rate, 2),
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-
|
|
|
|
|
-# ==================== MAIN SERVICE CLASS ====================
|
|
|
|
|
-
|
|
|
|
|
|
|
+# --------------------------------------------------------------------------- #
|
|
|
|
|
+# RETRY DECORATOR
|
|
|
|
|
+# --------------------------------------------------------------------------- #
|
|
|
|
|
+def retry(max_attempts=3, delay=1.0):
|
|
|
|
|
+ def decorator(f):
|
|
|
|
|
+ @wraps(f)
|
|
|
|
|
+ def wrapper(*args, **kwargs):
|
|
|
|
|
+ last_exc = None
|
|
|
|
|
+ for i in range(max_attempts):
|
|
|
|
|
+ try:
|
|
|
|
|
+ return f(*args, **kwargs)
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ last_exc = e
|
|
|
|
|
+ if i < max_attempts - 1:
|
|
|
|
|
+ wait = delay * (2 ** i)
|
|
|
|
|
+ logger.warning(f"Retry {i+1}/{max_attempts} after {wait}s: {e}")
|
|
|
|
|
+ time.sleep(wait)
|
|
|
|
|
+ raise last_exc or RuntimeError("Retry failed")
|
|
|
|
|
+ return wrapper
|
|
|
|
|
+ return decorator
|
|
|
|
|
+
|
|
|
|
|
+# --------------------------------------------------------------------------- #
|
|
|
|
|
+# MAIN SERVICE
|
|
|
|
|
+# --------------------------------------------------------------------------- #
|
|
|
class ProductAttributeService:
|
|
class ProductAttributeService:
|
|
|
- """Service class for extracting product attributes using Groq LLM."""
|
|
|
|
|
-
|
|
|
|
|
- @staticmethod
|
|
|
|
|
- def _generate_cache_key(product_text: str, mandatory_attrs: Dict) -> str:
|
|
|
|
|
- """Generate cache key from product text and attributes."""
|
|
|
|
|
- attrs_str = json.dumps(mandatory_attrs, sort_keys=True)
|
|
|
|
|
- content = f"{product_text}:{attrs_str}"
|
|
|
|
|
- return f"attr_{hashlib.md5(content.encode()).hexdigest()}"
|
|
|
|
|
-
|
|
|
|
|
- @staticmethod
|
|
|
|
|
- def normalize_dimension_text(text: str) -> str:
|
|
|
|
|
- """Normalize dimension text to format like '16x20'."""
|
|
|
|
|
- if not text:
|
|
|
|
|
- return ""
|
|
|
|
|
-
|
|
|
|
|
- text = text.lower()
|
|
|
|
|
- text = re.sub(r'\s*(inches|inch|in|cm|centimeters|mm|millimeters)\s*', '', text, flags=re.IGNORECASE)
|
|
|
|
|
-
|
|
|
|
|
- numbers = re.findall(r'\d+\.?\d*', text)
|
|
|
|
|
- if not numbers:
|
|
|
|
|
- return ""
|
|
|
|
|
-
|
|
|
|
|
- float_numbers = []
|
|
|
|
|
- for num in numbers:
|
|
|
|
|
- try:
|
|
|
|
|
- float_numbers.append(float(num))
|
|
|
|
|
- except:
|
|
|
|
|
- continue
|
|
|
|
|
-
|
|
|
|
|
- if len(float_numbers) < 2:
|
|
|
|
|
- return ""
|
|
|
|
|
-
|
|
|
|
|
- if len(float_numbers) == 3:
|
|
|
|
|
- float_numbers = [float_numbers[0], float_numbers[2]]
|
|
|
|
|
- elif len(float_numbers) > 3:
|
|
|
|
|
- float_numbers = sorted(float_numbers)[-2:]
|
|
|
|
|
- else:
|
|
|
|
|
- float_numbers = float_numbers[:2]
|
|
|
|
|
-
|
|
|
|
|
- formatted_numbers = []
|
|
|
|
|
- for num in float_numbers:
|
|
|
|
|
- if num.is_integer():
|
|
|
|
|
- formatted_numbers.append(str(int(num)))
|
|
|
|
|
- else:
|
|
|
|
|
- formatted_numbers.append(f"{num:.1f}")
|
|
|
|
|
-
|
|
|
|
|
- formatted_numbers.sort(key=lambda x: float(x))
|
|
|
|
|
- return f"{formatted_numbers[0]}x{formatted_numbers[1]}"
|
|
|
|
|
-
|
|
|
|
|
- @staticmethod
|
|
|
|
|
- def normalize_value_for_matching(value: str, attr_name: str = "") -> str:
|
|
|
|
|
- """Normalize a value based on its attribute type."""
|
|
|
|
|
- dimension_keywords = ['dimension', 'size', 'measurement']
|
|
|
|
|
- if any(keyword in attr_name.lower() for keyword in dimension_keywords):
|
|
|
|
|
- normalized = ProductAttributeService.normalize_dimension_text(value)
|
|
|
|
|
- if normalized:
|
|
|
|
|
- return normalized
|
|
|
|
|
- return value.strip()
|
|
|
|
|
-
|
|
|
|
|
@staticmethod
|
|
@staticmethod
|
|
|
- def combine_product_text(
|
|
|
|
|
- title: Optional[str] = None,
|
|
|
|
|
- short_desc: Optional[str] = None,
|
|
|
|
|
- long_desc: Optional[str] = None,
|
|
|
|
|
- ocr_text: Optional[str] = None
|
|
|
|
|
- ) -> Tuple[str, Dict[str, str]]:
|
|
|
|
|
- """Combine product metadata into a single text block."""
|
|
|
|
|
|
|
+ def combine_product_text(title=None, short_desc=None, long_desc=None, ocr_text=None) -> Tuple[str, Dict[str, str]]:
|
|
|
parts = []
|
|
parts = []
|
|
|
source_map = {}
|
|
source_map = {}
|
|
|
-
|
|
|
|
|
if title:
|
|
if title:
|
|
|
- title_str = str(title).strip()
|
|
|
|
|
- parts.append(f"Title: {title_str}")
|
|
|
|
|
- source_map['title'] = title_str
|
|
|
|
|
|
|
+ t = str(title).strip()
|
|
|
|
|
+ parts.append(f"Title: {t}")
|
|
|
|
|
+ source_map["title"] = t
|
|
|
if short_desc:
|
|
if short_desc:
|
|
|
- short_str = str(short_desc).strip()
|
|
|
|
|
- parts.append(f"Description: {short_str}")
|
|
|
|
|
- source_map['short_desc'] = short_str
|
|
|
|
|
|
|
+ s = str(short_desc).strip()
|
|
|
|
|
+ parts.append(f"Description: {s}")
|
|
|
|
|
+ source_map["short_desc"] = s
|
|
|
if long_desc:
|
|
if long_desc:
|
|
|
- long_str = str(long_desc).strip()
|
|
|
|
|
- parts.append(f"Details: {long_str}")
|
|
|
|
|
- source_map['long_desc'] = long_str
|
|
|
|
|
|
|
+ l = str(long_desc).strip()
|
|
|
|
|
+ parts.append(f"Details: {l}")
|
|
|
|
|
+ source_map["long_desc"] = l
|
|
|
if ocr_text:
|
|
if ocr_text:
|
|
|
parts.append(f"OCR Text: {ocr_text}")
|
|
parts.append(f"OCR Text: {ocr_text}")
|
|
|
- source_map['ocr_text'] = ocr_text
|
|
|
|
|
-
|
|
|
|
|
|
|
+ source_map["ocr_text"] = ocr_text
|
|
|
combined = "\n".join(parts).strip()
|
|
combined = "\n".join(parts).strip()
|
|
|
-
|
|
|
|
|
- if not combined:
|
|
|
|
|
- return "No product information available", {}
|
|
|
|
|
-
|
|
|
|
|
- return combined, source_map
|
|
|
|
|
|
|
+ return (combined or "No product information", source_map)
|
|
|
|
|
+
|
|
|
|
|
+ @staticmethod
|
|
|
|
|
+ def _cache_key(product_text: str, mandatory_attrs: Dict, extract_additional: bool, multiple: List[str]) -> str:
|
|
|
|
|
+ payload = {"text": product_text, "attrs": mandatory_attrs, "extra": extract_additional, "multiple": sorted(multiple)}
|
|
|
|
|
+ return f"attr_{hashlib.md5(json.dumps(payload, sort_keys=True).encode()).hexdigest()}"
|
|
|
|
|
+
|
|
|
|
|
+ @staticmethod
|
|
|
|
|
+ def _clean_json(text: str) -> str:
|
|
|
|
|
+ start = text.find("{")
|
|
|
|
|
+ end = text.rfind("}") + 1
|
|
|
|
|
+ if start != -1 and end > start:
|
|
|
|
|
+ text = text[start:end]
|
|
|
|
|
+ if "```json" in text:
|
|
|
|
|
+ text = text.split("```json", 1)[1].split("```", 1)[0]
|
|
|
|
|
+ elif "```" in text:
|
|
|
|
|
+ text = text.split("```", 1)[1].split("```", 1)[0]
|
|
|
|
|
+ if text.lstrip().startswith("json"): text = text[4:]
|
|
|
|
|
+ return text.strip()
|
|
|
|
|
+
|
|
|
|
|
+ @staticmethod
|
|
|
|
|
+ def _lexical_evidence(product_text: str, label: str) -> float:
|
|
|
|
|
+ pt = product_text.lower()
|
|
|
|
|
+ tokens = [t for t in label.lower().replace("-", " ").split() if t]
|
|
|
|
|
+ if not tokens: return 0.0
|
|
|
|
|
+ hits = sum(1 for t in tokens if t in pt)
|
|
|
|
|
+ return hits / len(tokens)
|
|
|
|
|
|
|
|
@staticmethod
|
|
@staticmethod
|
|
|
- def find_value_source(value: str, source_map: Dict[str, str], attr_name: str = "") -> str:
|
|
|
|
|
- """Find which source(s) contain the given value."""
|
|
|
|
|
|
|
+ def _find_source(value: str, source_map: Dict[str, str]) -> str:
|
|
|
value_lower = value.lower()
|
|
value_lower = value.lower()
|
|
|
- value_tokens = set(value_lower.replace("-", " ").replace("x", " ").split())
|
|
|
|
|
-
|
|
|
|
|
- is_dimension_attr = any(keyword in attr_name.lower() for keyword in ['dimension', 'size', 'measurement'])
|
|
|
|
|
-
|
|
|
|
|
- sources_found = []
|
|
|
|
|
- source_scores = {}
|
|
|
|
|
-
|
|
|
|
|
- for source_name, source_text in source_map.items():
|
|
|
|
|
- source_lower = source_text.lower()
|
|
|
|
|
-
|
|
|
|
|
- if value_lower in source_lower:
|
|
|
|
|
- source_scores[source_name] = 1.0
|
|
|
|
|
- continue
|
|
|
|
|
-
|
|
|
|
|
- if is_dimension_attr:
|
|
|
|
|
- normalized_value = ProductAttributeService.normalize_dimension_text(value)
|
|
|
|
|
- if not normalized_value:
|
|
|
|
|
- normalized_value = value.replace("x", " ").strip()
|
|
|
|
|
-
|
|
|
|
|
- normalized_source = ProductAttributeService.normalize_dimension_text(source_text)
|
|
|
|
|
-
|
|
|
|
|
- if normalized_value == normalized_source:
|
|
|
|
|
- source_scores[source_name] = 0.95
|
|
|
|
|
- continue
|
|
|
|
|
-
|
|
|
|
|
- dim_parts = normalized_value.split("x") if "x" in normalized_value else []
|
|
|
|
|
- if len(dim_parts) == 2:
|
|
|
|
|
- if all(part in source_text for part in dim_parts):
|
|
|
|
|
- source_scores[source_name] = 0.85
|
|
|
|
|
- continue
|
|
|
|
|
-
|
|
|
|
|
- token_matches = sum(1 for token in value_tokens if token and token in source_lower)
|
|
|
|
|
- if token_matches > 0 and len(value_tokens) > 0:
|
|
|
|
|
- source_scores[source_name] = token_matches / len(value_tokens)
|
|
|
|
|
-
|
|
|
|
|
- if source_scores:
|
|
|
|
|
- max_score = max(source_scores.values())
|
|
|
|
|
- sources_found = [s for s, score in source_scores.items() if score == max_score]
|
|
|
|
|
-
|
|
|
|
|
- priority = ['title', 'short_desc', 'long_desc', 'ocr_text']
|
|
|
|
|
- for p in priority:
|
|
|
|
|
- if p in sources_found:
|
|
|
|
|
- return p
|
|
|
|
|
-
|
|
|
|
|
- return sources_found[0] if sources_found else "Not found"
|
|
|
|
|
-
|
|
|
|
|
- return "Not found"
|
|
|
|
|
|
|
+ for src_key, text in source_map.items():
|
|
|
|
|
+ if value_lower in text.lower():
|
|
|
|
|
+ return src_key
|
|
|
|
|
+ return "not_found"
|
|
|
|
|
|
|
|
@staticmethod
|
|
@staticmethod
|
|
|
def format_visual_attributes(visual_attributes: Dict) -> Dict:
|
|
def format_visual_attributes(visual_attributes: Dict) -> Dict:
|
|
|
- """Convert visual attributes to array format with source tracking."""
|
|
|
|
|
formatted = {}
|
|
formatted = {}
|
|
|
-
|
|
|
|
|
for key, value in visual_attributes.items():
|
|
for key, value in visual_attributes.items():
|
|
|
if isinstance(value, list):
|
|
if isinstance(value, list):
|
|
|
formatted[key] = [{"value": str(item), "source": "image"} for item in value]
|
|
formatted[key] = [{"value": str(item), "source": "image"} for item in value]
|
|
|
elif isinstance(value, dict):
|
|
elif isinstance(value, dict):
|
|
|
- nested_formatted = {}
|
|
|
|
|
- for nested_key, nested_value in value.items():
|
|
|
|
|
- if isinstance(nested_value, list):
|
|
|
|
|
- nested_formatted[nested_key] = [{"value": str(item), "source": "image"} for item in nested_value]
|
|
|
|
|
|
|
+ nested = {}
|
|
|
|
|
+ for sub_key, sub_val in value.items():
|
|
|
|
|
+ if isinstance(sub_val, list):
|
|
|
|
|
+ nested[sub_key] = [{"value": str(v), "source": "image"} for v in sub_val]
|
|
|
else:
|
|
else:
|
|
|
- nested_formatted[nested_key] = [{"value": str(nested_value), "source": "image"}]
|
|
|
|
|
- formatted[key] = nested_formatted
|
|
|
|
|
|
|
+ nested[sub_key] = [{"value": str(sub_val), "source": "image"}]
|
|
|
|
|
+ formatted[key] = nested
|
|
|
else:
|
|
else:
|
|
|
formatted[key] = [{"value": str(value), "source": "image"}]
|
|
formatted[key] = [{"value": str(value), "source": "image"}]
|
|
|
-
|
|
|
|
|
return formatted
|
|
return formatted
|
|
|
|
|
|
|
|
@staticmethod
|
|
@staticmethod
|
|
|
- def extract_attributes_from_ocr(ocr_results: Dict, model: str = None) -> Dict:
|
|
|
|
|
- """Extract structured attributes from OCR text using LLM."""
|
|
|
|
|
- if model is None:
|
|
|
|
|
- model = settings.SUPPORTED_MODELS[0]
|
|
|
|
|
-
|
|
|
|
|
- detected_text = ocr_results.get('detected_text', [])
|
|
|
|
|
- if not detected_text:
|
|
|
|
|
- return {}
|
|
|
|
|
-
|
|
|
|
|
- ocr_text = "\n".join([f"Text: {item['text']}, Confidence: {item['confidence']:.2f}"
|
|
|
|
|
- for item in detected_text])
|
|
|
|
|
-
|
|
|
|
|
- prompt = f"""
|
|
|
|
|
-You are an AI model that extracts structured attributes from OCR text detected on product images.
|
|
|
|
|
-Given the OCR detections below, infer the possible product attributes and return them as a clean JSON object.
|
|
|
|
|
-
|
|
|
|
|
-OCR Text:
|
|
|
|
|
-{ocr_text}
|
|
|
|
|
-
|
|
|
|
|
-Extract relevant attributes like:
|
|
|
|
|
-- brand
|
|
|
|
|
-- model_number
|
|
|
|
|
-- size (waist_size, length, etc.)
|
|
|
|
|
-- collection
|
|
|
|
|
-- any other relevant product information
|
|
|
|
|
-
|
|
|
|
|
-Return a JSON object with only the attributes you can confidently identify.
|
|
|
|
|
-If an attribute is not present, do not include it in the response.
|
|
|
|
|
-"""
|
|
|
|
|
-
|
|
|
|
|
- payload = {
|
|
|
|
|
- "model": model,
|
|
|
|
|
- "messages": [
|
|
|
|
|
- {
|
|
|
|
|
- "role": "system",
|
|
|
|
|
- "content": "You are a helpful AI that extracts structured data from OCR output. Return only valid JSON."
|
|
|
|
|
- },
|
|
|
|
|
- {"role": "user", "content": prompt}
|
|
|
|
|
- ],
|
|
|
|
|
- "temperature": 0.2,
|
|
|
|
|
- "max_tokens": 500
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- headers = {
|
|
|
|
|
- "Authorization": f"Bearer {settings.GROQ_API_KEY}",
|
|
|
|
|
- "Content-Type": "application/json",
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- try:
|
|
|
|
|
- response = requests.post(
|
|
|
|
|
- settings.GROQ_API_URL,
|
|
|
|
|
- headers=headers,
|
|
|
|
|
- json=payload,
|
|
|
|
|
- timeout=30
|
|
|
|
|
- )
|
|
|
|
|
- response.raise_for_status()
|
|
|
|
|
- result_text = response.json()["choices"][0]["message"]["content"].strip()
|
|
|
|
|
-
|
|
|
|
|
- result_text = ProductAttributeService._clean_json_response(result_text)
|
|
|
|
|
- parsed = json.loads(result_text)
|
|
|
|
|
-
|
|
|
|
|
- formatted_attributes = {}
|
|
|
|
|
- for key, value in parsed.items():
|
|
|
|
|
- if key == "error":
|
|
|
|
|
- continue
|
|
|
|
|
-
|
|
|
|
|
- if isinstance(value, dict):
|
|
|
|
|
- nested_formatted = {}
|
|
|
|
|
- for nested_key, nested_value in value.items():
|
|
|
|
|
- nested_formatted[nested_key] = [{"value": str(nested_value), "source": "image"}]
|
|
|
|
|
- formatted_attributes[key] = nested_formatted
|
|
|
|
|
- elif isinstance(value, list):
|
|
|
|
|
- formatted_attributes[key] = [{"value": str(item), "source": "image"} for item in value]
|
|
|
|
|
- else:
|
|
|
|
|
- formatted_attributes[key] = [{"value": str(value), "source": "image"}]
|
|
|
|
|
-
|
|
|
|
|
- return formatted_attributes
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- logger.error(f"OCR attribute extraction failed: {str(e)}")
|
|
|
|
|
- return {"error": f"Failed to extract attributes from OCR: {str(e)}"}
|
|
|
|
|
-
|
|
|
|
|
- @staticmethod
|
|
|
|
|
- def calculate_attribute_relationships(
|
|
|
|
|
- mandatory_attrs: Dict[str, List[str]],
|
|
|
|
|
- product_text: str
|
|
|
|
|
- ) -> Dict[str, float]:
|
|
|
|
|
- """Calculate semantic relationships between attribute values."""
|
|
|
|
|
- pt_emb = EmbeddingCache.get_embedding(product_text, model_embedder)
|
|
|
|
|
-
|
|
|
|
|
- attr_scores = {}
|
|
|
|
|
- for attr, values in mandatory_attrs.items():
|
|
|
|
|
- attr_scores[attr] = {}
|
|
|
|
|
- for val in values:
|
|
|
|
|
- contexts = [val, f"for {val}", f"use in {val}", f"suitable for {val}"]
|
|
|
|
|
- ctx_embs = [EmbeddingCache.get_embedding(c, model_embedder) for c in contexts]
|
|
|
|
|
- sem_sim = max(float(util.cos_sim(pt_emb, ce).item()) for ce in ctx_embs)
|
|
|
|
|
- attr_scores[attr][val] = sem_sim
|
|
|
|
|
-
|
|
|
|
|
- relationships = {}
|
|
|
|
|
- attr_list = list(mandatory_attrs.keys())
|
|
|
|
|
-
|
|
|
|
|
- for i, attr1 in enumerate(attr_list):
|
|
|
|
|
- for attr2 in attr_list[i+1:]:
|
|
|
|
|
- for val1 in mandatory_attrs[attr1]:
|
|
|
|
|
- for val2 in mandatory_attrs[attr2]:
|
|
|
|
|
- emb1 = EmbeddingCache.get_embedding(val1, model_embedder)
|
|
|
|
|
- emb2 = EmbeddingCache.get_embedding(val2, model_embedder)
|
|
|
|
|
- sim = float(util.cos_sim(emb1, emb2).item())
|
|
|
|
|
-
|
|
|
|
|
- key1 = f"{attr1}:{val1}->{attr2}:{val2}"
|
|
|
|
|
- key2 = f"{attr2}:{val2}->{attr1}:{val1}"
|
|
|
|
|
- relationships[key1] = sim
|
|
|
|
|
- relationships[key2] = sim
|
|
|
|
|
-
|
|
|
|
|
- return relationships
|
|
|
|
|
-
|
|
|
|
|
- @staticmethod
|
|
|
|
|
- def calculate_value_clusters(
|
|
|
|
|
- values: List[str],
|
|
|
|
|
- scores: List[Tuple[str, float]],
|
|
|
|
|
- cluster_threshold: float = 0.4
|
|
|
|
|
- ) -> List[List[str]]:
|
|
|
|
|
- """Group values into semantic clusters."""
|
|
|
|
|
- if len(values) <= 1:
|
|
|
|
|
- return [[val] for val, _ in scores]
|
|
|
|
|
-
|
|
|
|
|
- embeddings = [EmbeddingCache.get_embedding(val, model_embedder) for val in values]
|
|
|
|
|
-
|
|
|
|
|
- similarity_matrix = np.zeros((len(values), len(values)))
|
|
|
|
|
- for i in range(len(values)):
|
|
|
|
|
- for j in range(i+1, len(values)):
|
|
|
|
|
- sim = float(util.cos_sim(embeddings[i], embeddings[j]).item())
|
|
|
|
|
- similarity_matrix[i][j] = sim
|
|
|
|
|
- similarity_matrix[j][i] = sim
|
|
|
|
|
-
|
|
|
|
|
- clusters = []
|
|
|
|
|
- visited = set()
|
|
|
|
|
-
|
|
|
|
|
- for i, (val, score) in enumerate(scores):
|
|
|
|
|
- if i in visited:
|
|
|
|
|
- continue
|
|
|
|
|
-
|
|
|
|
|
- cluster = [val]
|
|
|
|
|
- visited.add(i)
|
|
|
|
|
-
|
|
|
|
|
- for j in range(len(values)):
|
|
|
|
|
- if j not in visited and similarity_matrix[i][j] >= cluster_threshold:
|
|
|
|
|
- cluster.append(values[j])
|
|
|
|
|
- visited.add(j)
|
|
|
|
|
-
|
|
|
|
|
- clusters.append(cluster)
|
|
|
|
|
-
|
|
|
|
|
- return clusters
|
|
|
|
|
-
|
|
|
|
|
- @staticmethod
|
|
|
|
|
- def get_dynamic_threshold(
|
|
|
|
|
- attr: str,
|
|
|
|
|
- val: str,
|
|
|
|
|
- base_score: float,
|
|
|
|
|
- extracted_attrs: Dict[str, List[Dict[str, str]]],
|
|
|
|
|
- relationships: Dict[str, float],
|
|
|
|
|
- mandatory_attrs: Dict[str, List[str]],
|
|
|
|
|
- base_threshold: float = 0.65,
|
|
|
|
|
- boost_factor: float = 0.15
|
|
|
|
|
- ) -> float:
|
|
|
|
|
- """Calculate dynamic threshold based on relationships."""
|
|
|
|
|
- threshold = base_threshold
|
|
|
|
|
-
|
|
|
|
|
- max_relationship = 0.0
|
|
|
|
|
- for other_attr, other_values_list in extracted_attrs.items():
|
|
|
|
|
- if other_attr == attr:
|
|
|
|
|
- continue
|
|
|
|
|
-
|
|
|
|
|
- for other_val_dict in other_values_list:
|
|
|
|
|
- other_val = other_val_dict['value']
|
|
|
|
|
- key = f"{attr}:{val}->{other_attr}:{other_val}"
|
|
|
|
|
- if key in relationships:
|
|
|
|
|
- max_relationship = max(max_relationship, relationships[key])
|
|
|
|
|
-
|
|
|
|
|
- if max_relationship > 0.6:
|
|
|
|
|
- threshold = base_threshold - (boost_factor * max_relationship)
|
|
|
|
|
-
|
|
|
|
|
- return max(0.3, threshold)
|
|
|
|
|
-
|
|
|
|
|
- @staticmethod
|
|
|
|
|
- def get_adaptive_margin(
|
|
|
|
|
- scores: List[Tuple[str, float]],
|
|
|
|
|
- base_margin: float = 0.15,
|
|
|
|
|
- max_margin: float = 0.22
|
|
|
|
|
- ) -> float:
|
|
|
|
|
- """Calculate adaptive margin based on score distribution."""
|
|
|
|
|
- if len(scores) < 2:
|
|
|
|
|
- return base_margin
|
|
|
|
|
-
|
|
|
|
|
- score_values = [s for _, s in scores]
|
|
|
|
|
- best_score = score_values[0]
|
|
|
|
|
-
|
|
|
|
|
- if best_score < 0.5:
|
|
|
|
|
- top_scores = score_values[:min(4, len(score_values))]
|
|
|
|
|
- score_range = max(top_scores) - min(top_scores)
|
|
|
|
|
-
|
|
|
|
|
- if score_range < 0.30:
|
|
|
|
|
- score_factor = (0.5 - best_score) * 0.35
|
|
|
|
|
- adaptive = base_margin + score_factor + (0.30 - score_range) * 0.2
|
|
|
|
|
- return min(adaptive, max_margin)
|
|
|
|
|
-
|
|
|
|
|
- return base_margin
|
|
|
|
|
-
|
|
|
|
|
- @staticmethod
|
|
|
|
|
- def _lexical_evidence(product_text: str, label: str) -> float:
|
|
|
|
|
- """Calculate lexical overlap between product text and label."""
|
|
|
|
|
- pt = product_text.lower()
|
|
|
|
|
- tokens = [t for t in label.lower().replace("-", " ").split() if t]
|
|
|
|
|
- if not tokens:
|
|
|
|
|
- return 0.0
|
|
|
|
|
- hits = sum(1 for t in tokens if t in pt)
|
|
|
|
|
- return hits / len(tokens)
|
|
|
|
|
-
|
|
|
|
|
- @staticmethod
|
|
|
|
|
- def normalize_against_product_text(
|
|
|
|
|
- product_text: str,
|
|
|
|
|
- mandatory_attrs: Dict[str, List[str]],
|
|
|
|
|
- source_map: Dict[str, str],
|
|
|
|
|
- threshold_abs: float = 0.65,
|
|
|
|
|
- margin: float = 0.15,
|
|
|
|
|
- allow_multiple: bool = False,
|
|
|
|
|
- sem_weight: float = 0.8,
|
|
|
|
|
- lex_weight: float = 0.2,
|
|
|
|
|
- extracted_attrs: Optional[Dict[str, List[Dict[str, str]]]] = None,
|
|
|
|
|
- relationships: Optional[Dict[str, float]] = None,
|
|
|
|
|
- use_dynamic_thresholds: bool = True,
|
|
|
|
|
- use_adaptive_margin: bool = True,
|
|
|
|
|
- use_semantic_clustering: bool = True
|
|
|
|
|
- ) -> dict:
|
|
|
|
|
- """Score each allowed value against the product_text."""
|
|
|
|
|
- if extracted_attrs is None:
|
|
|
|
|
- extracted_attrs = {}
|
|
|
|
|
- if relationships is None:
|
|
|
|
|
- relationships = {}
|
|
|
|
|
-
|
|
|
|
|
- pt_emb = EmbeddingCache.get_embedding(product_text, model_embedder)
|
|
|
|
|
- extracted = {}
|
|
|
|
|
-
|
|
|
|
|
- for attr, allowed_values in mandatory_attrs.items():
|
|
|
|
|
- scores: List[Tuple[str, float]] = []
|
|
|
|
|
-
|
|
|
|
|
- is_dimension_attr = any(keyword in attr.lower() for keyword in ['dimension', 'size', 'measurement'])
|
|
|
|
|
- normalized_product_text = ProductAttributeService.normalize_dimension_text(product_text) if is_dimension_attr else ""
|
|
|
|
|
-
|
|
|
|
|
- for val in allowed_values:
|
|
|
|
|
- if is_dimension_attr:
|
|
|
|
|
- normalized_val = ProductAttributeService.normalize_dimension_text(val)
|
|
|
|
|
-
|
|
|
|
|
- if normalized_val and normalized_product_text and normalized_val == normalized_product_text:
|
|
|
|
|
- scores.append((val, 1.0))
|
|
|
|
|
- continue
|
|
|
|
|
-
|
|
|
|
|
- if normalized_val:
|
|
|
|
|
- val_numbers = normalized_val.split('x')
|
|
|
|
|
- text_lower = product_text.lower()
|
|
|
|
|
- if all(num in text_lower for num in val_numbers):
|
|
|
|
|
- idx1 = text_lower.find(val_numbers[0])
|
|
|
|
|
- idx2 = text_lower.find(val_numbers[1])
|
|
|
|
|
- if idx1 != -1 and idx2 != -1:
|
|
|
|
|
- distance = abs(idx2 - idx1)
|
|
|
|
|
- if distance < 20:
|
|
|
|
|
- scores.append((val, 0.95))
|
|
|
|
|
- continue
|
|
|
|
|
-
|
|
|
|
|
- contexts = [val, f"for {val}", f"use in {val}", f"suitable for {val}", f"{val} room"]
|
|
|
|
|
- ctx_embs = [EmbeddingCache.get_embedding(c, model_embedder) for c in contexts]
|
|
|
|
|
- sem_sim = max(float(util.cos_sim(pt_emb, ce).item()) for ce in ctx_embs)
|
|
|
|
|
-
|
|
|
|
|
- lex_score = ProductAttributeService._lexical_evidence(product_text, val)
|
|
|
|
|
- final_score = sem_weight * sem_sim + lex_weight * lex_score
|
|
|
|
|
- scores.append((val, final_score))
|
|
|
|
|
-
|
|
|
|
|
- scores.sort(key=lambda x: x[1], reverse=True)
|
|
|
|
|
- best_val, best_score = scores[0]
|
|
|
|
|
-
|
|
|
|
|
- effective_margin = margin
|
|
|
|
|
- if allow_multiple and use_adaptive_margin:
|
|
|
|
|
- effective_margin = ProductAttributeService.get_adaptive_margin(scores, margin)
|
|
|
|
|
-
|
|
|
|
|
- if is_dimension_attr and best_score >= 0.90:
|
|
|
|
|
- source = ProductAttributeService.find_value_source(best_val, source_map, attr)
|
|
|
|
|
- extracted[attr] = [{"value": best_val, "source": source}]
|
|
|
|
|
- continue
|
|
|
|
|
-
|
|
|
|
|
- if not allow_multiple:
|
|
|
|
|
- source = ProductAttributeService.find_value_source(best_val, source_map, attr)
|
|
|
|
|
- extracted[attr] = [{"value": best_val, "source": source}]
|
|
|
|
|
- else:
|
|
|
|
|
- candidates = [best_val]
|
|
|
|
|
- use_base_threshold = best_score >= threshold_abs
|
|
|
|
|
-
|
|
|
|
|
- clusters = []
|
|
|
|
|
- if use_semantic_clustering:
|
|
|
|
|
- clusters = ProductAttributeService.calculate_value_clusters(
|
|
|
|
|
- allowed_values, scores, cluster_threshold=0.4
|
|
|
|
|
- )
|
|
|
|
|
- best_cluster = next((c for c in clusters if best_val in c), [best_val])
|
|
|
|
|
-
|
|
|
|
|
- for val, sc in scores[1:]:
|
|
|
|
|
- min_score = 0.4 if is_dimension_attr else 0.3
|
|
|
|
|
- if sc < min_score:
|
|
|
|
|
- continue
|
|
|
|
|
-
|
|
|
|
|
- if use_dynamic_thresholds and extracted_attrs:
|
|
|
|
|
- dynamic_thresh = ProductAttributeService.get_dynamic_threshold(
|
|
|
|
|
- attr, val, sc, extracted_attrs, relationships,
|
|
|
|
|
- mandatory_attrs, threshold_abs
|
|
|
|
|
- )
|
|
|
|
|
- else:
|
|
|
|
|
- dynamic_thresh = threshold_abs
|
|
|
|
|
-
|
|
|
|
|
- within_margin = (best_score - sc) <= effective_margin
|
|
|
|
|
- above_threshold = sc >= dynamic_thresh
|
|
|
|
|
-
|
|
|
|
|
- in_cluster = False
|
|
|
|
|
- if use_semantic_clustering and clusters:
|
|
|
|
|
- in_cluster = any(best_val in c and val in c for c in clusters)
|
|
|
|
|
-
|
|
|
|
|
- if use_base_threshold:
|
|
|
|
|
- if above_threshold and within_margin:
|
|
|
|
|
- candidates.append(val)
|
|
|
|
|
- elif in_cluster and within_margin:
|
|
|
|
|
- candidates.append(val)
|
|
|
|
|
- else:
|
|
|
|
|
- if within_margin:
|
|
|
|
|
- candidates.append(val)
|
|
|
|
|
- elif in_cluster and (best_score - sc) <= effective_margin * 2.0:
|
|
|
|
|
- candidates.append(val)
|
|
|
|
|
-
|
|
|
|
|
- extracted[attr] = []
|
|
|
|
|
- for candidate in candidates:
|
|
|
|
|
- source = ProductAttributeService.find_value_source(candidate, source_map, attr)
|
|
|
|
|
- extracted[attr].append({"value": candidate, "source": source})
|
|
|
|
|
-
|
|
|
|
|
- return extracted
|
|
|
|
|
|
|
+ @retry(max_attempts=3, delay=1.0)
|
|
|
|
|
+ def _call_llm(payload: dict) -> str:
|
|
|
|
|
+ headers = {"Authorization": f"Bearer {settings.GROQ_API_KEY}", "Content-Type": "application/json"}
|
|
|
|
|
+ resp = requests.post(settings.GROQ_API_URL, headers=headers, json=payload, timeout=30)
|
|
|
|
|
+ resp.raise_for_status()
|
|
|
|
|
+ return resp.json()["choices"][0]["message"]["content"]
|
|
|
|
|
|
|
|
@staticmethod
|
|
@staticmethod
|
|
|
def extract_attributes(
|
|
def extract_attributes(
|
|
@@ -649,315 +220,157 @@ If an attribute is not present, do not include it in the response.
|
|
|
model: str = None,
|
|
model: str = None,
|
|
|
extract_additional: bool = True,
|
|
extract_additional: bool = True,
|
|
|
multiple: Optional[List[str]] = None,
|
|
multiple: Optional[List[str]] = None,
|
|
|
- threshold_abs: float = 0.65,
|
|
|
|
|
- margin: float = 0.15,
|
|
|
|
|
- use_dynamic_thresholds: bool = True,
|
|
|
|
|
- use_adaptive_margin: bool = True,
|
|
|
|
|
- use_semantic_clustering: bool = True,
|
|
|
|
|
- use_cache: bool = None # ⚡ NEW: Can override global setting
|
|
|
|
|
|
|
+ use_cache: Optional[bool] = None,
|
|
|
) -> dict:
|
|
) -> dict:
|
|
|
- """Extract attributes from product text using Groq LLM."""
|
|
|
|
|
-
|
|
|
|
|
- if model is None:
|
|
|
|
|
- model = settings.SUPPORTED_MODELS[0]
|
|
|
|
|
-
|
|
|
|
|
- if multiple is None:
|
|
|
|
|
- multiple = []
|
|
|
|
|
-
|
|
|
|
|
- if source_map is None:
|
|
|
|
|
- source_map = {}
|
|
|
|
|
-
|
|
|
|
|
- # ⚡ CACHE CONTROL: use parameter if provided, otherwise use global setting
|
|
|
|
|
- if use_cache is None:
|
|
|
|
|
- use_cache = ENABLE_ATTRIBUTE_EXTRACTION_CACHE
|
|
|
|
|
-
|
|
|
|
|
- # If caching is globally disabled, force use_cache to False
|
|
|
|
|
- if not is_caching_enabled():
|
|
|
|
|
- use_cache = False
|
|
|
|
|
-
|
|
|
|
|
- if not product_text or product_text == "No product information available":
|
|
|
|
|
- return ProductAttributeService._create_error_response(
|
|
|
|
|
- "No product information provided",
|
|
|
|
|
- mandatory_attrs,
|
|
|
|
|
- extract_additional
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- # ⚡ CHECK CACHE FIRST (only if enabled)
|
|
|
|
|
- if use_cache:
|
|
|
|
|
- cache_key = ProductAttributeService._generate_cache_key(product_text, mandatory_attrs)
|
|
|
|
|
- cached_result = SimpleCache.get(cache_key)
|
|
|
|
|
- if cached_result:
|
|
|
|
|
- logger.info(f"✓ Cache hit (caching enabled)")
|
|
|
|
|
- return cached_result
|
|
|
|
|
- else:
|
|
|
|
|
- logger.info(f"⚠ Cache disabled - processing fresh")
|
|
|
|
|
-
|
|
|
|
|
- mandatory_attr_list = []
|
|
|
|
|
- for attr_name, allowed_values in mandatory_attrs.items():
|
|
|
|
|
- mandatory_attr_list.append(f"{attr_name}: {', '.join(allowed_values)}")
|
|
|
|
|
- mandatory_attr_text = "\n".join(mandatory_attr_list)
|
|
|
|
|
-
|
|
|
|
|
- additional_instruction = ""
|
|
|
|
|
- if extract_additional:
|
|
|
|
|
- additional_instruction = """
|
|
|
|
|
-2. Extract ADDITIONAL attributes: Identify any other relevant attributes from the product text
|
|
|
|
|
- that are NOT in the mandatory list. Only include attributes where you can find actual values
|
|
|
|
|
- in the product text. Do NOT include attributes with "Not Specified" or empty values.
|
|
|
|
|
-
|
|
|
|
|
- Examples of attributes to look for (only if present): Brand, Material, Size, Color, Dimensions,
|
|
|
|
|
- Weight, Features, Style, Theme, Pattern, Finish, Care Instructions, etc."""
|
|
|
|
|
-
|
|
|
|
|
- output_format = {
|
|
|
|
|
- "mandatory": {attr: "value or list of values" for attr in mandatory_attrs.keys()},
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ if model is None: model = settings.SUPPORTED_MODELS[0]
|
|
|
|
|
+ if multiple is None: multiple = []
|
|
|
|
|
+ if source_map is None: source_map = {}
|
|
|
|
|
|
|
|
- if extract_additional:
|
|
|
|
|
- output_format["additional"] = {
|
|
|
|
|
- "example_attribute_1": "actual value found",
|
|
|
|
|
- "example_attribute_2": "actual value found"
|
|
|
|
|
- }
|
|
|
|
|
- output_format["additional"]["_note"] = "Only include attributes with actual values found in text"
|
|
|
|
|
|
|
+ if use_cache is None: use_cache = ENABLE_ATTRIBUTE_EXTRACTION_CACHE
|
|
|
|
|
+ if not is_caching_enabled(): use_cache = False
|
|
|
|
|
+
|
|
|
|
|
+ cache_key = None
|
|
|
|
|
+ if use_cache:
|
|
|
|
|
+ cache_key = ProductAttributeService._cache_key(product_text, mandatory_attrs, extract_additional, multiple)
|
|
|
|
|
+ cached = SimpleCache.get(cache_key)
|
|
|
|
|
+ if cached:
|
|
|
|
|
+ logger.info(f"CACHE HIT {cache_key[:16]}...")
|
|
|
|
|
+ return cached
|
|
|
|
|
+
|
|
|
|
|
+ # --------------------------- PROMPT WITH FULL WHITELIST ---------------------------
|
|
|
|
|
+ allowed_lines = [f"{attr}: {', '.join(vals)}" for attr, vals in mandatory_attrs.items()]
|
|
|
|
|
+ allowed_text = "\n".join(allowed_lines)
|
|
|
|
|
+ allowed_sources = list(source_map.keys()) + ["not_found"]
|
|
|
|
|
+ source_hint = "|".join(allowed_sources)
|
|
|
|
|
+ multiple_text = f"\nMULTIPLE ALLOWED FOR: {', '.join(multiple)}" if multiple else ""
|
|
|
|
|
|
|
|
prompt = f"""
|
|
prompt = f"""
|
|
|
-You are an intelligent product attribute extractor that works with ANY product type.
|
|
|
|
|
|
|
+You are a product-attribute classifier.
|
|
|
|
|
+Pick **exactly one** value from the list below for each attribute.
|
|
|
|
|
+If nothing matches, return "Not Specified".
|
|
|
|
|
|
|
|
-TASK:
|
|
|
|
|
-1. Extract MANDATORY attributes: For each mandatory attribute, select the most appropriate value(s)
|
|
|
|
|
- from the provided list. Choose the value(s) that best match the product description.
|
|
|
|
|
-{additional_instruction}
|
|
|
|
|
|
|
+ALLOWED VALUES:
|
|
|
|
|
+{allowed_text}
|
|
|
|
|
+{multiple_text}
|
|
|
|
|
|
|
|
-Product Text:
|
|
|
|
|
|
|
+PRODUCT TEXT:
|
|
|
{product_text}
|
|
{product_text}
|
|
|
|
|
|
|
|
-Mandatory Attribute Lists (MUST select from these allowed values):
|
|
|
|
|
-{mandatory_attr_text}
|
|
|
|
|
-
|
|
|
|
|
-CRITICAL INSTRUCTIONS:
|
|
|
|
|
-- Return ONLY valid JSON, nothing else
|
|
|
|
|
-- No explanations, no markdown, no text before or after the JSON
|
|
|
|
|
-- For mandatory attributes, choose the value(s) from the provided list that best match
|
|
|
|
|
-- If a mandatory attribute cannot be determined from the product text, use "Not Specified"
|
|
|
|
|
-- Prefer exact matches from the allowed values list over generic synonyms
|
|
|
|
|
-- If multiple values are plausible, you MAY return more than one
|
|
|
|
|
-{f"- For additional attributes: ONLY include attributes where you found actual values in the product text. DO NOT include attributes with 'Not Specified', 'None', 'N/A', or empty values. If you cannot find a value for an attribute, simply don't include that attribute." if extract_additional else ""}
|
|
|
|
|
-- Be precise and only extract information that is explicitly stated or clearly implied
|
|
|
|
|
-
|
|
|
|
|
-Required Output Format:
|
|
|
|
|
-{json.dumps(output_format, indent=2)}
|
|
|
|
|
- """
|
|
|
|
|
|
|
+OUTPUT (strict JSON only):
|
|
|
|
|
+{{
|
|
|
|
|
+ "mandatory": {{
|
|
|
|
|
+ "<attr>": [{{"value":"<chosen>", "source":"<{source_hint}>"}}]
|
|
|
|
|
+ }},
|
|
|
|
|
+ "additional": {{}}
|
|
|
|
|
+}}
|
|
|
|
|
+
|
|
|
|
|
+RULES:
|
|
|
|
|
+- Pick from allowed values only
|
|
|
|
|
+- If not found: "Not Specified" + "not_found"
|
|
|
|
|
+- Source must be: {source_hint}
|
|
|
|
|
+- Return ONLY JSON
|
|
|
|
|
+"""
|
|
|
|
|
|
|
|
payload = {
|
|
payload = {
|
|
|
"model": model,
|
|
"model": model,
|
|
|
"messages": [
|
|
"messages": [
|
|
|
- {
|
|
|
|
|
- "role": "system",
|
|
|
|
|
- "content": f"You are a precise attribute extraction model. Return ONLY valid JSON with {'mandatory and additional' if extract_additional else 'mandatory'} sections. No explanations, no markdown, no other text."
|
|
|
|
|
- },
|
|
|
|
|
- {"role": "user", "content": prompt}
|
|
|
|
|
|
|
+ {"role": "system", "content": "You are a JSON-only extractor."},
|
|
|
|
|
+ {"role": "user", "content": prompt},
|
|
|
],
|
|
],
|
|
|
"temperature": 0.0,
|
|
"temperature": 0.0,
|
|
|
- "max_tokens": 1500
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- headers = {
|
|
|
|
|
- "Authorization": f"Bearer {settings.GROQ_API_KEY}",
|
|
|
|
|
- "Content-Type": "application/json",
|
|
|
|
|
|
|
+ "max_tokens": 1200,
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
try:
|
|
try:
|
|
|
- response = requests.post(
|
|
|
|
|
- settings.GROQ_API_URL,
|
|
|
|
|
- headers=headers,
|
|
|
|
|
- json=payload,
|
|
|
|
|
- timeout=30
|
|
|
|
|
- )
|
|
|
|
|
- response.raise_for_status()
|
|
|
|
|
- result_text = response.json()["choices"][0]["message"]["content"].strip()
|
|
|
|
|
-
|
|
|
|
|
- result_text = ProductAttributeService._clean_json_response(result_text)
|
|
|
|
|
-
|
|
|
|
|
- parsed = json.loads(result_text)
|
|
|
|
|
-
|
|
|
|
|
- parsed = ProductAttributeService._validate_response_structure(
|
|
|
|
|
- parsed, mandatory_attrs, extract_additional, source_map
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- if extract_additional and "additional" in parsed:
|
|
|
|
|
- cleaned_additional = {}
|
|
|
|
|
- for k, v in parsed["additional"].items():
|
|
|
|
|
- if v and v not in ["Not Specified", "None", "N/A", "", "not specified", "none", "n/a"]:
|
|
|
|
|
- if not (isinstance(v, str) and v.lower() in ["not specified", "none", "n/a", ""]):
|
|
|
|
|
- if isinstance(v, list):
|
|
|
|
|
- cleaned_additional[k] = []
|
|
|
|
|
- for item in v:
|
|
|
|
|
- if isinstance(item, dict) and "value" in item:
|
|
|
|
|
- if "source" not in item:
|
|
|
|
|
- item["source"] = ProductAttributeService.find_value_source(
|
|
|
|
|
- item["value"], source_map, k
|
|
|
|
|
- )
|
|
|
|
|
- cleaned_additional[k].append(item)
|
|
|
|
|
- else:
|
|
|
|
|
- source = ProductAttributeService.find_value_source(str(item), source_map, k)
|
|
|
|
|
- cleaned_additional[k].append({"value": str(item), "source": source})
|
|
|
|
|
- else:
|
|
|
|
|
- source = ProductAttributeService.find_value_source(str(v), source_map, k)
|
|
|
|
|
- cleaned_additional[k] = [{"value": str(v), "source": source}]
|
|
|
|
|
- parsed["additional"] = cleaned_additional
|
|
|
|
|
-
|
|
|
|
|
- relationships = {}
|
|
|
|
|
- if use_dynamic_thresholds:
|
|
|
|
|
- relationships = ProductAttributeService.calculate_attribute_relationships(
|
|
|
|
|
- mandatory_attrs, product_text
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- extracted_so_far = {}
|
|
|
|
|
- for attr in mandatory_attrs.keys():
|
|
|
|
|
- allow_multiple = attr in multiple
|
|
|
|
|
-
|
|
|
|
|
- result = ProductAttributeService.normalize_against_product_text(
|
|
|
|
|
- product_text=product_text,
|
|
|
|
|
- mandatory_attrs={attr: mandatory_attrs[attr]},
|
|
|
|
|
- source_map=source_map,
|
|
|
|
|
- threshold_abs=threshold_abs,
|
|
|
|
|
- margin=margin,
|
|
|
|
|
- allow_multiple=allow_multiple,
|
|
|
|
|
- extracted_attrs=extracted_so_far,
|
|
|
|
|
- relationships=relationships,
|
|
|
|
|
- use_dynamic_thresholds=use_dynamic_thresholds,
|
|
|
|
|
- use_adaptive_margin=use_adaptive_margin,
|
|
|
|
|
- use_semantic_clustering=use_semantic_clustering
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- parsed["mandatory"][attr] = result[attr]
|
|
|
|
|
- extracted_so_far[attr] = result[attr]
|
|
|
|
|
-
|
|
|
|
|
- # ⚡ CACHE THE RESULT (only if enabled)
|
|
|
|
|
- if use_cache:
|
|
|
|
|
- SimpleCache.set(cache_key, parsed)
|
|
|
|
|
- logger.info(f"✓ Result cached")
|
|
|
|
|
-
|
|
|
|
|
- return parsed
|
|
|
|
|
-
|
|
|
|
|
- except requests.exceptions.RequestException as e:
|
|
|
|
|
- logger.error(f"Request exception: {str(e)}")
|
|
|
|
|
- return ProductAttributeService._create_error_response(
|
|
|
|
|
- str(e), mandatory_attrs, extract_additional
|
|
|
|
|
- )
|
|
|
|
|
- except json.JSONDecodeError as e:
|
|
|
|
|
- logger.error(f"JSON decode error: {str(e)}")
|
|
|
|
|
- return ProductAttributeService._create_error_response(
|
|
|
|
|
- f"Invalid JSON: {str(e)}", mandatory_attrs, extract_additional, result_text
|
|
|
|
|
- )
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- logger.error(f"Unexpected error: {str(e)}")
|
|
|
|
|
- return ProductAttributeService._create_error_response(
|
|
|
|
|
- str(e), mandatory_attrs, extract_additional
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- @staticmethod
|
|
|
|
|
- def _clean_json_response(text: str) -> str:
|
|
|
|
|
- """Clean LLM response to extract valid JSON."""
|
|
|
|
|
- start_idx = text.find('{')
|
|
|
|
|
- end_idx = text.rfind('}')
|
|
|
|
|
|
|
+ raw = ProductAttributeService._call_llm(payload)
|
|
|
|
|
+ cleaned = ProductAttributeService._clean_json(raw)
|
|
|
|
|
+ parsed = json.loads(cleaned)
|
|
|
|
|
+ except Exception as exc:
|
|
|
|
|
+ logger.error(f"LLM failed: {exc}")
|
|
|
|
|
+ return {
|
|
|
|
|
+ "mandatory": {a: [{"value": "Not Specified", "source": "llm_error"}] for a in mandatory_attrs},
|
|
|
|
|
+ "additional": {} if not extract_additional else {},
|
|
|
|
|
+ "error": str(exc)
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- if start_idx != -1 and end_idx != -1:
|
|
|
|
|
- text = text[start_idx:end_idx + 1]
|
|
|
|
|
|
|
+ # --------------------------- VALIDATION + SMART RECOVERY ---------------------------
|
|
|
|
|
+ pt_emb = EmbeddingCache.get_embedding(product_text, model_embedder)
|
|
|
|
|
|
|
|
- if "```json" in text:
|
|
|
|
|
- text = text.split("```json")[1].split("```")[0].strip()
|
|
|
|
|
- elif "```" in text:
|
|
|
|
|
- text = text.split("```")[1].split("```")[0].strip()
|
|
|
|
|
- if text.startswith("json"):
|
|
|
|
|
- text = text[4:].strip()
|
|
|
|
|
|
|
+ def _sanitize(section: dict, allowed: Dict):
|
|
|
|
|
+ sanitized = {}
|
|
|
|
|
+ for attr, items in section.items():
|
|
|
|
|
+ if attr not in allowed: continue
|
|
|
|
|
+ chosen = []
|
|
|
|
|
+ for it in (items if isinstance(items, list) else [items]):
|
|
|
|
|
+ if not isinstance(it, dict): it = {"value": str(it), "source": "not_found"}
|
|
|
|
|
+ val = str(it.get("value", "")).strip()
|
|
|
|
|
+ src = str(it.get("source", "not_found")).lower()
|
|
|
|
|
+
|
|
|
|
|
+ # --- LLM SAYS "Not Specified" → SMART RECOVERY ---
|
|
|
|
|
+ if val == "Not Specified":
|
|
|
|
|
+ # 1. Lexical recovery
|
|
|
|
|
+ for av in allowed[attr]:
|
|
|
|
|
+ if ProductAttributeService._lexical_evidence(product_text, av) > 0.6:
|
|
|
|
|
+ src = ProductAttributeService._find_source(av, source_map)
|
|
|
|
|
+ chosen.append({"value": av, "source": src})
|
|
|
|
|
+ break
|
|
|
|
|
+ else:
|
|
|
|
|
+ # 2. Semantic recovery
|
|
|
|
|
+ best_val, best_score = max(
|
|
|
|
|
+ ((av, float(util.cos_sim(pt_emb, EmbeddingCache.get_embedding(av, model_embedder)).item()))
|
|
|
|
|
+ for av in allowed[attr]),
|
|
|
|
|
+ key=lambda x: x[1]
|
|
|
|
|
+ )
|
|
|
|
|
+ if best_score > 0.75:
|
|
|
|
|
+ src = ProductAttributeService._find_source(best_val, source_map)
|
|
|
|
|
+ chosen.append({"value": best_val, "source": src})
|
|
|
|
|
+ else:
|
|
|
|
|
+ chosen.append({"value": "Not Specified", "source": "not_found"})
|
|
|
|
|
+ continue
|
|
|
|
|
|
|
|
- return text
|
|
|
|
|
|
|
+ # --- VALIDATE LLM CHOICE ---
|
|
|
|
|
+ if val not in allowed[attr]: continue
|
|
|
|
|
+ if ProductAttributeService._lexical_evidence(product_text, val) < 0.2: continue
|
|
|
|
|
+ if src not in source_map and src != "not_found": src = "not_found"
|
|
|
|
|
+ chosen.append({"value": val, "source": src})
|
|
|
|
|
+
|
|
|
|
|
+ sanitized[attr] = chosen or [{"value": "Not Specified", "source": "not_found"}]
|
|
|
|
|
+ return sanitized
|
|
|
|
|
+
|
|
|
|
|
+ parsed["mandatory"] = _sanitize(parsed.get("mandatory", {}), mandatory_attrs)
|
|
|
|
|
+
|
|
|
|
|
+ # --- ADDITIONAL ATTRIBUTES ---
|
|
|
|
|
+ if extract_additional and "additional" in parsed:
|
|
|
|
|
+ additional = {}
|
|
|
|
|
+ for attr, items in parsed["additional"].items():
|
|
|
|
|
+ good = []
|
|
|
|
|
+ for it in (items if isinstance(items, list) else [items]):
|
|
|
|
|
+ if not isinstance(it, dict): it = {"value": str(it), "source": "not_found"}
|
|
|
|
|
+ val = str(it.get("value", "")).strip()
|
|
|
|
|
+ src = str(it.get("source", "not_found")).lower()
|
|
|
|
|
+ if src not in source_map and src != "not_found": src = "not_found"
|
|
|
|
|
+ if val: good.append({"value": val, "source": src})
|
|
|
|
|
+ if good: additional[attr] = good
|
|
|
|
|
+ parsed["additional"] = additional
|
|
|
|
|
+ else:
|
|
|
|
|
+ parsed.pop("additional", None)
|
|
|
|
|
|
|
|
- @staticmethod
|
|
|
|
|
- def _validate_response_structure(
|
|
|
|
|
- parsed: dict,
|
|
|
|
|
- mandatory_attrs: Dict[str, List[str]],
|
|
|
|
|
- extract_additional: bool,
|
|
|
|
|
- source_map: Dict[str, str] = None
|
|
|
|
|
- ) -> dict:
|
|
|
|
|
- """Validate and fix the response structure."""
|
|
|
|
|
- if source_map is None:
|
|
|
|
|
- source_map = {}
|
|
|
|
|
-
|
|
|
|
|
- expected_sections = ["mandatory"]
|
|
|
|
|
- if extract_additional:
|
|
|
|
|
- expected_sections.append("additional")
|
|
|
|
|
-
|
|
|
|
|
- if not all(section in parsed for section in expected_sections):
|
|
|
|
|
- if isinstance(parsed, dict):
|
|
|
|
|
- mandatory_keys = set(mandatory_attrs.keys())
|
|
|
|
|
- mandatory = {k: v for k, v in parsed.items() if k in mandatory_keys}
|
|
|
|
|
- additional = {k: v for k, v in parsed.items() if k not in mandatory_keys}
|
|
|
|
|
-
|
|
|
|
|
- result = {"mandatory": mandatory}
|
|
|
|
|
- if extract_additional:
|
|
|
|
|
- result["additional"] = additional
|
|
|
|
|
- parsed = result
|
|
|
|
|
- else:
|
|
|
|
|
- return ProductAttributeService._create_error_response(
|
|
|
|
|
- "Invalid response structure",
|
|
|
|
|
- mandatory_attrs,
|
|
|
|
|
- extract_additional,
|
|
|
|
|
- str(parsed)
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- if "mandatory" in parsed:
|
|
|
|
|
- converted_mandatory = {}
|
|
|
|
|
- for attr, value in parsed["mandatory"].items():
|
|
|
|
|
- if isinstance(value, list):
|
|
|
|
|
- converted_mandatory[attr] = []
|
|
|
|
|
- for item in value:
|
|
|
|
|
- if isinstance(item, dict) and "value" in item:
|
|
|
|
|
- if "source" not in item:
|
|
|
|
|
- item["source"] = ProductAttributeService.find_value_source(
|
|
|
|
|
- item["value"], source_map, attr
|
|
|
|
|
- )
|
|
|
|
|
- converted_mandatory[attr].append(item)
|
|
|
|
|
- else:
|
|
|
|
|
- source = ProductAttributeService.find_value_source(str(item), source_map, attr)
|
|
|
|
|
- converted_mandatory[attr].append({"value": str(item), "source": source})
|
|
|
|
|
- else:
|
|
|
|
|
- source = ProductAttributeService.find_value_source(str(value), source_map, attr)
|
|
|
|
|
- converted_mandatory[attr] = [{"value": str(value), "source": source}]
|
|
|
|
|
-
|
|
|
|
|
- parsed["mandatory"] = converted_mandatory
|
|
|
|
|
|
|
+ if use_cache and cache_key:
|
|
|
|
|
+ SimpleCache.set(cache_key, parsed)
|
|
|
|
|
+ logger.info(f"CACHE SET {cache_key[:16]}...")
|
|
|
|
|
|
|
|
return parsed
|
|
return parsed
|
|
|
|
|
|
|
|
- @staticmethod
|
|
|
|
|
- def _create_error_response(
|
|
|
|
|
- error: str,
|
|
|
|
|
- mandatory_attrs: Dict[str, List[str]],
|
|
|
|
|
- extract_additional: bool,
|
|
|
|
|
- raw_output: Optional[str] = None
|
|
|
|
|
- ) -> dict:
|
|
|
|
|
- """Create a standardized error response."""
|
|
|
|
|
- response = {
|
|
|
|
|
- "mandatory": {attr: [{"value": "Not Specified", "source": "error"}] for attr in mandatory_attrs.keys()},
|
|
|
|
|
- "error": error
|
|
|
|
|
- }
|
|
|
|
|
- if extract_additional:
|
|
|
|
|
- response["additional"] = {}
|
|
|
|
|
- if raw_output:
|
|
|
|
|
- response["raw_output"] = raw_output
|
|
|
|
|
- return response
|
|
|
|
|
-
|
|
|
|
|
@staticmethod
|
|
@staticmethod
|
|
|
def get_cache_stats() -> Dict:
|
|
def get_cache_stats() -> Dict:
|
|
|
- """Get statistics for all caches including global status."""
|
|
|
|
|
return {
|
|
return {
|
|
|
- "global_caching_enabled": is_caching_enabled(),
|
|
|
|
|
- "simple_cache": SimpleCache.get_stats(),
|
|
|
|
|
- "embedding_cache": EmbeddingCache.get_stats()
|
|
|
|
|
|
|
+ "global_enabled": is_caching_enabled(),
|
|
|
|
|
+ "result_cache": SimpleCache.get_stats(),
|
|
|
|
|
+ "embedding_cache": EmbeddingCache.get_stats(),
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
@staticmethod
|
|
@staticmethod
|
|
|
def clear_all_caches():
|
|
def clear_all_caches():
|
|
|
- """Clear all caches."""
|
|
|
|
|
SimpleCache.clear()
|
|
SimpleCache.clear()
|
|
|
EmbeddingCache.clear()
|
|
EmbeddingCache.clear()
|
|
|
logger.info("All caches cleared")
|
|
logger.info("All caches cleared")
|