# ==================== services.py (WITH CACHE CONTROL) ==================== import requests import json import re import hashlib import logging from typing import Dict, List, Optional, Tuple from django.conf import settings from concurrent.futures import ThreadPoolExecutor, as_completed from sentence_transformers import SentenceTransformer, util import numpy as np # ⚡ IMPORT CACHE CONFIGURATION from .cache_config import ( is_caching_enabled, ENABLE_ATTRIBUTE_EXTRACTION_CACHE, ENABLE_EMBEDDING_CACHE, ATTRIBUTE_CACHE_MAX_SIZE, EMBEDDING_CACHE_MAX_SIZE ) logger = logging.getLogger(__name__) # ⚡ CRITICAL FIX: Initialize embedding model ONCE at module level print("Loading sentence transformer model (one-time initialization)...") model_embedder = SentenceTransformer("all-MiniLM-L6-v2") # Disable progress bars to prevent "Batches: 100%" spam import os os.environ['TOKENIZERS_PARALLELISM'] = 'false' print("✓ Model loaded successfully") # ==================== CACHING CLASSES ==================== class SimpleCache: """In-memory cache for attribute extraction results.""" _cache = {} _max_size = ATTRIBUTE_CACHE_MAX_SIZE @classmethod def get(cls, key: str) -> Optional[Dict]: """Get value from cache. Returns None if caching is disabled.""" if not ENABLE_ATTRIBUTE_EXTRACTION_CACHE: return None return cls._cache.get(key) @classmethod def set(cls, key: str, value: Dict): """Set value in cache. Does nothing if caching is disabled.""" if not ENABLE_ATTRIBUTE_EXTRACTION_CACHE: return if len(cls._cache) >= cls._max_size: items = list(cls._cache.items()) cls._cache = dict(items[int(cls._max_size * 0.2):]) cls._cache[key] = value @classmethod def clear(cls): """Clear the cache.""" cls._cache.clear() @classmethod def get_stats(cls) -> Dict: """Get cache statistics.""" return { "enabled": ENABLE_ATTRIBUTE_EXTRACTION_CACHE, "size": len(cls._cache), "max_size": cls._max_size, "usage_percent": round(len(cls._cache) / cls._max_size * 100, 2) if cls._max_size > 0 else 0 } class EmbeddingCache: """Cache for sentence transformer embeddings.""" _cache = {} _max_size = EMBEDDING_CACHE_MAX_SIZE _hit_count = 0 _miss_count = 0 @classmethod def get_embedding(cls, text: str, model): """Get or compute embedding with optional caching""" # If caching is disabled, always compute fresh if not ENABLE_EMBEDDING_CACHE: import warnings with warnings.catch_warnings(): warnings.simplefilter("ignore") embedding = model.encode(text, convert_to_tensor=True, show_progress_bar=False) return embedding # Caching is enabled, check cache first if text in cls._cache: cls._hit_count += 1 return cls._cache[text] cls._miss_count += 1 if len(cls._cache) >= cls._max_size: items = list(cls._cache.items()) cls._cache = dict(items[int(cls._max_size * 0.3):]) # Compute embedding import warnings with warnings.catch_warnings(): warnings.simplefilter("ignore") embedding = model.encode(text, convert_to_tensor=True, show_progress_bar=False) cls._cache[text] = embedding return embedding @classmethod def clear(cls): """Clear the cache and reset statistics.""" cls._cache.clear() cls._hit_count = 0 cls._miss_count = 0 @classmethod def get_stats(cls) -> Dict: """Get cache statistics.""" total = cls._hit_count + cls._miss_count hit_rate = (cls._hit_count / total * 100) if total > 0 else 0 return { "enabled": ENABLE_EMBEDDING_CACHE, "size": len(cls._cache), "hits": cls._hit_count, "misses": cls._miss_count, "hit_rate_percent": round(hit_rate, 2) } # ==================== MAIN SERVICE CLASS ==================== class ProductAttributeService: """Service class for extracting product attributes using Groq LLM.""" @staticmethod def _generate_cache_key(product_text: str, mandatory_attrs: Dict) -> str: """Generate cache key from product text and attributes.""" attrs_str = json.dumps(mandatory_attrs, sort_keys=True) content = f"{product_text}:{attrs_str}" return f"attr_{hashlib.md5(content.encode()).hexdigest()}" @staticmethod def normalize_dimension_text(text: str) -> str: """Normalize dimension text to format like '16x20'.""" if not text: return "" text = text.lower() text = re.sub(r'\s*(inches|inch|in|cm|centimeters|mm|millimeters)\s*', '', text, flags=re.IGNORECASE) numbers = re.findall(r'\d+\.?\d*', text) if not numbers: return "" float_numbers = [] for num in numbers: try: float_numbers.append(float(num)) except: continue if len(float_numbers) < 2: return "" if len(float_numbers) == 3: float_numbers = [float_numbers[0], float_numbers[2]] elif len(float_numbers) > 3: float_numbers = sorted(float_numbers)[-2:] else: float_numbers = float_numbers[:2] formatted_numbers = [] for num in float_numbers: if num.is_integer(): formatted_numbers.append(str(int(num))) else: formatted_numbers.append(f"{num:.1f}") formatted_numbers.sort(key=lambda x: float(x)) return f"{formatted_numbers[0]}x{formatted_numbers[1]}" @staticmethod def normalize_value_for_matching(value: str, attr_name: str = "") -> str: """Normalize a value based on its attribute type.""" dimension_keywords = ['dimension', 'size', 'measurement'] if any(keyword in attr_name.lower() for keyword in dimension_keywords): normalized = ProductAttributeService.normalize_dimension_text(value) if normalized: return normalized return value.strip() @staticmethod def combine_product_text( title: Optional[str] = None, short_desc: Optional[str] = None, long_desc: Optional[str] = None, ocr_text: Optional[str] = None ) -> Tuple[str, Dict[str, str]]: """Combine product metadata into a single text block.""" parts = [] source_map = {} if title: title_str = str(title).strip() parts.append(f"Title: {title_str}") source_map['title'] = title_str if short_desc: short_str = str(short_desc).strip() parts.append(f"Description: {short_str}") source_map['short_desc'] = short_str if long_desc: long_str = str(long_desc).strip() parts.append(f"Details: {long_str}") source_map['long_desc'] = long_str if ocr_text: parts.append(f"OCR Text: {ocr_text}") source_map['ocr_text'] = ocr_text combined = "\n".join(parts).strip() if not combined: return "No product information available", {} return combined, source_map @staticmethod def find_value_source(value: str, source_map: Dict[str, str], attr_name: str = "") -> str: """Find which source(s) contain the given value.""" value_lower = value.lower() value_tokens = set(value_lower.replace("-", " ").replace("x", " ").split()) is_dimension_attr = any(keyword in attr_name.lower() for keyword in ['dimension', 'size', 'measurement']) sources_found = [] source_scores = {} for source_name, source_text in source_map.items(): source_lower = source_text.lower() if value_lower in source_lower: source_scores[source_name] = 1.0 continue if is_dimension_attr: normalized_value = ProductAttributeService.normalize_dimension_text(value) if not normalized_value: normalized_value = value.replace("x", " ").strip() normalized_source = ProductAttributeService.normalize_dimension_text(source_text) if normalized_value == normalized_source: source_scores[source_name] = 0.95 continue dim_parts = normalized_value.split("x") if "x" in normalized_value else [] if len(dim_parts) == 2: if all(part in source_text for part in dim_parts): source_scores[source_name] = 0.85 continue token_matches = sum(1 for token in value_tokens if token and token in source_lower) if token_matches > 0 and len(value_tokens) > 0: source_scores[source_name] = token_matches / len(value_tokens) if source_scores: max_score = max(source_scores.values()) sources_found = [s for s, score in source_scores.items() if score == max_score] priority = ['title', 'short_desc', 'long_desc', 'ocr_text'] for p in priority: if p in sources_found: return p return sources_found[0] if sources_found else "Not found" return "Not found" @staticmethod def format_visual_attributes(visual_attributes: Dict) -> Dict: """Convert visual attributes to array format with source tracking.""" formatted = {} for key, value in visual_attributes.items(): if isinstance(value, list): formatted[key] = [{"value": str(item), "source": "image"} for item in value] elif isinstance(value, dict): nested_formatted = {} for nested_key, nested_value in value.items(): if isinstance(nested_value, list): nested_formatted[nested_key] = [{"value": str(item), "source": "image"} for item in nested_value] else: nested_formatted[nested_key] = [{"value": str(nested_value), "source": "image"}] formatted[key] = nested_formatted else: formatted[key] = [{"value": str(value), "source": "image"}] return formatted @staticmethod def extract_attributes_from_ocr(ocr_results: Dict, model: str = None) -> Dict: """Extract structured attributes from OCR text using LLM.""" if model is None: model = settings.SUPPORTED_MODELS[0] detected_text = ocr_results.get('detected_text', []) if not detected_text: return {} ocr_text = "\n".join([f"Text: {item['text']}, Confidence: {item['confidence']:.2f}" for item in detected_text]) prompt = f""" You are an AI model that extracts structured attributes from OCR text detected on product images. Given the OCR detections below, infer the possible product attributes and return them as a clean JSON object. OCR Text: {ocr_text} Extract relevant attributes like: - brand - model_number - size (waist_size, length, etc.) - collection - any other relevant product information Return a JSON object with only the attributes you can confidently identify. If an attribute is not present, do not include it in the response. """ payload = { "model": model, "messages": [ { "role": "system", "content": "You are a helpful AI that extracts structured data from OCR output. Return only valid JSON." }, {"role": "user", "content": prompt} ], "temperature": 0.2, "max_tokens": 500 } headers = { "Authorization": f"Bearer {settings.GROQ_API_KEY}", "Content-Type": "application/json", } try: response = requests.post( settings.GROQ_API_URL, headers=headers, json=payload, timeout=30 ) response.raise_for_status() result_text = response.json()["choices"][0]["message"]["content"].strip() result_text = ProductAttributeService._clean_json_response(result_text) parsed = json.loads(result_text) formatted_attributes = {} for key, value in parsed.items(): if key == "error": continue if isinstance(value, dict): nested_formatted = {} for nested_key, nested_value in value.items(): nested_formatted[nested_key] = [{"value": str(nested_value), "source": "image"}] formatted_attributes[key] = nested_formatted elif isinstance(value, list): formatted_attributes[key] = [{"value": str(item), "source": "image"} for item in value] else: formatted_attributes[key] = [{"value": str(value), "source": "image"}] return formatted_attributes except Exception as e: logger.error(f"OCR attribute extraction failed: {str(e)}") return {"error": f"Failed to extract attributes from OCR: {str(e)}"} @staticmethod def calculate_attribute_relationships( mandatory_attrs: Dict[str, List[str]], product_text: str ) -> Dict[str, float]: """Calculate semantic relationships between attribute values.""" pt_emb = EmbeddingCache.get_embedding(product_text, model_embedder) attr_scores = {} for attr, values in mandatory_attrs.items(): attr_scores[attr] = {} for val in values: contexts = [val, f"for {val}", f"use in {val}", f"suitable for {val}"] ctx_embs = [EmbeddingCache.get_embedding(c, model_embedder) for c in contexts] sem_sim = max(float(util.cos_sim(pt_emb, ce).item()) for ce in ctx_embs) attr_scores[attr][val] = sem_sim relationships = {} attr_list = list(mandatory_attrs.keys()) for i, attr1 in enumerate(attr_list): for attr2 in attr_list[i+1:]: for val1 in mandatory_attrs[attr1]: for val2 in mandatory_attrs[attr2]: emb1 = EmbeddingCache.get_embedding(val1, model_embedder) emb2 = EmbeddingCache.get_embedding(val2, model_embedder) sim = float(util.cos_sim(emb1, emb2).item()) key1 = f"{attr1}:{val1}->{attr2}:{val2}" key2 = f"{attr2}:{val2}->{attr1}:{val1}" relationships[key1] = sim relationships[key2] = sim return relationships @staticmethod def calculate_value_clusters( values: List[str], scores: List[Tuple[str, float]], cluster_threshold: float = 0.4 ) -> List[List[str]]: """Group values into semantic clusters.""" if len(values) <= 1: return [[val] for val, _ in scores] embeddings = [EmbeddingCache.get_embedding(val, model_embedder) for val in values] similarity_matrix = np.zeros((len(values), len(values))) for i in range(len(values)): for j in range(i+1, len(values)): sim = float(util.cos_sim(embeddings[i], embeddings[j]).item()) similarity_matrix[i][j] = sim similarity_matrix[j][i] = sim clusters = [] visited = set() for i, (val, score) in enumerate(scores): if i in visited: continue cluster = [val] visited.add(i) for j in range(len(values)): if j not in visited and similarity_matrix[i][j] >= cluster_threshold: cluster.append(values[j]) visited.add(j) clusters.append(cluster) return clusters @staticmethod def get_dynamic_threshold( attr: str, val: str, base_score: float, extracted_attrs: Dict[str, List[Dict[str, str]]], relationships: Dict[str, float], mandatory_attrs: Dict[str, List[str]], base_threshold: float = 0.65, boost_factor: float = 0.15 ) -> float: """Calculate dynamic threshold based on relationships.""" threshold = base_threshold max_relationship = 0.0 for other_attr, other_values_list in extracted_attrs.items(): if other_attr == attr: continue for other_val_dict in other_values_list: other_val = other_val_dict['value'] key = f"{attr}:{val}->{other_attr}:{other_val}" if key in relationships: max_relationship = max(max_relationship, relationships[key]) if max_relationship > 0.6: threshold = base_threshold - (boost_factor * max_relationship) return max(0.3, threshold) @staticmethod def get_adaptive_margin( scores: List[Tuple[str, float]], base_margin: float = 0.15, max_margin: float = 0.22 ) -> float: """Calculate adaptive margin based on score distribution.""" if len(scores) < 2: return base_margin score_values = [s for _, s in scores] best_score = score_values[0] if best_score < 0.5: top_scores = score_values[:min(4, len(score_values))] score_range = max(top_scores) - min(top_scores) if score_range < 0.30: score_factor = (0.5 - best_score) * 0.35 adaptive = base_margin + score_factor + (0.30 - score_range) * 0.2 return min(adaptive, max_margin) return base_margin @staticmethod def _lexical_evidence(product_text: str, label: str) -> float: """Calculate lexical overlap between product text and label.""" pt = product_text.lower() tokens = [t for t in label.lower().replace("-", " ").split() if t] if not tokens: return 0.0 hits = sum(1 for t in tokens if t in pt) return hits / len(tokens) @staticmethod def normalize_against_product_text( product_text: str, mandatory_attrs: Dict[str, List[str]], source_map: Dict[str, str], threshold_abs: float = 0.65, margin: float = 0.15, allow_multiple: bool = False, sem_weight: float = 0.8, lex_weight: float = 0.2, extracted_attrs: Optional[Dict[str, List[Dict[str, str]]]] = None, relationships: Optional[Dict[str, float]] = None, use_dynamic_thresholds: bool = True, use_adaptive_margin: bool = True, use_semantic_clustering: bool = True ) -> dict: """Score each allowed value against the product_text.""" if extracted_attrs is None: extracted_attrs = {} if relationships is None: relationships = {} pt_emb = EmbeddingCache.get_embedding(product_text, model_embedder) extracted = {} for attr, allowed_values in mandatory_attrs.items(): scores: List[Tuple[str, float]] = [] is_dimension_attr = any(keyword in attr.lower() for keyword in ['dimension', 'size', 'measurement']) normalized_product_text = ProductAttributeService.normalize_dimension_text(product_text) if is_dimension_attr else "" for val in allowed_values: if is_dimension_attr: normalized_val = ProductAttributeService.normalize_dimension_text(val) if normalized_val and normalized_product_text and normalized_val == normalized_product_text: scores.append((val, 1.0)) continue if normalized_val: val_numbers = normalized_val.split('x') text_lower = product_text.lower() if all(num in text_lower for num in val_numbers): idx1 = text_lower.find(val_numbers[0]) idx2 = text_lower.find(val_numbers[1]) if idx1 != -1 and idx2 != -1: distance = abs(idx2 - idx1) if distance < 20: scores.append((val, 0.95)) continue contexts = [val, f"for {val}", f"use in {val}", f"suitable for {val}", f"{val} room"] ctx_embs = [EmbeddingCache.get_embedding(c, model_embedder) for c in contexts] sem_sim = max(float(util.cos_sim(pt_emb, ce).item()) for ce in ctx_embs) lex_score = ProductAttributeService._lexical_evidence(product_text, val) final_score = sem_weight * sem_sim + lex_weight * lex_score scores.append((val, final_score)) scores.sort(key=lambda x: x[1], reverse=True) best_val, best_score = scores[0] effective_margin = margin if allow_multiple and use_adaptive_margin: effective_margin = ProductAttributeService.get_adaptive_margin(scores, margin) if is_dimension_attr and best_score >= 0.90: source = ProductAttributeService.find_value_source(best_val, source_map, attr) extracted[attr] = [{"value": best_val, "source": source}] continue if not allow_multiple: source = ProductAttributeService.find_value_source(best_val, source_map, attr) extracted[attr] = [{"value": best_val, "source": source}] else: candidates = [best_val] use_base_threshold = best_score >= threshold_abs clusters = [] if use_semantic_clustering: clusters = ProductAttributeService.calculate_value_clusters( allowed_values, scores, cluster_threshold=0.4 ) best_cluster = next((c for c in clusters if best_val in c), [best_val]) for val, sc in scores[1:]: min_score = 0.4 if is_dimension_attr else 0.3 if sc < min_score: continue if use_dynamic_thresholds and extracted_attrs: dynamic_thresh = ProductAttributeService.get_dynamic_threshold( attr, val, sc, extracted_attrs, relationships, mandatory_attrs, threshold_abs ) else: dynamic_thresh = threshold_abs within_margin = (best_score - sc) <= effective_margin above_threshold = sc >= dynamic_thresh in_cluster = False if use_semantic_clustering and clusters: in_cluster = any(best_val in c and val in c for c in clusters) if use_base_threshold: if above_threshold and within_margin: candidates.append(val) elif in_cluster and within_margin: candidates.append(val) else: if within_margin: candidates.append(val) elif in_cluster and (best_score - sc) <= effective_margin * 2.0: candidates.append(val) extracted[attr] = [] for candidate in candidates: source = ProductAttributeService.find_value_source(candidate, source_map, attr) extracted[attr].append({"value": candidate, "source": source}) return extracted @staticmethod def extract_attributes( product_text: str, mandatory_attrs: Dict[str, List[str]], source_map: Dict[str, str] = None, model: str = None, extract_additional: bool = True, multiple: Optional[List[str]] = None, threshold_abs: float = 0.65, margin: float = 0.15, use_dynamic_thresholds: bool = True, use_adaptive_margin: bool = True, use_semantic_clustering: bool = True, use_cache: bool = None # ⚡ NEW: Can override global setting ) -> dict: """Extract attributes from product text using Groq LLM.""" if model is None: model = settings.SUPPORTED_MODELS[0] if multiple is None: multiple = [] if source_map is None: source_map = {} # ⚡ CACHE CONTROL: use parameter if provided, otherwise use global setting if use_cache is None: use_cache = ENABLE_ATTRIBUTE_EXTRACTION_CACHE # If caching is globally disabled, force use_cache to False if not is_caching_enabled(): use_cache = False if not product_text or product_text == "No product information available": return ProductAttributeService._create_error_response( "No product information provided", mandatory_attrs, extract_additional ) # ⚡ CHECK CACHE FIRST (only if enabled) if use_cache: cache_key = ProductAttributeService._generate_cache_key(product_text, mandatory_attrs) cached_result = SimpleCache.get(cache_key) if cached_result: logger.info(f"✓ Cache hit (caching enabled)") return cached_result else: logger.info(f"⚠ Cache disabled - processing fresh") mandatory_attr_list = [] for attr_name, allowed_values in mandatory_attrs.items(): mandatory_attr_list.append(f"{attr_name}: {', '.join(allowed_values)}") mandatory_attr_text = "\n".join(mandatory_attr_list) additional_instruction = "" if extract_additional: additional_instruction = """ 2. Extract ADDITIONAL attributes: Identify any other relevant attributes from the product text that are NOT in the mandatory list. Only include attributes where you can find actual values in the product text. Do NOT include attributes with "Not Specified" or empty values. Examples of attributes to look for (only if present): Brand, Material, Size, Color, Dimensions, Weight, Features, Style, Theme, Pattern, Finish, Care Instructions, etc.""" output_format = { "mandatory": {attr: "value or list of values" for attr in mandatory_attrs.keys()}, } if extract_additional: output_format["additional"] = { "example_attribute_1": "actual value found", "example_attribute_2": "actual value found" } output_format["additional"]["_note"] = "Only include attributes with actual values found in text" prompt = f""" You are an intelligent product attribute extractor that works with ANY product type. TASK: 1. Extract MANDATORY attributes: For each mandatory attribute, select the most appropriate value(s) from the provided list. Choose the value(s) that best match the product description. {additional_instruction} Product Text: {product_text} Mandatory Attribute Lists (MUST select from these allowed values): {mandatory_attr_text} CRITICAL INSTRUCTIONS: - Return ONLY valid JSON, nothing else - No explanations, no markdown, no text before or after the JSON - For mandatory attributes, choose the value(s) from the provided list that best match - If a mandatory attribute cannot be determined from the product text, use "Not Specified" - Prefer exact matches from the allowed values list over generic synonyms - If multiple values are plausible, you MAY return more than one {f"- For additional attributes: ONLY include attributes where you found actual values in the product text. DO NOT include attributes with 'Not Specified', 'None', 'N/A', or empty values. If you cannot find a value for an attribute, simply don't include that attribute." if extract_additional else ""} - Be precise and only extract information that is explicitly stated or clearly implied Required Output Format: {json.dumps(output_format, indent=2)} """ payload = { "model": model, "messages": [ { "role": "system", "content": f"You are a precise attribute extraction model. Return ONLY valid JSON with {'mandatory and additional' if extract_additional else 'mandatory'} sections. No explanations, no markdown, no other text." }, {"role": "user", "content": prompt} ], "temperature": 0.0, "max_tokens": 1500 } headers = { "Authorization": f"Bearer {settings.GROQ_API_KEY}", "Content-Type": "application/json", } try: response = requests.post( settings.GROQ_API_URL, headers=headers, json=payload, timeout=30 ) response.raise_for_status() result_text = response.json()["choices"][0]["message"]["content"].strip() result_text = ProductAttributeService._clean_json_response(result_text) parsed = json.loads(result_text) parsed = ProductAttributeService._validate_response_structure( parsed, mandatory_attrs, extract_additional, source_map ) if extract_additional and "additional" in parsed: cleaned_additional = {} for k, v in parsed["additional"].items(): if v and v not in ["Not Specified", "None", "N/A", "", "not specified", "none", "n/a"]: if not (isinstance(v, str) and v.lower() in ["not specified", "none", "n/a", ""]): if isinstance(v, list): cleaned_additional[k] = [] for item in v: if isinstance(item, dict) and "value" in item: if "source" not in item: item["source"] = ProductAttributeService.find_value_source( item["value"], source_map, k ) cleaned_additional[k].append(item) else: source = ProductAttributeService.find_value_source(str(item), source_map, k) cleaned_additional[k].append({"value": str(item), "source": source}) else: source = ProductAttributeService.find_value_source(str(v), source_map, k) cleaned_additional[k] = [{"value": str(v), "source": source}] parsed["additional"] = cleaned_additional relationships = {} if use_dynamic_thresholds: relationships = ProductAttributeService.calculate_attribute_relationships( mandatory_attrs, product_text ) extracted_so_far = {} for attr in mandatory_attrs.keys(): allow_multiple = attr in multiple result = ProductAttributeService.normalize_against_product_text( product_text=product_text, mandatory_attrs={attr: mandatory_attrs[attr]}, source_map=source_map, threshold_abs=threshold_abs, margin=margin, allow_multiple=allow_multiple, extracted_attrs=extracted_so_far, relationships=relationships, use_dynamic_thresholds=use_dynamic_thresholds, use_adaptive_margin=use_adaptive_margin, use_semantic_clustering=use_semantic_clustering ) parsed["mandatory"][attr] = result[attr] extracted_so_far[attr] = result[attr] # ⚡ CACHE THE RESULT (only if enabled) if use_cache: SimpleCache.set(cache_key, parsed) logger.info(f"✓ Result cached") return parsed except requests.exceptions.RequestException as e: logger.error(f"Request exception: {str(e)}") return ProductAttributeService._create_error_response( str(e), mandatory_attrs, extract_additional ) except json.JSONDecodeError as e: logger.error(f"JSON decode error: {str(e)}") return ProductAttributeService._create_error_response( f"Invalid JSON: {str(e)}", mandatory_attrs, extract_additional, result_text ) except Exception as e: logger.error(f"Unexpected error: {str(e)}") return ProductAttributeService._create_error_response( str(e), mandatory_attrs, extract_additional ) @staticmethod def _clean_json_response(text: str) -> str: """Clean LLM response to extract valid JSON.""" start_idx = text.find('{') end_idx = text.rfind('}') if start_idx != -1 and end_idx != -1: text = text[start_idx:end_idx + 1] if "```json" in text: text = text.split("```json")[1].split("```")[0].strip() elif "```" in text: text = text.split("```")[1].split("```")[0].strip() if text.startswith("json"): text = text[4:].strip() return text @staticmethod def _validate_response_structure( parsed: dict, mandatory_attrs: Dict[str, List[str]], extract_additional: bool, source_map: Dict[str, str] = None ) -> dict: """Validate and fix the response structure.""" if source_map is None: source_map = {} expected_sections = ["mandatory"] if extract_additional: expected_sections.append("additional") if not all(section in parsed for section in expected_sections): if isinstance(parsed, dict): mandatory_keys = set(mandatory_attrs.keys()) mandatory = {k: v for k, v in parsed.items() if k in mandatory_keys} additional = {k: v for k, v in parsed.items() if k not in mandatory_keys} result = {"mandatory": mandatory} if extract_additional: result["additional"] = additional parsed = result else: return ProductAttributeService._create_error_response( "Invalid response structure", mandatory_attrs, extract_additional, str(parsed) ) if "mandatory" in parsed: converted_mandatory = {} for attr, value in parsed["mandatory"].items(): if isinstance(value, list): converted_mandatory[attr] = [] for item in value: if isinstance(item, dict) and "value" in item: if "source" not in item: item["source"] = ProductAttributeService.find_value_source( item["value"], source_map, attr ) converted_mandatory[attr].append(item) else: source = ProductAttributeService.find_value_source(str(item), source_map, attr) converted_mandatory[attr].append({"value": str(item), "source": source}) else: source = ProductAttributeService.find_value_source(str(value), source_map, attr) converted_mandatory[attr] = [{"value": str(value), "source": source}] parsed["mandatory"] = converted_mandatory return parsed @staticmethod def _create_error_response( error: str, mandatory_attrs: Dict[str, List[str]], extract_additional: bool, raw_output: Optional[str] = None ) -> dict: """Create a standardized error response.""" response = { "mandatory": {attr: [{"value": "Not Specified", "source": "error"}] for attr in mandatory_attrs.keys()}, "error": error } if extract_additional: response["additional"] = {} if raw_output: response["raw_output"] = raw_output return response @staticmethod def get_cache_stats() -> Dict: """Get statistics for all caches including global status.""" return { "global_caching_enabled": is_caching_enabled(), "simple_cache": SimpleCache.get_stats(), "embedding_cache": EmbeddingCache.get_stats() } @staticmethod def clear_all_caches(): """Clear all caches.""" SimpleCache.clear() EmbeddingCache.clear() logger.info("All caches cleared")