| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963 |
- # ==================== services.py (WITH CACHE CONTROL) ====================
- import requests
- import json
- import re
- import hashlib
- import logging
- from typing import Dict, List, Optional, Tuple
- from django.conf import settings
- from concurrent.futures import ThreadPoolExecutor, as_completed
- from sentence_transformers import SentenceTransformer, util
- import numpy as np
- # ⚡ IMPORT CACHE CONFIGURATION
- from .cache_config import (
- is_caching_enabled,
- ENABLE_ATTRIBUTE_EXTRACTION_CACHE,
- ENABLE_EMBEDDING_CACHE,
- ATTRIBUTE_CACHE_MAX_SIZE,
- EMBEDDING_CACHE_MAX_SIZE
- )
- logger = logging.getLogger(__name__)
- # ⚡ CRITICAL FIX: Initialize embedding model ONCE at module level
- print("Loading sentence transformer model (one-time initialization)...")
- model_embedder = SentenceTransformer("all-MiniLM-L6-v2")
- # Disable progress bars to prevent "Batches: 100%" spam
- import os
- os.environ['TOKENIZERS_PARALLELISM'] = 'false'
- print("✓ Model loaded successfully")
- # ==================== CACHING CLASSES ====================
- class SimpleCache:
- """In-memory cache for attribute extraction results."""
- _cache = {}
- _max_size = ATTRIBUTE_CACHE_MAX_SIZE
-
- @classmethod
- def get(cls, key: str) -> Optional[Dict]:
- """Get value from cache. Returns None if caching is disabled."""
- if not ENABLE_ATTRIBUTE_EXTRACTION_CACHE:
- return None
- return cls._cache.get(key)
-
- @classmethod
- def set(cls, key: str, value: Dict):
- """Set value in cache. Does nothing if caching is disabled."""
- if not ENABLE_ATTRIBUTE_EXTRACTION_CACHE:
- return
-
- if len(cls._cache) >= cls._max_size:
- items = list(cls._cache.items())
- cls._cache = dict(items[int(cls._max_size * 0.2):])
- cls._cache[key] = value
-
- @classmethod
- def clear(cls):
- """Clear the cache."""
- cls._cache.clear()
-
- @classmethod
- def get_stats(cls) -> Dict:
- """Get cache statistics."""
- return {
- "enabled": ENABLE_ATTRIBUTE_EXTRACTION_CACHE,
- "size": len(cls._cache),
- "max_size": cls._max_size,
- "usage_percent": round(len(cls._cache) / cls._max_size * 100, 2) if cls._max_size > 0 else 0
- }
- class EmbeddingCache:
- """Cache for sentence transformer embeddings."""
- _cache = {}
- _max_size = EMBEDDING_CACHE_MAX_SIZE
- _hit_count = 0
- _miss_count = 0
-
- @classmethod
- def get_embedding(cls, text: str, model):
- """Get or compute embedding with optional caching"""
- # If caching is disabled, always compute fresh
- if not ENABLE_EMBEDDING_CACHE:
- import warnings
- with warnings.catch_warnings():
- warnings.simplefilter("ignore")
- embedding = model.encode(text, convert_to_tensor=True, show_progress_bar=False)
- return embedding
-
- # Caching is enabled, check cache first
- if text in cls._cache:
- cls._hit_count += 1
- return cls._cache[text]
-
- cls._miss_count += 1
-
- if len(cls._cache) >= cls._max_size:
- items = list(cls._cache.items())
- cls._cache = dict(items[int(cls._max_size * 0.3):])
-
- # Compute embedding
- import warnings
- with warnings.catch_warnings():
- warnings.simplefilter("ignore")
- embedding = model.encode(text, convert_to_tensor=True, show_progress_bar=False)
-
- cls._cache[text] = embedding
- return embedding
-
- @classmethod
- def clear(cls):
- """Clear the cache and reset statistics."""
- cls._cache.clear()
- cls._hit_count = 0
- cls._miss_count = 0
-
- @classmethod
- def get_stats(cls) -> Dict:
- """Get cache statistics."""
- total = cls._hit_count + cls._miss_count
- hit_rate = (cls._hit_count / total * 100) if total > 0 else 0
- return {
- "enabled": ENABLE_EMBEDDING_CACHE,
- "size": len(cls._cache),
- "hits": cls._hit_count,
- "misses": cls._miss_count,
- "hit_rate_percent": round(hit_rate, 2)
- }
- # ==================== MAIN SERVICE CLASS ====================
- class ProductAttributeService:
- """Service class for extracting product attributes using Groq LLM."""
- @staticmethod
- def _generate_cache_key(product_text: str, mandatory_attrs: Dict) -> str:
- """Generate cache key from product text and attributes."""
- attrs_str = json.dumps(mandatory_attrs, sort_keys=True)
- content = f"{product_text}:{attrs_str}"
- return f"attr_{hashlib.md5(content.encode()).hexdigest()}"
- @staticmethod
- def normalize_dimension_text(text: str) -> str:
- """Normalize dimension text to format like '16x20'."""
- if not text:
- return ""
-
- text = text.lower()
- text = re.sub(r'\s*(inches|inch|in|cm|centimeters|mm|millimeters)\s*', '', text, flags=re.IGNORECASE)
-
- numbers = re.findall(r'\d+\.?\d*', text)
- if not numbers:
- return ""
-
- float_numbers = []
- for num in numbers:
- try:
- float_numbers.append(float(num))
- except:
- continue
-
- if len(float_numbers) < 2:
- return ""
-
- if len(float_numbers) == 3:
- float_numbers = [float_numbers[0], float_numbers[2]]
- elif len(float_numbers) > 3:
- float_numbers = sorted(float_numbers)[-2:]
- else:
- float_numbers = float_numbers[:2]
-
- formatted_numbers = []
- for num in float_numbers:
- if num.is_integer():
- formatted_numbers.append(str(int(num)))
- else:
- formatted_numbers.append(f"{num:.1f}")
-
- formatted_numbers.sort(key=lambda x: float(x))
- return f"{formatted_numbers[0]}x{formatted_numbers[1]}"
-
- @staticmethod
- def normalize_value_for_matching(value: str, attr_name: str = "") -> str:
- """Normalize a value based on its attribute type."""
- dimension_keywords = ['dimension', 'size', 'measurement']
- if any(keyword in attr_name.lower() for keyword in dimension_keywords):
- normalized = ProductAttributeService.normalize_dimension_text(value)
- if normalized:
- return normalized
- return value.strip()
- @staticmethod
- def combine_product_text(
- title: Optional[str] = None,
- short_desc: Optional[str] = None,
- long_desc: Optional[str] = None,
- ocr_text: Optional[str] = None
- ) -> Tuple[str, Dict[str, str]]:
- """Combine product metadata into a single text block."""
- parts = []
- source_map = {}
-
- if title:
- title_str = str(title).strip()
- parts.append(f"Title: {title_str}")
- source_map['title'] = title_str
- if short_desc:
- short_str = str(short_desc).strip()
- parts.append(f"Description: {short_str}")
- source_map['short_desc'] = short_str
- if long_desc:
- long_str = str(long_desc).strip()
- parts.append(f"Details: {long_str}")
- source_map['long_desc'] = long_str
- if ocr_text:
- parts.append(f"OCR Text: {ocr_text}")
- source_map['ocr_text'] = ocr_text
-
- combined = "\n".join(parts).strip()
-
- if not combined:
- return "No product information available", {}
-
- return combined, source_map
- @staticmethod
- def find_value_source(value: str, source_map: Dict[str, str], attr_name: str = "") -> str:
- """Find which source(s) contain the given value."""
- value_lower = value.lower()
- value_tokens = set(value_lower.replace("-", " ").replace("x", " ").split())
-
- is_dimension_attr = any(keyword in attr_name.lower() for keyword in ['dimension', 'size', 'measurement'])
-
- sources_found = []
- source_scores = {}
-
- for source_name, source_text in source_map.items():
- source_lower = source_text.lower()
-
- if value_lower in source_lower:
- source_scores[source_name] = 1.0
- continue
-
- if is_dimension_attr:
- normalized_value = ProductAttributeService.normalize_dimension_text(value)
- if not normalized_value:
- normalized_value = value.replace("x", " ").strip()
-
- normalized_source = ProductAttributeService.normalize_dimension_text(source_text)
-
- if normalized_value == normalized_source:
- source_scores[source_name] = 0.95
- continue
-
- dim_parts = normalized_value.split("x") if "x" in normalized_value else []
- if len(dim_parts) == 2:
- if all(part in source_text for part in dim_parts):
- source_scores[source_name] = 0.85
- continue
-
- token_matches = sum(1 for token in value_tokens if token and token in source_lower)
- if token_matches > 0 and len(value_tokens) > 0:
- source_scores[source_name] = token_matches / len(value_tokens)
-
- if source_scores:
- max_score = max(source_scores.values())
- sources_found = [s for s, score in source_scores.items() if score == max_score]
-
- priority = ['title', 'short_desc', 'long_desc', 'ocr_text']
- for p in priority:
- if p in sources_found:
- return p
-
- return sources_found[0] if sources_found else "Not found"
-
- return "Not found"
- @staticmethod
- def format_visual_attributes(visual_attributes: Dict) -> Dict:
- """Convert visual attributes to array format with source tracking."""
- formatted = {}
-
- for key, value in visual_attributes.items():
- if isinstance(value, list):
- formatted[key] = [{"value": str(item), "source": "image"} for item in value]
- elif isinstance(value, dict):
- nested_formatted = {}
- for nested_key, nested_value in value.items():
- if isinstance(nested_value, list):
- nested_formatted[nested_key] = [{"value": str(item), "source": "image"} for item in nested_value]
- else:
- nested_formatted[nested_key] = [{"value": str(nested_value), "source": "image"}]
- formatted[key] = nested_formatted
- else:
- formatted[key] = [{"value": str(value), "source": "image"}]
-
- return formatted
- @staticmethod
- def extract_attributes_from_ocr(ocr_results: Dict, model: str = None) -> Dict:
- """Extract structured attributes from OCR text using LLM."""
- if model is None:
- model = settings.SUPPORTED_MODELS[0]
-
- detected_text = ocr_results.get('detected_text', [])
- if not detected_text:
- return {}
-
- ocr_text = "\n".join([f"Text: {item['text']}, Confidence: {item['confidence']:.2f}"
- for item in detected_text])
-
- prompt = f"""
- You are an AI model that extracts structured attributes from OCR text detected on product images.
- Given the OCR detections below, infer the possible product attributes and return them as a clean JSON object.
- OCR Text:
- {ocr_text}
- Extract relevant attributes like:
- - brand
- - model_number
- - size (waist_size, length, etc.)
- - collection
- - any other relevant product information
- Return a JSON object with only the attributes you can confidently identify.
- If an attribute is not present, do not include it in the response.
- """
-
- payload = {
- "model": model,
- "messages": [
- {
- "role": "system",
- "content": "You are a helpful AI that extracts structured data from OCR output. Return only valid JSON."
- },
- {"role": "user", "content": prompt}
- ],
- "temperature": 0.2,
- "max_tokens": 500
- }
-
- headers = {
- "Authorization": f"Bearer {settings.GROQ_API_KEY}",
- "Content-Type": "application/json",
- }
-
- try:
- response = requests.post(
- settings.GROQ_API_URL,
- headers=headers,
- json=payload,
- timeout=30
- )
- response.raise_for_status()
- result_text = response.json()["choices"][0]["message"]["content"].strip()
-
- result_text = ProductAttributeService._clean_json_response(result_text)
- parsed = json.loads(result_text)
-
- formatted_attributes = {}
- for key, value in parsed.items():
- if key == "error":
- continue
-
- if isinstance(value, dict):
- nested_formatted = {}
- for nested_key, nested_value in value.items():
- nested_formatted[nested_key] = [{"value": str(nested_value), "source": "image"}]
- formatted_attributes[key] = nested_formatted
- elif isinstance(value, list):
- formatted_attributes[key] = [{"value": str(item), "source": "image"} for item in value]
- else:
- formatted_attributes[key] = [{"value": str(value), "source": "image"}]
-
- return formatted_attributes
- except Exception as e:
- logger.error(f"OCR attribute extraction failed: {str(e)}")
- return {"error": f"Failed to extract attributes from OCR: {str(e)}"}
- @staticmethod
- def calculate_attribute_relationships(
- mandatory_attrs: Dict[str, List[str]],
- product_text: str
- ) -> Dict[str, float]:
- """Calculate semantic relationships between attribute values."""
- pt_emb = EmbeddingCache.get_embedding(product_text, model_embedder)
- attr_scores = {}
- for attr, values in mandatory_attrs.items():
- attr_scores[attr] = {}
- for val in values:
- contexts = [val, f"for {val}", f"use in {val}", f"suitable for {val}"]
- ctx_embs = [EmbeddingCache.get_embedding(c, model_embedder) for c in contexts]
- sem_sim = max(float(util.cos_sim(pt_emb, ce).item()) for ce in ctx_embs)
- attr_scores[attr][val] = sem_sim
- relationships = {}
- attr_list = list(mandatory_attrs.keys())
- for i, attr1 in enumerate(attr_list):
- for attr2 in attr_list[i+1:]:
- for val1 in mandatory_attrs[attr1]:
- for val2 in mandatory_attrs[attr2]:
- emb1 = EmbeddingCache.get_embedding(val1, model_embedder)
- emb2 = EmbeddingCache.get_embedding(val2, model_embedder)
- sim = float(util.cos_sim(emb1, emb2).item())
- key1 = f"{attr1}:{val1}->{attr2}:{val2}"
- key2 = f"{attr2}:{val2}->{attr1}:{val1}"
- relationships[key1] = sim
- relationships[key2] = sim
- return relationships
- @staticmethod
- def calculate_value_clusters(
- values: List[str],
- scores: List[Tuple[str, float]],
- cluster_threshold: float = 0.4
- ) -> List[List[str]]:
- """Group values into semantic clusters."""
- if len(values) <= 1:
- return [[val] for val, _ in scores]
- embeddings = [EmbeddingCache.get_embedding(val, model_embedder) for val in values]
- similarity_matrix = np.zeros((len(values), len(values)))
- for i in range(len(values)):
- for j in range(i+1, len(values)):
- sim = float(util.cos_sim(embeddings[i], embeddings[j]).item())
- similarity_matrix[i][j] = sim
- similarity_matrix[j][i] = sim
- clusters = []
- visited = set()
- for i, (val, score) in enumerate(scores):
- if i in visited:
- continue
- cluster = [val]
- visited.add(i)
- for j in range(len(values)):
- if j not in visited and similarity_matrix[i][j] >= cluster_threshold:
- cluster.append(values[j])
- visited.add(j)
- clusters.append(cluster)
- return clusters
- @staticmethod
- def get_dynamic_threshold(
- attr: str,
- val: str,
- base_score: float,
- extracted_attrs: Dict[str, List[Dict[str, str]]],
- relationships: Dict[str, float],
- mandatory_attrs: Dict[str, List[str]],
- base_threshold: float = 0.65,
- boost_factor: float = 0.15
- ) -> float:
- """Calculate dynamic threshold based on relationships."""
- threshold = base_threshold
- max_relationship = 0.0
- for other_attr, other_values_list in extracted_attrs.items():
- if other_attr == attr:
- continue
- for other_val_dict in other_values_list:
- other_val = other_val_dict['value']
- key = f"{attr}:{val}->{other_attr}:{other_val}"
- if key in relationships:
- max_relationship = max(max_relationship, relationships[key])
- if max_relationship > 0.6:
- threshold = base_threshold - (boost_factor * max_relationship)
- return max(0.3, threshold)
- @staticmethod
- def get_adaptive_margin(
- scores: List[Tuple[str, float]],
- base_margin: float = 0.15,
- max_margin: float = 0.22
- ) -> float:
- """Calculate adaptive margin based on score distribution."""
- if len(scores) < 2:
- return base_margin
- score_values = [s for _, s in scores]
- best_score = score_values[0]
- if best_score < 0.5:
- top_scores = score_values[:min(4, len(score_values))]
- score_range = max(top_scores) - min(top_scores)
- if score_range < 0.30:
- score_factor = (0.5 - best_score) * 0.35
- adaptive = base_margin + score_factor + (0.30 - score_range) * 0.2
- return min(adaptive, max_margin)
- return base_margin
- @staticmethod
- def _lexical_evidence(product_text: str, label: str) -> float:
- """Calculate lexical overlap between product text and label."""
- pt = product_text.lower()
- tokens = [t for t in label.lower().replace("-", " ").split() if t]
- if not tokens:
- return 0.0
- hits = sum(1 for t in tokens if t in pt)
- return hits / len(tokens)
- @staticmethod
- def normalize_against_product_text(
- product_text: str,
- mandatory_attrs: Dict[str, List[str]],
- source_map: Dict[str, str],
- threshold_abs: float = 0.65,
- margin: float = 0.15,
- allow_multiple: bool = False,
- sem_weight: float = 0.8,
- lex_weight: float = 0.2,
- extracted_attrs: Optional[Dict[str, List[Dict[str, str]]]] = None,
- relationships: Optional[Dict[str, float]] = None,
- use_dynamic_thresholds: bool = True,
- use_adaptive_margin: bool = True,
- use_semantic_clustering: bool = True
- ) -> dict:
- """Score each allowed value against the product_text."""
- if extracted_attrs is None:
- extracted_attrs = {}
- if relationships is None:
- relationships = {}
- pt_emb = EmbeddingCache.get_embedding(product_text, model_embedder)
- extracted = {}
- for attr, allowed_values in mandatory_attrs.items():
- scores: List[Tuple[str, float]] = []
-
- is_dimension_attr = any(keyword in attr.lower() for keyword in ['dimension', 'size', 'measurement'])
- normalized_product_text = ProductAttributeService.normalize_dimension_text(product_text) if is_dimension_attr else ""
- for val in allowed_values:
- if is_dimension_attr:
- normalized_val = ProductAttributeService.normalize_dimension_text(val)
-
- if normalized_val and normalized_product_text and normalized_val == normalized_product_text:
- scores.append((val, 1.0))
- continue
-
- if normalized_val:
- val_numbers = normalized_val.split('x')
- text_lower = product_text.lower()
- if all(num in text_lower for num in val_numbers):
- idx1 = text_lower.find(val_numbers[0])
- idx2 = text_lower.find(val_numbers[1])
- if idx1 != -1 and idx2 != -1:
- distance = abs(idx2 - idx1)
- if distance < 20:
- scores.append((val, 0.95))
- continue
-
- contexts = [val, f"for {val}", f"use in {val}", f"suitable for {val}", f"{val} room"]
- ctx_embs = [EmbeddingCache.get_embedding(c, model_embedder) for c in contexts]
- sem_sim = max(float(util.cos_sim(pt_emb, ce).item()) for ce in ctx_embs)
- lex_score = ProductAttributeService._lexical_evidence(product_text, val)
- final_score = sem_weight * sem_sim + lex_weight * lex_score
- scores.append((val, final_score))
- scores.sort(key=lambda x: x[1], reverse=True)
- best_val, best_score = scores[0]
- effective_margin = margin
- if allow_multiple and use_adaptive_margin:
- effective_margin = ProductAttributeService.get_adaptive_margin(scores, margin)
- if is_dimension_attr and best_score >= 0.90:
- source = ProductAttributeService.find_value_source(best_val, source_map, attr)
- extracted[attr] = [{"value": best_val, "source": source}]
- continue
- if not allow_multiple:
- source = ProductAttributeService.find_value_source(best_val, source_map, attr)
- extracted[attr] = [{"value": best_val, "source": source}]
- else:
- candidates = [best_val]
- use_base_threshold = best_score >= threshold_abs
- clusters = []
- if use_semantic_clustering:
- clusters = ProductAttributeService.calculate_value_clusters(
- allowed_values, scores, cluster_threshold=0.4
- )
- best_cluster = next((c for c in clusters if best_val in c), [best_val])
- for val, sc in scores[1:]:
- min_score = 0.4 if is_dimension_attr else 0.3
- if sc < min_score:
- continue
-
- if use_dynamic_thresholds and extracted_attrs:
- dynamic_thresh = ProductAttributeService.get_dynamic_threshold(
- attr, val, sc, extracted_attrs, relationships,
- mandatory_attrs, threshold_abs
- )
- else:
- dynamic_thresh = threshold_abs
- within_margin = (best_score - sc) <= effective_margin
- above_threshold = sc >= dynamic_thresh
- in_cluster = False
- if use_semantic_clustering and clusters:
- in_cluster = any(best_val in c and val in c for c in clusters)
- if use_base_threshold:
- if above_threshold and within_margin:
- candidates.append(val)
- elif in_cluster and within_margin:
- candidates.append(val)
- else:
- if within_margin:
- candidates.append(val)
- elif in_cluster and (best_score - sc) <= effective_margin * 2.0:
- candidates.append(val)
- extracted[attr] = []
- for candidate in candidates:
- source = ProductAttributeService.find_value_source(candidate, source_map, attr)
- extracted[attr].append({"value": candidate, "source": source})
- return extracted
- @staticmethod
- def extract_attributes(
- product_text: str,
- mandatory_attrs: Dict[str, List[str]],
- source_map: Dict[str, str] = None,
- model: str = None,
- extract_additional: bool = True,
- multiple: Optional[List[str]] = None,
- threshold_abs: float = 0.65,
- margin: float = 0.15,
- use_dynamic_thresholds: bool = True,
- use_adaptive_margin: bool = True,
- use_semantic_clustering: bool = True,
- use_cache: bool = None # ⚡ NEW: Can override global setting
- ) -> dict:
- """Extract attributes from product text using Groq LLM."""
-
- if model is None:
- model = settings.SUPPORTED_MODELS[0]
- if multiple is None:
- multiple = []
- if source_map is None:
- source_map = {}
- # ⚡ CACHE CONTROL: use parameter if provided, otherwise use global setting
- if use_cache is None:
- use_cache = ENABLE_ATTRIBUTE_EXTRACTION_CACHE
-
- # If caching is globally disabled, force use_cache to False
- if not is_caching_enabled():
- use_cache = False
- if not product_text or product_text == "No product information available":
- return ProductAttributeService._create_error_response(
- "No product information provided",
- mandatory_attrs,
- extract_additional
- )
- # ⚡ CHECK CACHE FIRST (only if enabled)
- if use_cache:
- cache_key = ProductAttributeService._generate_cache_key(product_text, mandatory_attrs)
- cached_result = SimpleCache.get(cache_key)
- if cached_result:
- logger.info(f"✓ Cache hit (caching enabled)")
- return cached_result
- else:
- logger.info(f"⚠ Cache disabled - processing fresh")
- mandatory_attr_list = []
- for attr_name, allowed_values in mandatory_attrs.items():
- mandatory_attr_list.append(f"{attr_name}: {', '.join(allowed_values)}")
- mandatory_attr_text = "\n".join(mandatory_attr_list)
- additional_instruction = ""
- if extract_additional:
- additional_instruction = """
- 2. Extract ADDITIONAL attributes: Identify any other relevant attributes from the product text
- that are NOT in the mandatory list. Only include attributes where you can find actual values
- in the product text. Do NOT include attributes with "Not Specified" or empty values.
-
- Examples of attributes to look for (only if present): Brand, Material, Size, Color, Dimensions,
- Weight, Features, Style, Theme, Pattern, Finish, Care Instructions, etc."""
- output_format = {
- "mandatory": {attr: "value or list of values" for attr in mandatory_attrs.keys()},
- }
- if extract_additional:
- output_format["additional"] = {
- "example_attribute_1": "actual value found",
- "example_attribute_2": "actual value found"
- }
- output_format["additional"]["_note"] = "Only include attributes with actual values found in text"
- prompt = f"""
- You are an intelligent product attribute extractor that works with ANY product type.
- TASK:
- 1. Extract MANDATORY attributes: For each mandatory attribute, select the most appropriate value(s)
- from the provided list. Choose the value(s) that best match the product description.
- {additional_instruction}
- Product Text:
- {product_text}
- Mandatory Attribute Lists (MUST select from these allowed values):
- {mandatory_attr_text}
- CRITICAL INSTRUCTIONS:
- - Return ONLY valid JSON, nothing else
- - No explanations, no markdown, no text before or after the JSON
- - For mandatory attributes, choose the value(s) from the provided list that best match
- - If a mandatory attribute cannot be determined from the product text, use "Not Specified"
- - Prefer exact matches from the allowed values list over generic synonyms
- - If multiple values are plausible, you MAY return more than one
- {f"- For additional attributes: ONLY include attributes where you found actual values in the product text. DO NOT include attributes with 'Not Specified', 'None', 'N/A', or empty values. If you cannot find a value for an attribute, simply don't include that attribute." if extract_additional else ""}
- - Be precise and only extract information that is explicitly stated or clearly implied
- Required Output Format:
- {json.dumps(output_format, indent=2)}
- """
- payload = {
- "model": model,
- "messages": [
- {
- "role": "system",
- "content": f"You are a precise attribute extraction model. Return ONLY valid JSON with {'mandatory and additional' if extract_additional else 'mandatory'} sections. No explanations, no markdown, no other text."
- },
- {"role": "user", "content": prompt}
- ],
- "temperature": 0.0,
- "max_tokens": 1500
- }
- headers = {
- "Authorization": f"Bearer {settings.GROQ_API_KEY}",
- "Content-Type": "application/json",
- }
- try:
- response = requests.post(
- settings.GROQ_API_URL,
- headers=headers,
- json=payload,
- timeout=30
- )
- response.raise_for_status()
- result_text = response.json()["choices"][0]["message"]["content"].strip()
- result_text = ProductAttributeService._clean_json_response(result_text)
- parsed = json.loads(result_text)
- parsed = ProductAttributeService._validate_response_structure(
- parsed, mandatory_attrs, extract_additional, source_map
- )
- if extract_additional and "additional" in parsed:
- cleaned_additional = {}
- for k, v in parsed["additional"].items():
- if v and v not in ["Not Specified", "None", "N/A", "", "not specified", "none", "n/a"]:
- if not (isinstance(v, str) and v.lower() in ["not specified", "none", "n/a", ""]):
- if isinstance(v, list):
- cleaned_additional[k] = []
- for item in v:
- if isinstance(item, dict) and "value" in item:
- if "source" not in item:
- item["source"] = ProductAttributeService.find_value_source(
- item["value"], source_map, k
- )
- cleaned_additional[k].append(item)
- else:
- source = ProductAttributeService.find_value_source(str(item), source_map, k)
- cleaned_additional[k].append({"value": str(item), "source": source})
- else:
- source = ProductAttributeService.find_value_source(str(v), source_map, k)
- cleaned_additional[k] = [{"value": str(v), "source": source}]
- parsed["additional"] = cleaned_additional
- relationships = {}
- if use_dynamic_thresholds:
- relationships = ProductAttributeService.calculate_attribute_relationships(
- mandatory_attrs, product_text
- )
- extracted_so_far = {}
- for attr in mandatory_attrs.keys():
- allow_multiple = attr in multiple
- result = ProductAttributeService.normalize_against_product_text(
- product_text=product_text,
- mandatory_attrs={attr: mandatory_attrs[attr]},
- source_map=source_map,
- threshold_abs=threshold_abs,
- margin=margin,
- allow_multiple=allow_multiple,
- extracted_attrs=extracted_so_far,
- relationships=relationships,
- use_dynamic_thresholds=use_dynamic_thresholds,
- use_adaptive_margin=use_adaptive_margin,
- use_semantic_clustering=use_semantic_clustering
- )
- parsed["mandatory"][attr] = result[attr]
- extracted_so_far[attr] = result[attr]
- # ⚡ CACHE THE RESULT (only if enabled)
- if use_cache:
- SimpleCache.set(cache_key, parsed)
- logger.info(f"✓ Result cached")
- return parsed
- except requests.exceptions.RequestException as e:
- logger.error(f"Request exception: {str(e)}")
- return ProductAttributeService._create_error_response(
- str(e), mandatory_attrs, extract_additional
- )
- except json.JSONDecodeError as e:
- logger.error(f"JSON decode error: {str(e)}")
- return ProductAttributeService._create_error_response(
- f"Invalid JSON: {str(e)}", mandatory_attrs, extract_additional, result_text
- )
- except Exception as e:
- logger.error(f"Unexpected error: {str(e)}")
- return ProductAttributeService._create_error_response(
- str(e), mandatory_attrs, extract_additional
- )
- @staticmethod
- def _clean_json_response(text: str) -> str:
- """Clean LLM response to extract valid JSON."""
- start_idx = text.find('{')
- end_idx = text.rfind('}')
- if start_idx != -1 and end_idx != -1:
- text = text[start_idx:end_idx + 1]
- if "```json" in text:
- text = text.split("```json")[1].split("```")[0].strip()
- elif "```" in text:
- text = text.split("```")[1].split("```")[0].strip()
- if text.startswith("json"):
- text = text[4:].strip()
- return text
- @staticmethod
- def _validate_response_structure(
- parsed: dict,
- mandatory_attrs: Dict[str, List[str]],
- extract_additional: bool,
- source_map: Dict[str, str] = None
- ) -> dict:
- """Validate and fix the response structure."""
- if source_map is None:
- source_map = {}
-
- expected_sections = ["mandatory"]
- if extract_additional:
- expected_sections.append("additional")
- if not all(section in parsed for section in expected_sections):
- if isinstance(parsed, dict):
- mandatory_keys = set(mandatory_attrs.keys())
- mandatory = {k: v for k, v in parsed.items() if k in mandatory_keys}
- additional = {k: v for k, v in parsed.items() if k not in mandatory_keys}
- result = {"mandatory": mandatory}
- if extract_additional:
- result["additional"] = additional
- parsed = result
- else:
- return ProductAttributeService._create_error_response(
- "Invalid response structure",
- mandatory_attrs,
- extract_additional,
- str(parsed)
- )
- if "mandatory" in parsed:
- converted_mandatory = {}
- for attr, value in parsed["mandatory"].items():
- if isinstance(value, list):
- converted_mandatory[attr] = []
- for item in value:
- if isinstance(item, dict) and "value" in item:
- if "source" not in item:
- item["source"] = ProductAttributeService.find_value_source(
- item["value"], source_map, attr
- )
- converted_mandatory[attr].append(item)
- else:
- source = ProductAttributeService.find_value_source(str(item), source_map, attr)
- converted_mandatory[attr].append({"value": str(item), "source": source})
- else:
- source = ProductAttributeService.find_value_source(str(value), source_map, attr)
- converted_mandatory[attr] = [{"value": str(value), "source": source}]
-
- parsed["mandatory"] = converted_mandatory
- return parsed
- @staticmethod
- def _create_error_response(
- error: str,
- mandatory_attrs: Dict[str, List[str]],
- extract_additional: bool,
- raw_output: Optional[str] = None
- ) -> dict:
- """Create a standardized error response."""
- response = {
- "mandatory": {attr: [{"value": "Not Specified", "source": "error"}] for attr in mandatory_attrs.keys()},
- "error": error
- }
- if extract_additional:
- response["additional"] = {}
- if raw_output:
- response["raw_output"] = raw_output
- return response
- @staticmethod
- def get_cache_stats() -> Dict:
- """Get statistics for all caches including global status."""
- return {
- "global_caching_enabled": is_caching_enabled(),
- "simple_cache": SimpleCache.get_stats(),
- "embedding_cache": EmbeddingCache.get_stats()
- }
- @staticmethod
- def clear_all_caches():
- """Clear all caches."""
- SimpleCache.clear()
- EmbeddingCache.clear()
- logger.info("All caches cleared")
|