|
|
@@ -1,5 +1,4 @@
|
|
|
-
|
|
|
-# ==================== visual_processing_service.py (FIXED - Smart Subcategory Detection) ====================
|
|
|
+# ==================== visual_processing_service.py (WITH CACHE CONTROL) ====================
|
|
|
import torch
|
|
|
import numpy as np
|
|
|
import requests
|
|
|
@@ -10,18 +9,21 @@ import logging
|
|
|
from transformers import CLIPProcessor, CLIPModel
|
|
|
from sklearn.cluster import KMeans
|
|
|
|
|
|
+# ⚡ IMPORT CACHE CONFIGURATION
|
|
|
+from .cache_config import ENABLE_CLIP_MODEL_CACHE
|
|
|
+
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
import os
|
|
|
-os.environ['TOKENIZERS_PARALLELISM'] = 'false' # Disable tokenizer warnings
|
|
|
+os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
|
|
import warnings
|
|
|
-warnings.filterwarnings('ignore') # Suppress all warnings
|
|
|
+warnings.filterwarnings('ignore')
|
|
|
|
|
|
|
|
|
class VisualProcessingService:
|
|
|
"""Service for extracting visual attributes from product images using CLIP with smart subcategory detection."""
|
|
|
|
|
|
- # Class-level caching (shared across instances)
|
|
|
+ # ⚡ Class-level caching (controlled by cache_config)
|
|
|
_clip_model = None
|
|
|
_clip_processor = None
|
|
|
_device = None
|
|
|
@@ -129,7 +131,24 @@ class VisualProcessingService:
|
|
|
|
|
|
@classmethod
|
|
|
def _get_clip_model(cls):
|
|
|
- """Lazy load CLIP model with class-level caching."""
|
|
|
+ """
|
|
|
+ Lazy load CLIP model with optional class-level caching.
|
|
|
+ ⚡ If caching is disabled, model is still loaded but not persisted at class level.
|
|
|
+ """
|
|
|
+ # ⚡ CACHE CONTROL: If caching is disabled, always reload (no persistence)
|
|
|
+ if not ENABLE_CLIP_MODEL_CACHE:
|
|
|
+ logger.info("⚠ CLIP model caching is DISABLED - loading fresh instance")
|
|
|
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
|
|
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
|
|
+
|
|
|
+ device = cls._get_device()
|
|
|
+ model.to(device)
|
|
|
+ model.eval()
|
|
|
+
|
|
|
+ logger.info("✓ CLIP model loaded (no caching)")
|
|
|
+ return model, processor
|
|
|
+
|
|
|
+ # Caching is enabled - use class-level cache
|
|
|
if cls._clip_model is None:
|
|
|
logger.info("Loading CLIP model (this may take a few minutes on first use)...")
|
|
|
cls._clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
|
|
@@ -139,9 +158,24 @@ class VisualProcessingService:
|
|
|
cls._clip_model.to(device)
|
|
|
cls._clip_model.eval()
|
|
|
|
|
|
- logger.info("✓ CLIP model loaded successfully")
|
|
|
+ logger.info("✓ CLIP model loaded and cached successfully")
|
|
|
+ else:
|
|
|
+ logger.info("✓ Using cached CLIP model")
|
|
|
+
|
|
|
return cls._clip_model, cls._clip_processor
|
|
|
|
|
|
+ @classmethod
|
|
|
+ def clear_clip_cache(cls):
|
|
|
+ """Clear the cached CLIP model to free memory."""
|
|
|
+ if cls._clip_model is not None:
|
|
|
+ del cls._clip_model
|
|
|
+ del cls._clip_processor
|
|
|
+ cls._clip_model = None
|
|
|
+ cls._clip_processor = None
|
|
|
+ if torch.cuda.is_available():
|
|
|
+ torch.cuda.empty_cache()
|
|
|
+ logger.info("✓ CLIP model cache cleared")
|
|
|
+
|
|
|
def download_image(self, image_url: str) -> Optional[Image.Image]:
|
|
|
"""Download image from URL."""
|
|
|
try:
|
|
|
@@ -156,12 +190,10 @@ class VisualProcessingService:
|
|
|
def extract_dominant_colors(self, image: Image.Image, n_colors: int = 3) -> List[Dict]:
|
|
|
"""Extract dominant colors using K-means clustering."""
|
|
|
try:
|
|
|
- # Resize for faster processing
|
|
|
img_small = image.resize((150, 150))
|
|
|
img_array = np.array(img_small)
|
|
|
pixels = img_array.reshape(-1, 3)
|
|
|
|
|
|
- # K-means clustering
|
|
|
kmeans = KMeans(n_clusters=n_colors, random_state=42, n_init=5)
|
|
|
kmeans.fit(pixels)
|
|
|
|
|
|
@@ -179,7 +211,6 @@ class VisualProcessingService:
|
|
|
"percentage": round(percentage, 2)
|
|
|
})
|
|
|
|
|
|
- # Sort by percentage (most dominant first)
|
|
|
colors.sort(key=lambda x: x['percentage'], reverse=True)
|
|
|
return colors
|
|
|
|
|
|
@@ -191,7 +222,6 @@ class VisualProcessingService:
|
|
|
"""Map RGB values to basic color names."""
|
|
|
r, g, b = rgb
|
|
|
|
|
|
- # Define color ranges with priorities
|
|
|
colors = {
|
|
|
'black': (r < 50 and g < 50 and b < 50),
|
|
|
'white': (r > 200 and g > 200 and b > 200),
|
|
|
@@ -212,7 +242,6 @@ class VisualProcessingService:
|
|
|
if condition:
|
|
|
return color_name
|
|
|
|
|
|
- # Fallback to dominant channel
|
|
|
if r > g and r > b:
|
|
|
return 'red'
|
|
|
elif g > r and g > b:
|
|
|
@@ -234,14 +263,12 @@ class VisualProcessingService:
|
|
|
model, processor = self._get_clip_model()
|
|
|
device = self._get_device()
|
|
|
|
|
|
- # ⚡ OPTIMIZATION: Process in smaller batches to avoid memory issues
|
|
|
- batch_size = 16 # Process 16 candidates at a time
|
|
|
+ batch_size = 16
|
|
|
all_results = []
|
|
|
|
|
|
for i in range(0, len(candidates), batch_size):
|
|
|
batch_candidates = candidates[i:i + batch_size]
|
|
|
|
|
|
- # Prepare inputs WITHOUT progress bars
|
|
|
inputs = processor(
|
|
|
text=batch_candidates,
|
|
|
images=image,
|
|
|
@@ -249,16 +276,13 @@ class VisualProcessingService:
|
|
|
padding=True
|
|
|
)
|
|
|
|
|
|
- # Move to device
|
|
|
inputs = {k: v.to(device) for k, v in inputs.items()}
|
|
|
|
|
|
- # Get predictions
|
|
|
with torch.no_grad():
|
|
|
outputs = model(**inputs)
|
|
|
logits_per_image = outputs.logits_per_image
|
|
|
probs = logits_per_image.softmax(dim=1).cpu()
|
|
|
|
|
|
- # Collect results from this batch
|
|
|
for j, prob in enumerate(probs[0]):
|
|
|
if prob.item() > confidence_threshold:
|
|
|
all_results.append({
|
|
|
@@ -266,7 +290,6 @@ class VisualProcessingService:
|
|
|
"confidence": round(float(prob.item()), 3)
|
|
|
})
|
|
|
|
|
|
- # Sort by confidence and return top 3
|
|
|
all_results.sort(key=lambda x: x['confidence'], reverse=True)
|
|
|
|
|
|
return {
|
|
|
@@ -278,16 +301,11 @@ class VisualProcessingService:
|
|
|
logger.error(f"Error in CLIP classification for {attribute_name}: {str(e)}")
|
|
|
return {"attribute": attribute_name, "predictions": []}
|
|
|
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
def detect_category_and_subcategory(self, image: Image.Image) -> Tuple[str, str, str, float]:
|
|
|
"""
|
|
|
Hierarchically detect category, subcategory, and specific product.
|
|
|
Returns: (category, subcategory, product_type, confidence)
|
|
|
"""
|
|
|
- # Step 1: Detect if it's clothing or something else
|
|
|
main_categories = list(self.CATEGORY_ATTRIBUTES.keys())
|
|
|
category_prompts = [f"a photo of {cat}" for cat in main_categories]
|
|
|
|
|
|
@@ -301,11 +319,9 @@ class VisualProcessingService:
|
|
|
|
|
|
logger.info(f"Step 1 - Main category detected: {detected_category} (confidence: {category_confidence:.3f})")
|
|
|
|
|
|
- # Step 2: For clothing, detect subcategory (tops/bottoms/dresses/outerwear)
|
|
|
if detected_category == "clothing":
|
|
|
subcategories = self.CATEGORY_ATTRIBUTES["clothing"]["subcategories"]
|
|
|
|
|
|
- # Collect all products grouped by subcategory
|
|
|
all_products = []
|
|
|
product_to_subcategory = {}
|
|
|
|
|
|
@@ -315,7 +331,6 @@ class VisualProcessingService:
|
|
|
all_products.append(prompt)
|
|
|
product_to_subcategory[prompt] = subcat
|
|
|
|
|
|
- # Step 3: Detect specific product type
|
|
|
product_result = self.classify_with_clip(
|
|
|
image,
|
|
|
all_products,
|
|
|
@@ -336,11 +351,9 @@ class VisualProcessingService:
|
|
|
logger.warning("Could not detect specific product type for clothing")
|
|
|
return detected_category, "unknown", "unknown", category_confidence
|
|
|
|
|
|
- # Step 3: For non-clothing categories, just detect product type
|
|
|
else:
|
|
|
category_data = self.CATEGORY_ATTRIBUTES[detected_category]
|
|
|
|
|
|
- # Check if this category has subcategories or direct products
|
|
|
if "products" in category_data:
|
|
|
products = category_data["products"]
|
|
|
product_prompts = [f"a photo of {p}" for p in products]
|
|
|
@@ -374,7 +387,6 @@ class VisualProcessingService:
|
|
|
start_time = time.time()
|
|
|
|
|
|
try:
|
|
|
- # Download image
|
|
|
image = self.download_image(image_url)
|
|
|
if image is None:
|
|
|
return {
|
|
|
@@ -385,10 +397,8 @@ class VisualProcessingService:
|
|
|
visual_attributes = {}
|
|
|
detailed_predictions = {}
|
|
|
|
|
|
- # Step 1: Detect category, subcategory, and product type
|
|
|
category, subcategory, product_type, confidence = self.detect_category_and_subcategory(image)
|
|
|
|
|
|
- # Low confidence check
|
|
|
if confidence < 0.10:
|
|
|
logger.warning(f"Low confidence in detection ({confidence:.3f}). Returning basic attributes only.")
|
|
|
colors = self.extract_dominant_colors(image, n_colors=3)
|
|
|
@@ -403,13 +413,11 @@ class VisualProcessingService:
|
|
|
"processing_time": round(time.time() - start_time, 2)
|
|
|
}
|
|
|
|
|
|
- # Add detected metadata
|
|
|
visual_attributes["product_type"] = product_type
|
|
|
visual_attributes["category"] = category
|
|
|
if subcategory != "none" and subcategory != "unknown":
|
|
|
visual_attributes["subcategory"] = subcategory
|
|
|
|
|
|
- # Step 2: Extract color information (universal)
|
|
|
colors = self.extract_dominant_colors(image, n_colors=3)
|
|
|
if colors:
|
|
|
visual_attributes["primary_color"] = colors[0]["name"]
|
|
|
@@ -419,7 +427,6 @@ class VisualProcessingService:
|
|
|
for c in colors
|
|
|
]
|
|
|
|
|
|
- # Step 3: Get the right attribute configuration based on subcategory
|
|
|
attributes_config = None
|
|
|
|
|
|
if category == "clothing":
|
|
|
@@ -434,7 +441,6 @@ class VisualProcessingService:
|
|
|
attributes_config = self.CATEGORY_ATTRIBUTES[category]["attributes"]
|
|
|
logger.info(f"Using attributes for category: {category}")
|
|
|
|
|
|
- # Step 4: Extract category-specific attributes
|
|
|
if attributes_config:
|
|
|
for attr_name, attr_values in attributes_config.items():
|
|
|
result = self.classify_with_clip(
|
|
|
@@ -446,11 +452,9 @@ class VisualProcessingService:
|
|
|
|
|
|
if result["predictions"]:
|
|
|
best_prediction = result["predictions"][0]
|
|
|
- # Only add attributes with reasonable confidence
|
|
|
if best_prediction["confidence"] > 0.20:
|
|
|
visual_attributes[attr_name] = best_prediction["value"]
|
|
|
|
|
|
- # Store detailed predictions for debugging
|
|
|
detailed_predictions[attr_name] = result
|
|
|
|
|
|
processing_time = time.time() - start_time
|
|
|
@@ -461,7 +465,8 @@ class VisualProcessingService:
|
|
|
"visual_attributes": visual_attributes,
|
|
|
"detailed_predictions": detailed_predictions,
|
|
|
"detection_confidence": confidence,
|
|
|
- "processing_time": round(processing_time, 2)
|
|
|
+ "processing_time": round(processing_time, 2),
|
|
|
+ "cache_status": "enabled" if ENABLE_CLIP_MODEL_CACHE else "disabled"
|
|
|
}
|
|
|
|
|
|
except Exception as e:
|
|
|
@@ -470,6 +475,4 @@ class VisualProcessingService:
|
|
|
"visual_attributes": {},
|
|
|
"error": str(e),
|
|
|
"processing_time": round(time.time() - start_time, 2)
|
|
|
- }
|
|
|
-
|
|
|
-
|
|
|
+ }
|