|
|
@@ -0,0 +1,1193 @@
|
|
|
+# # ==================== visual_processing_service.py ====================
|
|
|
+# import torch
|
|
|
+# import cv2
|
|
|
+# import numpy as np
|
|
|
+# import requests
|
|
|
+# from io import BytesIO
|
|
|
+# from PIL import Image
|
|
|
+# from typing import Dict, List, Optional, Tuple
|
|
|
+# import logging
|
|
|
+# from transformers import CLIPProcessor, CLIPModel, AutoImageProcessor, AutoModelForImageClassification
|
|
|
+# from sklearn.cluster import KMeans
|
|
|
+# import webcolors
|
|
|
+
|
|
|
+# logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+
|
|
|
+# class VisualProcessingService:
|
|
|
+# """Service for extracting visual attributes from product images using CLIP and computer vision."""
|
|
|
+
|
|
|
+# def __init__(self):
|
|
|
+# self.clip_model = None
|
|
|
+# self.clip_processor = None
|
|
|
+# self.classification_model = None
|
|
|
+# self.classification_processor = None
|
|
|
+
|
|
|
+# def _get_clip_model(self):
|
|
|
+# """Lazy load CLIP model."""
|
|
|
+# if self.clip_model is None:
|
|
|
+# logger.info("Loading CLIP model...")
|
|
|
+# self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
|
|
+# self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
|
|
+# self.clip_model.eval()
|
|
|
+# return self.clip_model, self.clip_processor
|
|
|
+
|
|
|
+# def _get_classification_model(self):
|
|
|
+# """Lazy load image classification model for product categories."""
|
|
|
+# if self.classification_model is None:
|
|
|
+# logger.info("Loading classification model...")
|
|
|
+# # Using Google's ViT model fine-tuned on fashion/products
|
|
|
+# self.classification_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
|
|
|
+# self.classification_model = AutoModelForImageClassification.from_pretrained(
|
|
|
+# "google/vit-base-patch16-224"
|
|
|
+# )
|
|
|
+# self.classification_model.eval()
|
|
|
+# return self.classification_model, self.classification_processor
|
|
|
+
|
|
|
+# def download_image(self, image_url: str) -> Optional[Image.Image]:
|
|
|
+# """Download image from URL."""
|
|
|
+# try:
|
|
|
+# response = requests.get(image_url, timeout=10)
|
|
|
+# response.raise_for_status()
|
|
|
+# image = Image.open(BytesIO(response.content)).convert('RGB')
|
|
|
+# return image
|
|
|
+# except Exception as e:
|
|
|
+# logger.error(f"Error downloading image from {image_url}: {str(e)}")
|
|
|
+# return None
|
|
|
+
|
|
|
+# def extract_dominant_colors(self, image: Image.Image, n_colors: int = 3) -> List[Dict]:
|
|
|
+# """Extract dominant colors from image using K-means clustering."""
|
|
|
+# try:
|
|
|
+# # Resize image for faster processing
|
|
|
+# img_small = image.resize((150, 150))
|
|
|
+# img_array = np.array(img_small)
|
|
|
+
|
|
|
+# # Reshape to pixels
|
|
|
+# pixels = img_array.reshape(-1, 3)
|
|
|
+
|
|
|
+# # Apply K-means
|
|
|
+# kmeans = KMeans(n_clusters=n_colors, random_state=42, n_init=10)
|
|
|
+# kmeans.fit(pixels)
|
|
|
+
|
|
|
+# colors = []
|
|
|
+# for center in kmeans.cluster_centers_:
|
|
|
+# rgb = tuple(center.astype(int))
|
|
|
+# color_name = self._get_color_name(rgb)
|
|
|
+# colors.append({
|
|
|
+# "name": color_name,
|
|
|
+# "rgb": rgb,
|
|
|
+# "percentage": float(np.sum(kmeans.labels_ == len(colors)) / len(kmeans.labels_) * 100)
|
|
|
+# })
|
|
|
+
|
|
|
+# # Sort by percentage
|
|
|
+# colors.sort(key=lambda x: x['percentage'], reverse=True)
|
|
|
+# return colors
|
|
|
+
|
|
|
+# except Exception as e:
|
|
|
+# logger.error(f"Error extracting colors: {str(e)}")
|
|
|
+# return []
|
|
|
+
|
|
|
+# def _get_color_name(self, rgb: Tuple[int, int, int]) -> str:
|
|
|
+# """Convert RGB to closest color name."""
|
|
|
+# try:
|
|
|
+# # Try to get exact match
|
|
|
+# color_name = webcolors.rgb_to_name(rgb)
|
|
|
+# return color_name
|
|
|
+# except ValueError:
|
|
|
+# # Find closest color
|
|
|
+# min_distance = float('inf')
|
|
|
+# closest_name = 'unknown'
|
|
|
+
|
|
|
+# for name in webcolors.CSS3_NAMES_TO_HEX:
|
|
|
+# hex_color = webcolors.CSS3_NAMES_TO_HEX[name]
|
|
|
+# r, g, b = webcolors.hex_to_rgb(hex_color)
|
|
|
+# distance = sum((c1 - c2) ** 2 for c1, c2 in zip(rgb, (r, g, b)))
|
|
|
+
|
|
|
+# if distance < min_distance:
|
|
|
+# min_distance = distance
|
|
|
+# closest_name = name
|
|
|
+
|
|
|
+# return closest_name
|
|
|
+
|
|
|
+# def classify_with_clip(self, image: Image.Image, candidates: List[str], attribute_name: str) -> Dict:
|
|
|
+# """Use CLIP to classify image against candidate labels."""
|
|
|
+# try:
|
|
|
+# model, processor = self._get_clip_model()
|
|
|
+
|
|
|
+# # Prepare inputs
|
|
|
+# inputs = processor(
|
|
|
+# text=candidates,
|
|
|
+# images=image,
|
|
|
+# return_tensors="pt",
|
|
|
+# padding=True
|
|
|
+# )
|
|
|
+
|
|
|
+# # Get predictions
|
|
|
+# with torch.no_grad():
|
|
|
+# outputs = model(**inputs)
|
|
|
+# logits_per_image = outputs.logits_per_image
|
|
|
+# probs = logits_per_image.softmax(dim=1)
|
|
|
+
|
|
|
+# # Get top predictions
|
|
|
+# top_probs, top_indices = torch.topk(probs[0], k=min(3, len(candidates)))
|
|
|
+
|
|
|
+# results = []
|
|
|
+# for prob, idx in zip(top_probs, top_indices):
|
|
|
+# if prob.item() > 0.15: # Confidence threshold
|
|
|
+# results.append({
|
|
|
+# "value": candidates[idx.item()],
|
|
|
+# "confidence": float(prob.item())
|
|
|
+# })
|
|
|
+
|
|
|
+# return {
|
|
|
+# "attribute": attribute_name,
|
|
|
+# "predictions": results
|
|
|
+# }
|
|
|
+
|
|
|
+# except Exception as e:
|
|
|
+# logger.error(f"Error in CLIP classification: {str(e)}")
|
|
|
+# return {"attribute": attribute_name, "predictions": []}
|
|
|
+
|
|
|
+# def detect_patterns(self, image: Image.Image) -> Dict:
|
|
|
+# """Detect patterns in the image using edge detection and texture analysis."""
|
|
|
+# try:
|
|
|
+# # Convert to OpenCV format
|
|
|
+# img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
|
|
+# gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
|
|
|
+
|
|
|
+# # Calculate edge density
|
|
|
+# edges = cv2.Canny(gray, 50, 150)
|
|
|
+# edge_density = np.sum(edges > 0) / (edges.shape[0] * edges.shape[1])
|
|
|
+
|
|
|
+# # Calculate texture variance
|
|
|
+# laplacian = cv2.Laplacian(gray, cv2.CV_64F)
|
|
|
+# texture_variance = laplacian.var()
|
|
|
+
|
|
|
+# # Determine pattern type based on metrics
|
|
|
+# pattern_candidates = []
|
|
|
+
|
|
|
+# if edge_density > 0.15:
|
|
|
+# pattern_candidates.append("geometric")
|
|
|
+# if texture_variance > 500:
|
|
|
+# pattern_candidates.append("textured")
|
|
|
+# if edge_density < 0.05 and texture_variance < 200:
|
|
|
+# pattern_candidates.append("solid")
|
|
|
+
|
|
|
+# # Use CLIP for more detailed pattern detection
|
|
|
+# pattern_types = [
|
|
|
+# "solid color", "striped", "checkered", "polka dot", "floral",
|
|
|
+# "geometric", "abstract", "graphic print", "camouflage", "paisley"
|
|
|
+# ]
|
|
|
+
|
|
|
+# clip_result = self.classify_with_clip(image, pattern_types, "pattern")
|
|
|
+
|
|
|
+# return clip_result
|
|
|
+
|
|
|
+# except Exception as e:
|
|
|
+# logger.error(f"Error detecting patterns: {str(e)}")
|
|
|
+# return {"attribute": "pattern", "predictions": []}
|
|
|
+
|
|
|
+# def detect_material(self, image: Image.Image) -> Dict:
|
|
|
+# """Detect material type using CLIP."""
|
|
|
+# materials = [
|
|
|
+# "cotton", "polyester", "denim", "leather", "silk", "wool",
|
|
|
+# "linen", "satin", "velvet", "fleece", "knit", "jersey",
|
|
|
+# "canvas", "nylon", "suede", "corduroy"
|
|
|
+# ]
|
|
|
+
|
|
|
+# return self.classify_with_clip(image, materials, "material")
|
|
|
+
|
|
|
+# def detect_style(self, image: Image.Image) -> Dict:
|
|
|
+# """Detect style/occasion using CLIP."""
|
|
|
+# styles = [
|
|
|
+# "casual", "formal", "sporty", "business", "vintage", "modern",
|
|
|
+# "bohemian", "streetwear", "elegant", "preppy", "athletic",
|
|
|
+# "loungewear", "party", "workwear", "outdoor"
|
|
|
+# ]
|
|
|
+
|
|
|
+# return self.classify_with_clip(image, styles, "style")
|
|
|
+
|
|
|
+# def detect_fit(self, image: Image.Image) -> Dict:
|
|
|
+# """Detect clothing fit using CLIP."""
|
|
|
+# fits = [
|
|
|
+# "slim fit", "regular fit", "loose fit", "oversized",
|
|
|
+# "tight", "relaxed", "tailored", "athletic fit"
|
|
|
+# ]
|
|
|
+
|
|
|
+# return self.classify_with_clip(image, fits, "fit")
|
|
|
+
|
|
|
+# def detect_neckline(self, image: Image.Image, product_type: str) -> Dict:
|
|
|
+# """Detect neckline type for tops using CLIP."""
|
|
|
+# if product_type.lower() not in ['shirt', 't-shirt', 'top', 'blouse', 'dress', 'sweater']:
|
|
|
+# return {"attribute": "neckline", "predictions": []}
|
|
|
+
|
|
|
+# necklines = [
|
|
|
+# "crew neck", "v-neck", "round neck", "collar", "turtleneck",
|
|
|
+# "scoop neck", "boat neck", "off-shoulder", "square neck", "halter"
|
|
|
+# ]
|
|
|
+
|
|
|
+# return self.classify_with_clip(image, necklines, "neckline")
|
|
|
+
|
|
|
+# def detect_sleeve_type(self, image: Image.Image, product_type: str) -> Dict:
|
|
|
+# """Detect sleeve type using CLIP."""
|
|
|
+# if product_type.lower() not in ['shirt', 't-shirt', 'top', 'blouse', 'dress', 'sweater', 'jacket']:
|
|
|
+# return {"attribute": "sleeve_type", "predictions": []}
|
|
|
+
|
|
|
+# sleeves = [
|
|
|
+# "short sleeve", "long sleeve", "sleeveless", "three-quarter sleeve",
|
|
|
+# "cap sleeve", "flutter sleeve", "bell sleeve", "raglan sleeve"
|
|
|
+# ]
|
|
|
+
|
|
|
+# return self.classify_with_clip(image, sleeves, "sleeve_type")
|
|
|
+
|
|
|
+# def detect_product_type(self, image: Image.Image) -> Dict:
|
|
|
+# """Detect product type using CLIP."""
|
|
|
+# product_types = [
|
|
|
+# "t-shirt", "shirt", "dress", "pants", "jeans", "shorts",
|
|
|
+# "skirt", "jacket", "coat", "sweater", "hoodie", "blazer",
|
|
|
+# "suit", "jumpsuit", "romper", "cardigan", "vest", "top",
|
|
|
+# "blouse", "tank top", "polo shirt", "sweatshirt"
|
|
|
+# ]
|
|
|
+
|
|
|
+# return self.classify_with_clip(image, product_types, "product_type")
|
|
|
+
|
|
|
+# def detect_closure_type(self, image: Image.Image) -> Dict:
|
|
|
+# """Detect closure type using CLIP."""
|
|
|
+# closures = [
|
|
|
+# "button", "zipper", "snap", "hook and eye", "velcro",
|
|
|
+# "lace-up", "pull-on", "elastic", "tie", "buckle"
|
|
|
+# ]
|
|
|
+
|
|
|
+# return self.classify_with_clip(image, closures, "closure_type")
|
|
|
+
|
|
|
+# def detect_length(self, image: Image.Image, product_type: str) -> Dict:
|
|
|
+# """Detect garment length using CLIP."""
|
|
|
+# if product_type.lower() in ['pants', 'jeans', 'trousers']:
|
|
|
+# lengths = ["full length", "ankle length", "cropped", "capri", "shorts"]
|
|
|
+# elif product_type.lower() in ['skirt', 'dress']:
|
|
|
+# lengths = ["mini", "knee length", "midi", "maxi", "floor length"]
|
|
|
+# elif product_type.lower() in ['jacket', 'coat']:
|
|
|
+# lengths = ["waist length", "hip length", "thigh length", "knee length", "full length"]
|
|
|
+# else:
|
|
|
+# lengths = ["short", "regular", "long"]
|
|
|
+
|
|
|
+# return self.classify_with_clip(image, lengths, "length")
|
|
|
+
|
|
|
+# def process_image(self, image_url: str, product_type_hint: Optional[str] = None) -> Dict:
|
|
|
+# """
|
|
|
+# Main method to process image and extract all visual attributes.
|
|
|
+# """
|
|
|
+# try:
|
|
|
+# # Download image
|
|
|
+# image = self.download_image(image_url)
|
|
|
+# if image is None:
|
|
|
+# return {
|
|
|
+# "visual_attributes": {},
|
|
|
+# "error": "Failed to download image"
|
|
|
+# }
|
|
|
+
|
|
|
+# # Extract all attributes
|
|
|
+# visual_attributes = {}
|
|
|
+
|
|
|
+# # 1. Product Type Detection
|
|
|
+# logger.info("Detecting product type...")
|
|
|
+# product_type_result = self.detect_product_type(image)
|
|
|
+# if product_type_result["predictions"]:
|
|
|
+# visual_attributes["product_type"] = product_type_result["predictions"][0]["value"]
|
|
|
+# detected_product_type = visual_attributes["product_type"]
|
|
|
+# else:
|
|
|
+# detected_product_type = product_type_hint or "unknown"
|
|
|
+
|
|
|
+# # 2. Color Detection
|
|
|
+# logger.info("Extracting colors...")
|
|
|
+# colors = self.extract_dominant_colors(image, n_colors=3)
|
|
|
+# if colors:
|
|
|
+# visual_attributes["primary_color"] = colors[0]["name"]
|
|
|
+# visual_attributes["color_palette"] = [c["name"] for c in colors]
|
|
|
+# visual_attributes["color_details"] = colors
|
|
|
+
|
|
|
+# # 3. Pattern Detection
|
|
|
+# logger.info("Detecting patterns...")
|
|
|
+# pattern_result = self.detect_patterns(image)
|
|
|
+# if pattern_result["predictions"]:
|
|
|
+# visual_attributes["pattern"] = pattern_result["predictions"][0]["value"]
|
|
|
+
|
|
|
+# # 4. Material Detection
|
|
|
+# logger.info("Detecting material...")
|
|
|
+# material_result = self.detect_material(image)
|
|
|
+# if material_result["predictions"]:
|
|
|
+# visual_attributes["material"] = material_result["predictions"][0]["value"]
|
|
|
+
|
|
|
+# # 5. Style Detection
|
|
|
+# logger.info("Detecting style...")
|
|
|
+# style_result = self.detect_style(image)
|
|
|
+# if style_result["predictions"]:
|
|
|
+# visual_attributes["style"] = style_result["predictions"][0]["value"]
|
|
|
+
|
|
|
+# # 6. Fit Detection
|
|
|
+# logger.info("Detecting fit...")
|
|
|
+# fit_result = self.detect_fit(image)
|
|
|
+# if fit_result["predictions"]:
|
|
|
+# visual_attributes["fit"] = fit_result["predictions"][0]["value"]
|
|
|
+
|
|
|
+# # 7. Neckline Detection (if applicable)
|
|
|
+# logger.info("Detecting neckline...")
|
|
|
+# neckline_result = self.detect_neckline(image, detected_product_type)
|
|
|
+# if neckline_result["predictions"]:
|
|
|
+# visual_attributes["neckline"] = neckline_result["predictions"][0]["value"]
|
|
|
+
|
|
|
+# # 8. Sleeve Type Detection (if applicable)
|
|
|
+# logger.info("Detecting sleeve type...")
|
|
|
+# sleeve_result = self.detect_sleeve_type(image, detected_product_type)
|
|
|
+# if sleeve_result["predictions"]:
|
|
|
+# visual_attributes["sleeve_type"] = sleeve_result["predictions"][0]["value"]
|
|
|
+
|
|
|
+# # 9. Closure Type Detection
|
|
|
+# logger.info("Detecting closure type...")
|
|
|
+# closure_result = self.detect_closure_type(image)
|
|
|
+# if closure_result["predictions"]:
|
|
|
+# visual_attributes["closure_type"] = closure_result["predictions"][0]["value"]
|
|
|
+
|
|
|
+# # 10. Length Detection
|
|
|
+# logger.info("Detecting length...")
|
|
|
+# length_result = self.detect_length(image, detected_product_type)
|
|
|
+# if length_result["predictions"]:
|
|
|
+# visual_attributes["length"] = length_result["predictions"][0]["value"]
|
|
|
+
|
|
|
+# # Format response
|
|
|
+# return {
|
|
|
+# "visual_attributes": visual_attributes,
|
|
|
+# "detailed_predictions": {
|
|
|
+# "product_type": product_type_result,
|
|
|
+# "pattern": pattern_result,
|
|
|
+# "material": material_result,
|
|
|
+# "style": style_result,
|
|
|
+# "fit": fit_result,
|
|
|
+# "neckline": neckline_result,
|
|
|
+# "sleeve_type": sleeve_result,
|
|
|
+# "closure_type": closure_result,
|
|
|
+# "length": length_result
|
|
|
+# }
|
|
|
+# }
|
|
|
+
|
|
|
+# except Exception as e:
|
|
|
+# logger.error(f"Error processing image: {str(e)}")
|
|
|
+# return {
|
|
|
+# "visual_attributes": {},
|
|
|
+# "error": str(e)
|
|
|
+# }
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+# # ==================== visual_processing_service_optimized.py ====================
|
|
|
+# """
|
|
|
+# Optimized version with:
|
|
|
+# - Result caching
|
|
|
+# - Batch processing support
|
|
|
+# - Memory management
|
|
|
+# - Error recovery
|
|
|
+# - Performance monitoring
|
|
|
+# """
|
|
|
+
|
|
|
+# import torch
|
|
|
+# import cv2
|
|
|
+# import numpy as np
|
|
|
+# import requests
|
|
|
+# from io import BytesIO
|
|
|
+# from PIL import Image
|
|
|
+# from typing import Dict, List, Optional, Tuple
|
|
|
+# import logging
|
|
|
+# import time
|
|
|
+# import hashlib
|
|
|
+# from functools import lru_cache
|
|
|
+# from transformers import CLIPProcessor, CLIPModel
|
|
|
+# from sklearn.cluster import KMeans
|
|
|
+# import webcolors
|
|
|
+
|
|
|
+# logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+
|
|
|
+# class VisualProcessingService:
|
|
|
+# """Optimized service for extracting visual attributes from product images."""
|
|
|
+
|
|
|
+# # Class-level model caching (shared across instances)
|
|
|
+# _clip_model = None
|
|
|
+# _clip_processor = None
|
|
|
+# _model_device = None
|
|
|
+
|
|
|
+# def __init__(self, use_cache: bool = True, cache_ttl: int = 3600):
|
|
|
+# self.use_cache = use_cache
|
|
|
+# self.cache_ttl = cache_ttl
|
|
|
+# self._cache = {}
|
|
|
+
|
|
|
+# @classmethod
|
|
|
+# def _get_device(cls):
|
|
|
+# """Get optimal device (GPU if available, else CPU)."""
|
|
|
+# if cls._model_device is None:
|
|
|
+# cls._model_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
+# logger.info(f"Using device: {cls._model_device}")
|
|
|
+# return cls._model_device
|
|
|
+
|
|
|
+# @classmethod
|
|
|
+# def _get_clip_model(cls):
|
|
|
+# """Lazy load CLIP model with class-level caching."""
|
|
|
+# if cls._clip_model is None:
|
|
|
+# logger.info("Loading CLIP model...")
|
|
|
+# start_time = time.time()
|
|
|
+
|
|
|
+# try:
|
|
|
+# cls._clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
|
|
+# cls._clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
|
|
+
|
|
|
+# device = cls._get_device()
|
|
|
+# cls._clip_model.to(device)
|
|
|
+# cls._clip_model.eval()
|
|
|
+
|
|
|
+# logger.info(f"CLIP model loaded in {time.time() - start_time:.2f}s")
|
|
|
+# except Exception as e:
|
|
|
+# logger.error(f"Failed to load CLIP model: {str(e)}")
|
|
|
+# raise
|
|
|
+
|
|
|
+# return cls._clip_model, cls._clip_processor
|
|
|
+
|
|
|
+# def _get_cache_key(self, image_url: str, operation: str) -> str:
|
|
|
+# """Generate cache key for results."""
|
|
|
+# url_hash = hashlib.md5(image_url.encode()).hexdigest()
|
|
|
+# return f"visual_{operation}_{url_hash}"
|
|
|
+
|
|
|
+# def _get_cached(self, key: str) -> Optional[Dict]:
|
|
|
+# """Get cached result if available and not expired."""
|
|
|
+# if not self.use_cache:
|
|
|
+# return None
|
|
|
+
|
|
|
+# if key in self._cache:
|
|
|
+# result, timestamp = self._cache[key]
|
|
|
+# if time.time() - timestamp < self.cache_ttl:
|
|
|
+# return result
|
|
|
+# else:
|
|
|
+# del self._cache[key]
|
|
|
+# return None
|
|
|
+
|
|
|
+# def _set_cached(self, key: str, value: Dict):
|
|
|
+# """Cache result with timestamp."""
|
|
|
+# if self.use_cache:
|
|
|
+# self._cache[key] = (value, time.time())
|
|
|
+
|
|
|
+# def download_image(self, image_url: str, max_size: Tuple[int, int] = (1024, 1024)) -> Optional[Image.Image]:
|
|
|
+# """Download and optionally resize image for faster processing."""
|
|
|
+# try:
|
|
|
+# response = requests.get(image_url, timeout=10)
|
|
|
+# response.raise_for_status()
|
|
|
+
|
|
|
+# image = Image.open(BytesIO(response.content)).convert('RGB')
|
|
|
+
|
|
|
+# # Resize if image is too large
|
|
|
+# if image.size[0] > max_size[0] or image.size[1] > max_size[1]:
|
|
|
+# image.thumbnail(max_size, Image.Resampling.LANCZOS)
|
|
|
+# logger.info(f"Resized image from original size to {image.size}")
|
|
|
+
|
|
|
+# return image
|
|
|
+
|
|
|
+# except Exception as e:
|
|
|
+# logger.error(f"Error downloading image from {image_url}: {str(e)}")
|
|
|
+# return None
|
|
|
+
|
|
|
+# @lru_cache(maxsize=100)
|
|
|
+# def _get_color_name_cached(self, rgb: Tuple[int, int, int]) -> str:
|
|
|
+# """Cached version of color name lookup."""
|
|
|
+# try:
|
|
|
+# return webcolors.rgb_to_name(rgb)
|
|
|
+# except ValueError:
|
|
|
+# min_distance = float('inf')
|
|
|
+# closest_name = 'unknown'
|
|
|
+
|
|
|
+# for name in webcolors.CSS3_NAMES_TO_HEX:
|
|
|
+# hex_color = webcolors.CSS3_NAMES_TO_HEX[name]
|
|
|
+# r, g, b = webcolors.hex_to_rgb(hex_color)
|
|
|
+# distance = sum((c1 - c2) ** 2 for c1, c2 in zip(rgb, (r, g, b)))
|
|
|
+
|
|
|
+# if distance < min_distance:
|
|
|
+# min_distance = distance
|
|
|
+# closest_name = name
|
|
|
+
|
|
|
+# return closest_name
|
|
|
+
|
|
|
+# def extract_dominant_colors(self, image: Image.Image, n_colors: int = 3) -> List[Dict]:
|
|
|
+# """Extract dominant colors with optimized K-means."""
|
|
|
+# try:
|
|
|
+# # Resize for faster processing
|
|
|
+# img_small = image.resize((150, 150))
|
|
|
+# img_array = np.array(img_small)
|
|
|
+# pixels = img_array.reshape(-1, 3)
|
|
|
+
|
|
|
+# # Sample pixels if too many
|
|
|
+# if len(pixels) > 10000:
|
|
|
+# indices = np.random.choice(len(pixels), 10000, replace=False)
|
|
|
+# pixels = pixels[indices]
|
|
|
+
|
|
|
+# # K-means with optimized parameters
|
|
|
+# kmeans = KMeans(
|
|
|
+# n_clusters=n_colors,
|
|
|
+# random_state=42,
|
|
|
+# n_init=5, # Reduced from 10 for speed
|
|
|
+# max_iter=100,
|
|
|
+# algorithm='elkan' # Faster for low dimensions
|
|
|
+# )
|
|
|
+# kmeans.fit(pixels)
|
|
|
+
|
|
|
+# colors = []
|
|
|
+# labels_counts = np.bincount(kmeans.labels_)
|
|
|
+
|
|
|
+# for i, center in enumerate(kmeans.cluster_centers_):
|
|
|
+# rgb = tuple(center.astype(int))
|
|
|
+# color_name = self._get_color_name_cached(rgb)
|
|
|
+# percentage = float(labels_counts[i] / len(kmeans.labels_) * 100)
|
|
|
+
|
|
|
+# colors.append({
|
|
|
+# "name": color_name,
|
|
|
+# "rgb": rgb,
|
|
|
+# "percentage": percentage
|
|
|
+# })
|
|
|
+
|
|
|
+# colors.sort(key=lambda x: x['percentage'], reverse=True)
|
|
|
+# return colors
|
|
|
+
|
|
|
+# except Exception as e:
|
|
|
+# logger.error(f"Error extracting colors: {str(e)}")
|
|
|
+# return []
|
|
|
+
|
|
|
+# def classify_with_clip(
|
|
|
+# self,
|
|
|
+# image: Image.Image,
|
|
|
+# candidates: List[str],
|
|
|
+# attribute_name: str,
|
|
|
+# confidence_threshold: float = 0.15
|
|
|
+# ) -> Dict:
|
|
|
+# """Optimized CLIP classification with batching."""
|
|
|
+# try:
|
|
|
+# model, processor = self._get_clip_model()
|
|
|
+# device = self._get_device()
|
|
|
+
|
|
|
+# # Prepare inputs
|
|
|
+# inputs = processor(
|
|
|
+# text=candidates,
|
|
|
+# images=image,
|
|
|
+# return_tensors="pt",
|
|
|
+# padding=True
|
|
|
+# )
|
|
|
+
|
|
|
+# # Move to device
|
|
|
+# inputs = {k: v.to(device) for k, v in inputs.items()}
|
|
|
+
|
|
|
+# # Get predictions with no_grad for speed
|
|
|
+# with torch.no_grad():
|
|
|
+# outputs = model(**inputs)
|
|
|
+# logits_per_image = outputs.logits_per_image
|
|
|
+# probs = logits_per_image.softmax(dim=1).cpu()
|
|
|
+
|
|
|
+# # Get top predictions
|
|
|
+# top_k = min(3, len(candidates))
|
|
|
+# top_probs, top_indices = torch.topk(probs[0], k=top_k)
|
|
|
+
|
|
|
+# results = []
|
|
|
+# for prob, idx in zip(top_probs, top_indices):
|
|
|
+# if prob.item() > confidence_threshold:
|
|
|
+# results.append({
|
|
|
+# "value": candidates[idx.item()],
|
|
|
+# "confidence": float(prob.item())
|
|
|
+# })
|
|
|
+
|
|
|
+# return {
|
|
|
+# "attribute": attribute_name,
|
|
|
+# "predictions": results
|
|
|
+# }
|
|
|
+
|
|
|
+# except Exception as e:
|
|
|
+# logger.error(f"Error in CLIP classification for {attribute_name}: {str(e)}")
|
|
|
+# return {"attribute": attribute_name, "predictions": []}
|
|
|
+
|
|
|
+# def batch_classify(
|
|
|
+# self,
|
|
|
+# image: Image.Image,
|
|
|
+# attribute_configs: List[Dict[str, any]]
|
|
|
+# ) -> Dict[str, Dict]:
|
|
|
+# """
|
|
|
+# Batch multiple CLIP classifications for efficiency.
|
|
|
+# attribute_configs: [{"name": "pattern", "candidates": [...], "threshold": 0.15}, ...]
|
|
|
+# """
|
|
|
+# results = {}
|
|
|
+
|
|
|
+# for config in attribute_configs:
|
|
|
+# attr_name = config["name"]
|
|
|
+# candidates = config["candidates"]
|
|
|
+# threshold = config.get("threshold", 0.15)
|
|
|
+
|
|
|
+# result = self.classify_with_clip(image, candidates, attr_name, threshold)
|
|
|
+# results[attr_name] = result
|
|
|
+
|
|
|
+# return results
|
|
|
+
|
|
|
+# def detect_patterns(self, image: Image.Image) -> Dict:
|
|
|
+# """Detect patterns using CLIP."""
|
|
|
+# pattern_types = [
|
|
|
+# "solid color", "striped", "checkered", "polka dot", "floral",
|
|
|
+# "geometric", "abstract", "graphic print", "camouflage", "paisley"
|
|
|
+# ]
|
|
|
+# return self.classify_with_clip(image, pattern_types, "pattern")
|
|
|
+
|
|
|
+# def process_image(
|
|
|
+# self,
|
|
|
+# image_url: str,
|
|
|
+# product_type_hint: Optional[str] = None,
|
|
|
+# attributes_to_extract: Optional[List[str]] = None
|
|
|
+# ) -> Dict:
|
|
|
+# """
|
|
|
+# Main method with caching and selective attribute extraction.
|
|
|
+
|
|
|
+# Args:
|
|
|
+# image_url: URL of the product image
|
|
|
+# product_type_hint: Optional hint about product type
|
|
|
+# attributes_to_extract: List of attributes to extract (None = all)
|
|
|
+# """
|
|
|
+# # Check cache
|
|
|
+# cache_key = self._get_cache_key(image_url, "full")
|
|
|
+# cached_result = self._get_cached(cache_key)
|
|
|
+# if cached_result:
|
|
|
+# logger.info(f"Returning cached result for {image_url}")
|
|
|
+# return cached_result
|
|
|
+
|
|
|
+# start_time = time.time()
|
|
|
+
|
|
|
+# try:
|
|
|
+# # Download image
|
|
|
+# image = self.download_image(image_url)
|
|
|
+# if image is None:
|
|
|
+# return {
|
|
|
+# "visual_attributes": {},
|
|
|
+# "error": "Failed to download image"
|
|
|
+# }
|
|
|
+
|
|
|
+# visual_attributes = {}
|
|
|
+# detailed_predictions = {}
|
|
|
+
|
|
|
+# # Default: extract all attributes
|
|
|
+# if attributes_to_extract is None:
|
|
|
+# attributes_to_extract = [
|
|
|
+# "product_type", "color", "pattern", "material",
|
|
|
+# "style", "fit", "neckline", "sleeve_type",
|
|
|
+# "closure_type", "length"
|
|
|
+# ]
|
|
|
+
|
|
|
+# # 1. Product Type Detection
|
|
|
+# if "product_type" in attributes_to_extract:
|
|
|
+# logger.info("Detecting product type...")
|
|
|
+# product_types = [
|
|
|
+# "t-shirt", "shirt", "dress", "pants", "jeans", "shorts",
|
|
|
+# "skirt", "jacket", "coat", "sweater", "hoodie", "blazer",
|
|
|
+# "top", "blouse"
|
|
|
+# ]
|
|
|
+# product_type_result = self.classify_with_clip(image, product_types, "product_type")
|
|
|
+# if product_type_result["predictions"]:
|
|
|
+# visual_attributes["product_type"] = product_type_result["predictions"][0]["value"]
|
|
|
+# detected_product_type = visual_attributes["product_type"]
|
|
|
+# else:
|
|
|
+# detected_product_type = product_type_hint or "unknown"
|
|
|
+# detailed_predictions["product_type"] = product_type_result
|
|
|
+# else:
|
|
|
+# detected_product_type = product_type_hint or "unknown"
|
|
|
+
|
|
|
+# # 2. Color Detection
|
|
|
+# if "color" in attributes_to_extract:
|
|
|
+# logger.info("Extracting colors...")
|
|
|
+# colors = self.extract_dominant_colors(image, n_colors=3)
|
|
|
+# if colors:
|
|
|
+# visual_attributes["primary_color"] = colors[0]["name"]
|
|
|
+# visual_attributes["color_palette"] = [c["name"] for c in colors]
|
|
|
+# visual_attributes["color_details"] = colors
|
|
|
+
|
|
|
+# # 3. Batch classify remaining attributes
|
|
|
+# batch_configs = []
|
|
|
+
|
|
|
+# if "pattern" in attributes_to_extract:
|
|
|
+# batch_configs.append({
|
|
|
+# "name": "pattern",
|
|
|
+# "candidates": [
|
|
|
+# "solid color", "striped", "checkered", "polka dot",
|
|
|
+# "floral", "geometric", "abstract", "graphic print"
|
|
|
+# ]
|
|
|
+# })
|
|
|
+
|
|
|
+# if "material" in attributes_to_extract:
|
|
|
+# batch_configs.append({
|
|
|
+# "name": "material",
|
|
|
+# "candidates": [
|
|
|
+# "cotton", "polyester", "denim", "leather", "silk",
|
|
|
+# "wool", "linen", "satin", "fleece", "knit"
|
|
|
+# ]
|
|
|
+# })
|
|
|
+
|
|
|
+# if "style" in attributes_to_extract:
|
|
|
+# batch_configs.append({
|
|
|
+# "name": "style",
|
|
|
+# "candidates": [
|
|
|
+# "casual", "formal", "sporty", "business", "vintage",
|
|
|
+# "modern", "streetwear", "elegant", "athletic"
|
|
|
+# ]
|
|
|
+# })
|
|
|
+
|
|
|
+# if "fit" in attributes_to_extract:
|
|
|
+# batch_configs.append({
|
|
|
+# "name": "fit",
|
|
|
+# "candidates": [
|
|
|
+# "slim fit", "regular fit", "loose fit", "oversized",
|
|
|
+# "relaxed", "tailored"
|
|
|
+# ]
|
|
|
+# })
|
|
|
+
|
|
|
+# # Product-type specific attributes
|
|
|
+# if detected_product_type.lower() in ['shirt', 't-shirt', 'top', 'blouse', 'dress', 'sweater']:
|
|
|
+# if "neckline" in attributes_to_extract:
|
|
|
+# batch_configs.append({
|
|
|
+# "name": "neckline",
|
|
|
+# "candidates": [
|
|
|
+# "crew neck", "v-neck", "round neck", "collar",
|
|
|
+# "turtleneck", "scoop neck", "boat neck"
|
|
|
+# ]
|
|
|
+# })
|
|
|
+
|
|
|
+# if "sleeve_type" in attributes_to_extract:
|
|
|
+# batch_configs.append({
|
|
|
+# "name": "sleeve_type",
|
|
|
+# "candidates": [
|
|
|
+# "short sleeve", "long sleeve", "sleeveless",
|
|
|
+# "three-quarter sleeve", "cap sleeve"
|
|
|
+# ]
|
|
|
+# })
|
|
|
+
|
|
|
+# if "closure_type" in attributes_to_extract:
|
|
|
+# batch_configs.append({
|
|
|
+# "name": "closure_type",
|
|
|
+# "candidates": [
|
|
|
+# "button", "zipper", "snap", "pull-on",
|
|
|
+# "lace-up", "elastic", "buckle"
|
|
|
+# ]
|
|
|
+# })
|
|
|
+
|
|
|
+# if "length" in attributes_to_extract:
|
|
|
+# if detected_product_type.lower() in ['pants', 'jeans', 'trousers']:
|
|
|
+# batch_configs.append({
|
|
|
+# "name": "length",
|
|
|
+# "candidates": ["full length", "ankle length", "cropped", "capri", "shorts"]
|
|
|
+# })
|
|
|
+# elif detected_product_type.lower() in ['skirt', 'dress']:
|
|
|
+# batch_configs.append({
|
|
|
+# "name": "length",
|
|
|
+# "candidates": ["mini", "knee length", "midi", "maxi", "floor length"]
|
|
|
+# })
|
|
|
+
|
|
|
+# # Execute batch classification
|
|
|
+# logger.info(f"Batch classifying {len(batch_configs)} attributes...")
|
|
|
+# batch_results = self.batch_classify(image, batch_configs)
|
|
|
+
|
|
|
+# # Process batch results
|
|
|
+# for attr_name, result in batch_results.items():
|
|
|
+# detailed_predictions[attr_name] = result
|
|
|
+# if result["predictions"]:
|
|
|
+# visual_attributes[attr_name] = result["predictions"][0]["value"]
|
|
|
+
|
|
|
+# # Format response
|
|
|
+# result = {
|
|
|
+# "visual_attributes": visual_attributes,
|
|
|
+# "detailed_predictions": detailed_predictions,
|
|
|
+# "processing_time": round(time.time() - start_time, 2)
|
|
|
+# }
|
|
|
+
|
|
|
+# # Cache result
|
|
|
+# self._set_cached(cache_key, result)
|
|
|
+
|
|
|
+# logger.info(f"Visual processing completed in {result['processing_time']}s")
|
|
|
+# return result
|
|
|
+
|
|
|
+# except Exception as e:
|
|
|
+# logger.error(f"Error processing image: {str(e)}")
|
|
|
+# return {
|
|
|
+# "visual_attributes": {},
|
|
|
+# "error": str(e),
|
|
|
+# "processing_time": round(time.time() - start_time, 2)
|
|
|
+# }
|
|
|
+
|
|
|
+# def clear_cache(self):
|
|
|
+# """Clear all cached results."""
|
|
|
+# self._cache.clear()
|
|
|
+# logger.info("Cache cleared")
|
|
|
+
|
|
|
+# def get_cache_stats(self) -> Dict:
|
|
|
+# """Get cache statistics."""
|
|
|
+# return {
|
|
|
+# "cache_size": len(self._cache),
|
|
|
+# "cache_enabled": self.use_cache,
|
|
|
+# "cache_ttl": self.cache_ttl
|
|
|
+# }
|
|
|
+
|
|
|
+# @classmethod
|
|
|
+# def cleanup_models(cls):
|
|
|
+# """Free up memory by unloading models."""
|
|
|
+# if cls._clip_model is not None:
|
|
|
+# del cls._clip_model
|
|
|
+# del cls._clip_processor
|
|
|
+# cls._clip_model = None
|
|
|
+# cls._clip_processor = None
|
|
|
+
|
|
|
+# if torch.cuda.is_available():
|
|
|
+# torch.cuda.empty_cache()
|
|
|
+
|
|
|
+# logger.info("Models unloaded and memory freed")
|
|
|
+
|
|
|
+
|
|
|
+# # ==================== Usage Example ====================
|
|
|
+
|
|
|
+# def example_usage():
|
|
|
+# """Example of how to use the optimized service."""
|
|
|
+
|
|
|
+# # Initialize service with caching
|
|
|
+# service = VisualProcessingService(use_cache=True, cache_ttl=3600)
|
|
|
+
|
|
|
+# # Process single image with all attributes
|
|
|
+# result1 = service.process_image("https://example.com/product1.jpg")
|
|
|
+# print("All attributes:", result1["visual_attributes"])
|
|
|
+
|
|
|
+# # Process with selective attributes (faster)
|
|
|
+# result2 = service.process_image(
|
|
|
+# "https://example.com/product2.jpg",
|
|
|
+# product_type_hint="t-shirt",
|
|
|
+# attributes_to_extract=["color", "pattern", "style"]
|
|
|
+# )
|
|
|
+# print("Selected attributes:", result2["visual_attributes"])
|
|
|
+
|
|
|
+# # Check cache stats
|
|
|
+# print("Cache stats:", service.get_cache_stats())
|
|
|
+
|
|
|
+# # Clear cache when needed
|
|
|
+# service.clear_cache()
|
|
|
+
|
|
|
+# # Cleanup models (call when shutting down)
|
|
|
+# VisualProcessingService.cleanup_models()
|
|
|
+
|
|
|
+
|
|
|
+# if __name__ == "__main__":
|
|
|
+# example_usage()
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+# ==================== visual_processing_service.py (FIXED) ====================
|
|
|
+import torch
|
|
|
+import cv2
|
|
|
+import numpy as np
|
|
|
+import requests
|
|
|
+from io import BytesIO
|
|
|
+from PIL import Image
|
|
|
+from typing import Dict, List, Optional, Tuple
|
|
|
+import logging
|
|
|
+from transformers import CLIPProcessor, CLIPModel
|
|
|
+from sklearn.cluster import KMeans
|
|
|
+
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+
|
|
|
+class VisualProcessingService:
|
|
|
+ """Service for extracting visual attributes from product images using CLIP."""
|
|
|
+
|
|
|
+ # Class-level caching (shared across instances)
|
|
|
+ _clip_model = None
|
|
|
+ _clip_processor = None
|
|
|
+ _device = None
|
|
|
+
|
|
|
+ def __init__(self):
|
|
|
+ pass
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def _get_device(cls):
|
|
|
+ """Get optimal device."""
|
|
|
+ if cls._device is None:
|
|
|
+ cls._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
+ logger.info(f"Visual Processing using device: {cls._device}")
|
|
|
+ return cls._device
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def _get_clip_model(cls):
|
|
|
+ """Lazy load CLIP model with class-level caching."""
|
|
|
+ if cls._clip_model is None:
|
|
|
+ logger.info("Loading CLIP model (this may take a few minutes on first use)...")
|
|
|
+ cls._clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
|
|
+ cls._clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
|
|
+
|
|
|
+ device = cls._get_device()
|
|
|
+ cls._clip_model.to(device)
|
|
|
+ cls._clip_model.eval()
|
|
|
+
|
|
|
+ logger.info("✓ CLIP model loaded successfully")
|
|
|
+ return cls._clip_model, cls._clip_processor
|
|
|
+
|
|
|
+ def download_image(self, image_url: str) -> Optional[Image.Image]:
|
|
|
+ """Download image from URL."""
|
|
|
+ try:
|
|
|
+ response = requests.get(image_url, timeout=10)
|
|
|
+ response.raise_for_status()
|
|
|
+ image = Image.open(BytesIO(response.content)).convert('RGB')
|
|
|
+ return image
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Error downloading image from {image_url}: {str(e)}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ def extract_dominant_colors(self, image: Image.Image, n_colors: int = 3) -> List[Dict]:
|
|
|
+ """Extract dominant colors using K-means (FIXED webcolors issue)."""
|
|
|
+ try:
|
|
|
+ # Resize for faster processing
|
|
|
+ img_small = image.resize((150, 150))
|
|
|
+ img_array = np.array(img_small)
|
|
|
+ pixels = img_array.reshape(-1, 3)
|
|
|
+
|
|
|
+ # K-means clustering
|
|
|
+ kmeans = KMeans(n_clusters=n_colors, random_state=42, n_init=5)
|
|
|
+ kmeans.fit(pixels)
|
|
|
+
|
|
|
+ colors = []
|
|
|
+ labels_counts = np.bincount(kmeans.labels_)
|
|
|
+
|
|
|
+ for i, center in enumerate(kmeans.cluster_centers_):
|
|
|
+ rgb = tuple(center.astype(int))
|
|
|
+ color_name = self._get_color_name_simple(rgb)
|
|
|
+ percentage = float(labels_counts[i] / len(kmeans.labels_) * 100)
|
|
|
+
|
|
|
+ colors.append({
|
|
|
+ "name": color_name,
|
|
|
+ "rgb": rgb,
|
|
|
+ "percentage": percentage
|
|
|
+ })
|
|
|
+
|
|
|
+ colors.sort(key=lambda x: x['percentage'], reverse=True)
|
|
|
+ return colors
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Error extracting colors: {str(e)}")
|
|
|
+ return []
|
|
|
+
|
|
|
+ def _get_color_name_simple(self, rgb: Tuple[int, int, int]) -> str:
|
|
|
+ """
|
|
|
+ Simple color name detection without webcolors dependency.
|
|
|
+ Maps RGB to basic color names.
|
|
|
+ """
|
|
|
+ r, g, b = rgb
|
|
|
+
|
|
|
+ # Define basic color ranges
|
|
|
+ colors = {
|
|
|
+ 'black': (r < 50 and g < 50 and b < 50),
|
|
|
+ 'white': (r > 200 and g > 200 and b > 200),
|
|
|
+ 'gray': (abs(r - g) < 30 and abs(g - b) < 30 and abs(r - b) < 30 and 50 <= r <= 200),
|
|
|
+ 'red': (r > 150 and g < 100 and b < 100),
|
|
|
+ 'green': (g > 150 and r < 100 and b < 100),
|
|
|
+ 'blue': (b > 150 and r < 100 and g < 100),
|
|
|
+ 'yellow': (r > 200 and g > 200 and b < 100),
|
|
|
+ 'orange': (r > 200 and 100 < g < 200 and b < 100),
|
|
|
+ 'purple': (r > 100 and b > 100 and g < 100),
|
|
|
+ 'pink': (r > 200 and 100 < g < 200 and 100 < b < 200),
|
|
|
+ 'brown': (50 < r < 150 and 30 < g < 100 and b < 80),
|
|
|
+ 'cyan': (r < 100 and g > 150 and b > 150),
|
|
|
+ }
|
|
|
+
|
|
|
+ for color_name, condition in colors.items():
|
|
|
+ if condition:
|
|
|
+ return color_name
|
|
|
+
|
|
|
+ # Default fallback
|
|
|
+ if r > g and r > b:
|
|
|
+ return 'red'
|
|
|
+ elif g > r and g > b:
|
|
|
+ return 'green'
|
|
|
+ elif b > r and b > g:
|
|
|
+ return 'blue'
|
|
|
+ else:
|
|
|
+ return 'gray'
|
|
|
+
|
|
|
+ def classify_with_clip(
|
|
|
+ self,
|
|
|
+ image: Image.Image,
|
|
|
+ candidates: List[str],
|
|
|
+ attribute_name: str
|
|
|
+ ) -> Dict:
|
|
|
+ """Use CLIP to classify image against candidate labels."""
|
|
|
+ try:
|
|
|
+ model, processor = self._get_clip_model()
|
|
|
+ device = self._get_device()
|
|
|
+
|
|
|
+ # Prepare inputs
|
|
|
+ inputs = processor(
|
|
|
+ text=candidates,
|
|
|
+ images=image,
|
|
|
+ return_tensors="pt",
|
|
|
+ padding=True
|
|
|
+ )
|
|
|
+
|
|
|
+ # Move to device
|
|
|
+ inputs = {k: v.to(device) for k, v in inputs.items()}
|
|
|
+
|
|
|
+ # Get predictions
|
|
|
+ with torch.no_grad():
|
|
|
+ outputs = model(**inputs)
|
|
|
+ logits_per_image = outputs.logits_per_image
|
|
|
+ probs = logits_per_image.softmax(dim=1).cpu()
|
|
|
+
|
|
|
+ # Get top predictions
|
|
|
+ top_k = min(3, len(candidates))
|
|
|
+ top_probs, top_indices = torch.topk(probs[0], k=top_k)
|
|
|
+
|
|
|
+ results = []
|
|
|
+ for prob, idx in zip(top_probs, top_indices):
|
|
|
+ if prob.item() > 0.15: # Confidence threshold
|
|
|
+ results.append({
|
|
|
+ "value": candidates[idx.item()],
|
|
|
+ "confidence": float(prob.item())
|
|
|
+ })
|
|
|
+
|
|
|
+ return {
|
|
|
+ "attribute": attribute_name,
|
|
|
+ "predictions": results
|
|
|
+ }
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Error in CLIP classification for {attribute_name}: {str(e)}")
|
|
|
+ return {"attribute": attribute_name, "predictions": []}
|
|
|
+
|
|
|
+ def process_image(
|
|
|
+ self,
|
|
|
+ image_url: str,
|
|
|
+ product_type_hint: Optional[str] = None
|
|
|
+ ) -> Dict:
|
|
|
+ """
|
|
|
+ Main method to process image and extract visual attributes.
|
|
|
+ """
|
|
|
+ import time
|
|
|
+ start_time = time.time()
|
|
|
+
|
|
|
+ try:
|
|
|
+ # Download image
|
|
|
+ image = self.download_image(image_url)
|
|
|
+ if image is None:
|
|
|
+ return {
|
|
|
+ "visual_attributes": {},
|
|
|
+ "error": "Failed to download image"
|
|
|
+ }
|
|
|
+
|
|
|
+ visual_attributes = {}
|
|
|
+ detailed_predictions = {}
|
|
|
+
|
|
|
+ # 1. Product Type Detection
|
|
|
+ product_types = [
|
|
|
+ "t-shirt", "shirt", "dress", "pants", "jeans", "shorts",
|
|
|
+ "skirt", "jacket", "coat", "sweater", "hoodie", "top"
|
|
|
+ ]
|
|
|
+ product_type_result = self.classify_with_clip(image, product_types, "product_type")
|
|
|
+ if product_type_result["predictions"]:
|
|
|
+ visual_attributes["product_type"] = product_type_result["predictions"][0]["value"]
|
|
|
+ detected_product_type = visual_attributes["product_type"]
|
|
|
+ else:
|
|
|
+ detected_product_type = product_type_hint or "unknown"
|
|
|
+ detailed_predictions["product_type"] = product_type_result
|
|
|
+
|
|
|
+ # 2. Color Detection
|
|
|
+ colors = self.extract_dominant_colors(image, n_colors=3)
|
|
|
+ if colors:
|
|
|
+ visual_attributes["primary_color"] = colors[0]["name"]
|
|
|
+ visual_attributes["color_palette"] = [c["name"] for c in colors]
|
|
|
+
|
|
|
+ # 3. Pattern Detection
|
|
|
+ patterns = ["solid color", "striped", "checkered", "graphic print", "floral", "geometric"]
|
|
|
+ pattern_result = self.classify_with_clip(image, patterns, "pattern")
|
|
|
+ if pattern_result["predictions"]:
|
|
|
+ visual_attributes["pattern"] = pattern_result["predictions"][0]["value"]
|
|
|
+ detailed_predictions["pattern"] = pattern_result
|
|
|
+
|
|
|
+ # 4. Material Detection
|
|
|
+ materials = ["cotton", "polyester", "denim", "leather", "silk", "wool", "linen"]
|
|
|
+ material_result = self.classify_with_clip(image, materials, "material")
|
|
|
+ if material_result["predictions"]:
|
|
|
+ visual_attributes["material"] = material_result["predictions"][0]["value"]
|
|
|
+ detailed_predictions["material"] = material_result
|
|
|
+
|
|
|
+ # 5. Style Detection
|
|
|
+ styles = ["casual", "formal", "sporty", "streetwear", "elegant", "vintage"]
|
|
|
+ style_result = self.classify_with_clip(image, styles, "style")
|
|
|
+ if style_result["predictions"]:
|
|
|
+ visual_attributes["style"] = style_result["predictions"][0]["value"]
|
|
|
+ detailed_predictions["style"] = style_result
|
|
|
+
|
|
|
+ # 6. Fit Detection
|
|
|
+ fits = ["slim fit", "regular fit", "loose fit", "oversized"]
|
|
|
+ fit_result = self.classify_with_clip(image, fits, "fit")
|
|
|
+ if fit_result["predictions"]:
|
|
|
+ visual_attributes["fit"] = fit_result["predictions"][0]["value"]
|
|
|
+ detailed_predictions["fit"] = fit_result
|
|
|
+
|
|
|
+ # 7. Neckline (for tops)
|
|
|
+ if detected_product_type.lower() in ['shirt', 't-shirt', 'top', 'dress']:
|
|
|
+ necklines = ["crew neck", "v-neck", "round neck", "collar"]
|
|
|
+ neckline_result = self.classify_with_clip(image, necklines, "neckline")
|
|
|
+ if neckline_result["predictions"]:
|
|
|
+ visual_attributes["neckline"] = neckline_result["predictions"][0]["value"]
|
|
|
+ detailed_predictions["neckline"] = neckline_result
|
|
|
+
|
|
|
+ # 8. Sleeve Type (for tops)
|
|
|
+ if detected_product_type.lower() in ['shirt', 't-shirt', 'top']:
|
|
|
+ sleeves = ["short sleeve", "long sleeve", "sleeveless"]
|
|
|
+ sleeve_result = self.classify_with_clip(image, sleeves, "sleeve_type")
|
|
|
+ if sleeve_result["predictions"]:
|
|
|
+ visual_attributes["sleeve_type"] = sleeve_result["predictions"][0]["value"]
|
|
|
+ detailed_predictions["sleeve_type"] = sleeve_result
|
|
|
+
|
|
|
+ # 9. Closure Type
|
|
|
+ closures = ["button", "zipper", "pull-on"]
|
|
|
+ closure_result = self.classify_with_clip(image, closures, "closure_type")
|
|
|
+ if closure_result["predictions"]:
|
|
|
+ visual_attributes["closure_type"] = closure_result["predictions"][0]["value"]
|
|
|
+ detailed_predictions["closure_type"] = closure_result
|
|
|
+
|
|
|
+ processing_time = time.time() - start_time
|
|
|
+
|
|
|
+ return {
|
|
|
+ "visual_attributes": visual_attributes,
|
|
|
+ "detailed_predictions": detailed_predictions,
|
|
|
+ "processing_time": round(processing_time, 2)
|
|
|
+ }
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Error processing image: {str(e)}")
|
|
|
+ return {
|
|
|
+ "visual_attributes": {},
|
|
|
+ "error": str(e),
|
|
|
+ "processing_time": round(time.time() - start_time, 2)
|
|
|
+ }
|