|
@@ -0,0 +1,662 @@
|
|
|
|
|
+"""
|
|
|
|
|
+Enhanced Multi-API-Key Load Balancer with Smart Rate Limiting
|
|
|
|
|
+Optimized for Groq's free tier (30 RPM per key, 14K daily limit)
|
|
|
|
|
+"""
|
|
|
|
|
+
|
|
|
|
|
+import time
|
|
|
|
|
+import threading
|
|
|
|
|
+import requests
|
|
|
|
|
+import logging
|
|
|
|
|
+from typing import List, Dict, Optional
|
|
|
|
|
+from dataclasses import dataclass, field
|
|
|
|
|
+from datetime import datetime
|
|
|
|
|
+from django.conf import settings
|
|
|
|
|
+
|
|
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@dataclass
|
|
|
|
|
+class APIKeyState:
|
|
|
|
|
+ """Tracks state and health of a single API key"""
|
|
|
|
|
+ key: str
|
|
|
|
|
+ name: str
|
|
|
|
|
+ requests_made: int = 0
|
|
|
|
|
+ last_request_time: float = 0
|
|
|
|
|
+ is_available: bool = True
|
|
|
|
|
+ rate_limit_reset_time: Optional[float] = None
|
|
|
|
|
+ consecutive_failures: int = 0
|
|
|
|
|
+ total_requests: int = 0
|
|
|
|
|
+ total_failures: int = 0
|
|
|
|
|
+ request_times: list = field(default_factory=list)
|
|
|
|
|
+ requests_per_minute: int = 25 # Conservative: 25 instead of 30
|
|
|
|
|
+ min_request_interval: float = 2.5 # Minimum 2.5s between requests per key
|
|
|
|
|
+
|
|
|
|
|
+ def can_make_request(self) -> bool:
|
|
|
|
|
+ """Check if key can make a request (rate limit + spacing)"""
|
|
|
|
|
+ now = time.time()
|
|
|
|
|
+
|
|
|
|
|
+ # Check minimum interval between requests
|
|
|
|
|
+ if self.last_request_time and (now - self.last_request_time) < self.min_request_interval:
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+ # Remove requests older than 1 minute
|
|
|
|
|
+ self.request_times = [t for t in self.request_times if now - t < 60]
|
|
|
|
|
+ return len(self.request_times) < self.requests_per_minute
|
|
|
|
|
+
|
|
|
|
|
+ def mark_success(self):
|
|
|
|
|
+ now = time.time()
|
|
|
|
|
+ self.requests_made += 1
|
|
|
|
|
+ self.total_requests += 1
|
|
|
|
|
+ self.last_request_time = now
|
|
|
|
|
+ self.request_times.append(now)
|
|
|
|
|
+ self.consecutive_failures = 0
|
|
|
|
|
+ self.is_available = True
|
|
|
|
|
+ self.rate_limit_reset_time = None
|
|
|
|
|
+
|
|
|
|
|
+ # Keep only last 60 seconds
|
|
|
|
|
+ self.request_times = [t for t in self.request_times if now - t < 60]
|
|
|
|
|
+
|
|
|
|
|
+ def mark_failure(self, is_rate_limit: bool = False, retry_after: Optional[int] = None):
|
|
|
|
|
+ self.consecutive_failures += 1
|
|
|
|
|
+ self.total_failures += 1
|
|
|
|
|
+
|
|
|
|
|
+ if is_rate_limit:
|
|
|
|
|
+ self.is_available = False
|
|
|
|
|
+ reset_time = time.time() + (retry_after or 65) # 65s default
|
|
|
|
|
+ self.rate_limit_reset_time = reset_time
|
|
|
|
|
+ logger.warning(f"🚫 {self.name} rate limited until {datetime.fromtimestamp(reset_time).strftime('%H:%M:%S')}")
|
|
|
|
|
+
|
|
|
|
|
+ # Disable after 5 consecutive failures (increased from 3)
|
|
|
|
|
+ if self.consecutive_failures >= 5:
|
|
|
|
|
+ self.is_available = False
|
|
|
|
|
+ self.rate_limit_reset_time = time.time() + 120 # 2 min cooldown
|
|
|
|
|
+ logger.error(f"❌ {self.name} disabled (cooldown 2min)")
|
|
|
|
|
+
|
|
|
|
|
+ def check_availability(self) -> bool:
|
|
|
|
|
+ """Check if key is available"""
|
|
|
|
|
+ # Check rate limit reset
|
|
|
|
|
+ if self.rate_limit_reset_time and time.time() >= self.rate_limit_reset_time:
|
|
|
|
|
+ self.is_available = True
|
|
|
|
|
+ self.rate_limit_reset_time = None
|
|
|
|
|
+ self.consecutive_failures = 0
|
|
|
|
|
+ logger.info(f"✅ {self.name} recovered")
|
|
|
|
|
+ return True
|
|
|
|
|
+
|
|
|
|
|
+ if not self.is_available:
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+ return self.can_make_request()
|
|
|
|
|
+
|
|
|
|
|
+ def get_stats(self) -> Dict:
|
|
|
|
|
+ success_count = self.total_requests - self.total_failures
|
|
|
|
|
+ success_rate = (success_count / max(self.total_requests, 1)) * 100
|
|
|
|
|
+
|
|
|
|
|
+ return {
|
|
|
|
|
+ "name": self.name,
|
|
|
|
|
+ "total_requests": self.total_requests,
|
|
|
|
|
+ "total_failures": self.total_failures,
|
|
|
|
|
+ "success_rate": round(success_rate, 2),
|
|
|
|
|
+ "is_available": self.check_availability(),
|
|
|
|
|
+ "consecutive_failures": self.consecutive_failures,
|
|
|
|
|
+ "current_rpm": len(self.request_times),
|
|
|
|
|
+ "max_rpm": self.requests_per_minute,
|
|
|
|
|
+ "time_since_last_request": round(time.time() - self.last_request_time, 1) if self.last_request_time else None
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class MultiKeyLLMLoadBalancer:
|
|
|
|
|
+ """Enhanced load balancer with smart rate limiting"""
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(self, api_keys: List[Dict[str, str]], strategy: str = "round_robin"):
|
|
|
|
|
+ if not api_keys:
|
|
|
|
|
+ raise ValueError("At least one API key required")
|
|
|
|
|
+
|
|
|
|
|
+ self.keys = [APIKeyState(key=k['key'], name=k['name']) for k in api_keys]
|
|
|
|
|
+ self.strategy = strategy
|
|
|
|
|
+ self.current_index = 0
|
|
|
|
|
+ self.lock = threading.Lock()
|
|
|
|
|
+ self.total_requests = 0
|
|
|
|
|
+ self.total_failures = 0
|
|
|
|
|
+ self.global_last_request = 0
|
|
|
|
|
+ self.min_global_interval = 0.5 # 500ms between ANY requests
|
|
|
|
|
+
|
|
|
|
|
+ logger.info(f"🔑 Load balancer initialized: {len(self.keys)} keys, '{strategy}' strategy")
|
|
|
|
|
+
|
|
|
|
|
+ def get_next_key(self) -> Optional[APIKeyState]:
|
|
|
|
|
+ """Get next available key with global rate limiting"""
|
|
|
|
|
+ with self.lock:
|
|
|
|
|
+ # Enforce minimum global interval
|
|
|
|
|
+ now = time.time()
|
|
|
|
|
+ time_since_last = now - self.global_last_request
|
|
|
|
|
+ if time_since_last < self.min_global_interval:
|
|
|
|
|
+ wait_time = self.min_global_interval - time_since_last
|
|
|
|
|
+ time.sleep(wait_time)
|
|
|
|
|
+
|
|
|
|
|
+ if self.strategy == "least_loaded":
|
|
|
|
|
+ return self._least_loaded_select()
|
|
|
|
|
+ else:
|
|
|
|
|
+ return self._round_robin_select()
|
|
|
|
|
+
|
|
|
|
|
+ def _round_robin_select(self) -> Optional[APIKeyState]:
|
|
|
|
|
+ """Round-robin with availability check"""
|
|
|
|
|
+ attempts = 0
|
|
|
|
|
+ total_keys = len(self.keys)
|
|
|
|
|
+
|
|
|
|
|
+ while attempts < total_keys:
|
|
|
|
|
+ key = self.keys[self.current_index]
|
|
|
|
|
+ self.current_index = (self.current_index + 1) % total_keys
|
|
|
|
|
+
|
|
|
|
|
+ if key.check_availability():
|
|
|
|
|
+ return key
|
|
|
|
|
+
|
|
|
|
|
+ attempts += 1
|
|
|
|
|
+
|
|
|
|
|
+ return self._wait_for_available_key()
|
|
|
|
|
+
|
|
|
|
|
+ def _least_loaded_select(self) -> Optional[APIKeyState]:
|
|
|
|
|
+ """Select least loaded key"""
|
|
|
|
|
+ available = [k for k in self.keys if k.check_availability()]
|
|
|
|
|
+
|
|
|
|
|
+ if not available:
|
|
|
|
|
+ return self._wait_for_available_key()
|
|
|
|
|
+
|
|
|
|
|
+ available.sort(key=lambda k: (len(k.request_times), k.last_request_time))
|
|
|
|
|
+ return available[0]
|
|
|
|
|
+
|
|
|
|
|
+ def _wait_for_available_key(self, max_wait: float = 5.0) -> Optional[APIKeyState]:
|
|
|
|
|
+ """Wait for next available key (with timeout)"""
|
|
|
|
|
+ keys_with_reset = [k for k in self.keys if k.rate_limit_reset_time]
|
|
|
|
|
+
|
|
|
|
|
+ if not keys_with_reset:
|
|
|
|
|
+ # Check if any key just needs spacing
|
|
|
|
|
+ now = time.time()
|
|
|
|
|
+ for key in self.keys:
|
|
|
|
|
+ if key.is_available:
|
|
|
|
|
+ wait = key.min_request_interval - (now - key.last_request_time)
|
|
|
|
|
+ if 0 < wait < max_wait:
|
|
|
|
|
+ logger.info(f"⏳ Waiting {wait:.1f}s for {key.name}...")
|
|
|
|
|
+ time.sleep(wait + 0.1)
|
|
|
|
|
+ return key if key.check_availability() else None
|
|
|
|
|
+ return None
|
|
|
|
|
+
|
|
|
|
|
+ keys_with_reset.sort(key=lambda k: k.rate_limit_reset_time)
|
|
|
|
|
+ next_key = keys_with_reset[0]
|
|
|
|
|
+ wait = max(0, next_key.rate_limit_reset_time - time.time())
|
|
|
|
|
+
|
|
|
|
|
+ if 0 < wait < max_wait:
|
|
|
|
|
+ logger.info(f"⏳ Waiting {wait:.1f}s for {next_key.name}...")
|
|
|
|
|
+ time.sleep(wait + 0.5)
|
|
|
|
|
+ return next_key if next_key.check_availability() else None
|
|
|
|
|
+
|
|
|
|
|
+ return None
|
|
|
|
|
+
|
|
|
|
|
+ def mark_success(self, key: APIKeyState):
|
|
|
|
|
+ with self.lock:
|
|
|
|
|
+ key.mark_success()
|
|
|
|
|
+ self.total_requests += 1
|
|
|
|
|
+ self.global_last_request = time.time()
|
|
|
|
|
+
|
|
|
|
|
+ def mark_failure(self, key: APIKeyState, is_rate_limit: bool = False, retry_after: Optional[int] = None):
|
|
|
|
|
+ with self.lock:
|
|
|
|
|
+ key.mark_failure(is_rate_limit, retry_after)
|
|
|
|
|
+ self.total_failures += 1
|
|
|
|
|
+
|
|
|
|
|
+ def get_stats(self) -> Dict:
|
|
|
|
|
+ with self.lock:
|
|
|
|
|
+ available_count = sum(1 for k in self.keys if k.check_availability())
|
|
|
|
|
+ success_rate = ((self.total_requests - self.total_failures) / max(self.total_requests, 1)) * 100
|
|
|
|
|
+
|
|
|
|
|
+ return {
|
|
|
|
|
+ "total_keys": len(self.keys),
|
|
|
|
|
+ "available_keys": available_count,
|
|
|
|
|
+ "strategy": self.strategy,
|
|
|
|
|
+ "total_requests": self.total_requests,
|
|
|
|
|
+ "total_failures": self.total_failures,
|
|
|
|
|
+ "success_rate": round(success_rate, 2),
|
|
|
|
|
+ "keys": [k.get_stats() for k in self.keys]
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ def call_llm(self, payload: dict, api_url: str, max_retries: int = None) -> str:
|
|
|
|
|
+ """Make LLM call with smart retry and failover"""
|
|
|
|
|
+ if max_retries is None:
|
|
|
|
|
+ max_retries = len(self.keys) * 3
|
|
|
|
|
+
|
|
|
|
|
+ attempt = 0
|
|
|
|
|
+ last_error = None
|
|
|
|
|
+ keys_tried = set()
|
|
|
|
|
+
|
|
|
|
|
+ while attempt < max_retries:
|
|
|
|
|
+ key_state = self.get_next_key()
|
|
|
|
|
+
|
|
|
|
|
+ if not key_state:
|
|
|
|
|
+ if len(keys_tried) >= len(self.keys):
|
|
|
|
|
+ # All keys tried, wait longer
|
|
|
|
|
+ logger.warning(f"⏳ All keys exhausted. Waiting 3s...")
|
|
|
|
|
+ time.sleep(3)
|
|
|
|
|
+ keys_tried.clear()
|
|
|
|
|
+
|
|
|
|
|
+ attempt += 1
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ keys_tried.add(key_state.name)
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ headers = {
|
|
|
|
|
+ "Authorization": f"Bearer {key_state.key}",
|
|
|
|
|
+ "Content-Type": "application/json"
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ logger.debug(f"🔑 {key_state.name} (attempt {attempt + 1}/{max_retries})")
|
|
|
|
|
+
|
|
|
|
|
+ response = requests.post(
|
|
|
|
|
+ api_url,
|
|
|
|
|
+ headers=headers,
|
|
|
|
|
+ json=payload,
|
|
|
|
|
+ timeout=30
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ if response.status_code == 429:
|
|
|
|
|
+ retry_after = int(response.headers.get('Retry-After', 65))
|
|
|
|
|
+ self.mark_failure(key_state, is_rate_limit=True, retry_after=retry_after)
|
|
|
|
|
+ attempt += 1
|
|
|
|
|
+ time.sleep(1) # Brief pause before next key
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ response.raise_for_status()
|
|
|
|
|
+
|
|
|
|
|
+ # Success
|
|
|
|
|
+ self.mark_success(key_state)
|
|
|
|
|
+ content = response.json()["choices"][0]["message"]["content"]
|
|
|
|
|
+ logger.debug(f"✅ Success via {key_state.name}")
|
|
|
|
|
+ return content
|
|
|
|
|
+
|
|
|
|
|
+ except requests.exceptions.HTTPError as e:
|
|
|
|
|
+ if e.response and e.response.status_code == 429:
|
|
|
|
|
+ retry_after = int(e.response.headers.get('Retry-After', 65))
|
|
|
|
|
+ self.mark_failure(key_state, is_rate_limit=True, retry_after=retry_after)
|
|
|
|
|
+ else:
|
|
|
|
|
+ self.mark_failure(key_state)
|
|
|
|
|
+ logger.error(f"❌ HTTP error {key_state.name}: {e}")
|
|
|
|
|
+ last_error = e
|
|
|
|
|
+ attempt += 1
|
|
|
|
|
+ time.sleep(0.5)
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ self.mark_failure(key_state)
|
|
|
|
|
+ logger.error(f"❌ Error {key_state.name}: {e}")
|
|
|
|
|
+ last_error = e
|
|
|
|
|
+ attempt += 1
|
|
|
|
|
+ time.sleep(0.5)
|
|
|
|
|
+
|
|
|
|
|
+ stats = self.get_stats()
|
|
|
|
|
+ error_msg = (
|
|
|
|
|
+ f"LLM failed after {max_retries} attempts. "
|
|
|
|
|
+ f"Available: {stats['available_keys']}/{stats['total_keys']}. "
|
|
|
|
|
+ f"Error: {last_error}"
|
|
|
|
|
+ )
|
|
|
|
|
+ logger.error(f"💥 {error_msg}")
|
|
|
|
|
+ raise RuntimeError(error_msg)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# Global instance
|
|
|
|
|
+_load_balancer: Optional[MultiKeyLLMLoadBalancer] = None
|
|
|
|
|
+_load_balancer_lock = threading.Lock()
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def get_llm_load_balancer() -> MultiKeyLLMLoadBalancer:
|
|
|
|
|
+ """Get singleton load balancer"""
|
|
|
|
|
+ global _load_balancer
|
|
|
|
|
+
|
|
|
|
|
+ if _load_balancer is None:
|
|
|
|
|
+ with _load_balancer_lock:
|
|
|
|
|
+ if _load_balancer is None:
|
|
|
|
|
+ api_keys = getattr(settings, 'GROQ_API_KEYS', None)
|
|
|
|
|
+
|
|
|
|
|
+ if not api_keys:
|
|
|
|
|
+ single_key = getattr(settings, 'GROQ_API_KEY', None)
|
|
|
|
|
+ if single_key:
|
|
|
|
|
+ api_keys = [{'key': single_key, 'name': 'groq_key_1'}]
|
|
|
|
|
+
|
|
|
|
|
+ if not api_keys:
|
|
|
|
|
+ raise ValueError("No GROQ API keys configured")
|
|
|
|
|
+
|
|
|
|
|
+ strategy = getattr(settings, 'LLM_LOAD_BALANCER_STRATEGY', 'round_robin')
|
|
|
|
|
+ _load_balancer = MultiKeyLLMLoadBalancer(api_keys, strategy=strategy)
|
|
|
|
|
+
|
|
|
|
|
+ return _load_balancer
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def reset_load_balancer():
|
|
|
|
|
+ """Reset load balancer"""
|
|
|
|
|
+ global _load_balancer
|
|
|
|
|
+ with _load_balancer_lock:
|
|
|
|
|
+ _load_balancer = None
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def call_llm_with_load_balancer(payload: dict) -> str:
|
|
|
|
|
+ """Drop-in replacement for _call_llm"""
|
|
|
|
|
+ balancer = get_llm_load_balancer()
|
|
|
|
|
+ api_url = getattr(settings, 'GROQ_API_URL')
|
|
|
|
|
+ return balancer.call_llm(payload, api_url)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def get_load_balancer_stats() -> Dict:
|
|
|
|
|
+ """Get stats"""
|
|
|
|
|
+ try:
|
|
|
|
|
+ return get_llm_load_balancer().get_stats()
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ return {"error": str(e), "total_keys": 0, "available_keys": 0}
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# """
|
|
|
|
|
+# Ultra-Safe Sequential Load Balancer with Adaptive Rate Limiting
|
|
|
|
|
+# Guaranteed to work with strict API rate limits
|
|
|
|
|
+# """
|
|
|
|
|
+
|
|
|
|
|
+# import time
|
|
|
|
|
+# import threading
|
|
|
|
|
+# import requests
|
|
|
|
|
+# import logging
|
|
|
|
|
+# from typing import List, Dict, Optional
|
|
|
|
|
+# from dataclasses import dataclass
|
|
|
|
|
+# from datetime import datetime
|
|
|
|
|
+# from django.conf import settings
|
|
|
|
|
+
|
|
|
|
|
+# logger = logging.getLogger(__name__)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# @dataclass
|
|
|
|
|
+# class APIKeyState:
|
|
|
|
|
+# """Simple key state tracker"""
|
|
|
|
|
+# key: str
|
|
|
|
|
+# name: str
|
|
|
|
|
+# last_used: float = 0
|
|
|
|
|
+# total_requests: int = 0
|
|
|
|
|
+# total_failures: int = 0
|
|
|
|
|
+# consecutive_failures: int = 0
|
|
|
|
|
+# disabled_until: float = 0
|
|
|
|
|
+
|
|
|
|
|
+# def is_available(self) -> bool:
|
|
|
|
|
+# """Check if key is available RIGHT NOW"""
|
|
|
|
|
+# now = time.time()
|
|
|
|
|
+
|
|
|
|
|
+# # Check if disabled
|
|
|
|
|
+# if self.disabled_until > now:
|
|
|
|
|
+# return False
|
|
|
|
|
+
|
|
|
|
|
+# # Require 5 seconds between requests on SAME key
|
|
|
|
|
+# if self.last_used > 0:
|
|
|
|
|
+# elapsed = now - self.last_used
|
|
|
|
|
+# if elapsed < 5.0:
|
|
|
|
|
+# return False
|
|
|
|
|
+
|
|
|
|
|
+# return True
|
|
|
|
|
+
|
|
|
|
|
+# def get_wait_time(self) -> float:
|
|
|
|
|
+# """How long until this key is available?"""
|
|
|
|
|
+# now = time.time()
|
|
|
|
|
+
|
|
|
|
|
+# if self.disabled_until > now:
|
|
|
|
|
+# return self.disabled_until - now
|
|
|
|
|
+
|
|
|
|
|
+# if self.last_used > 0:
|
|
|
|
|
+# elapsed = now - self.last_used
|
|
|
|
|
+# if elapsed < 5.0:
|
|
|
|
|
+# return 5.0 - elapsed
|
|
|
|
|
+
|
|
|
|
|
+# return 0
|
|
|
|
|
+
|
|
|
|
|
+# def mark_success(self):
|
|
|
|
|
+# self.last_used = time.time()
|
|
|
|
|
+# self.total_requests += 1
|
|
|
|
|
+# self.consecutive_failures = 0
|
|
|
|
|
+# self.disabled_until = 0
|
|
|
|
|
+# logger.info(f"✅ {self.name} success (total: {self.total_requests})")
|
|
|
|
|
+
|
|
|
|
|
+# def mark_failure(self, is_rate_limit: bool = False):
|
|
|
|
|
+# self.last_used = time.time()
|
|
|
|
|
+# self.total_requests += 1
|
|
|
|
|
+# self.total_failures += 1
|
|
|
|
|
+# self.consecutive_failures += 1
|
|
|
|
|
+
|
|
|
|
|
+# if is_rate_limit:
|
|
|
|
|
+# # Rate limit: wait 90 seconds
|
|
|
|
|
+# self.disabled_until = time.time() + 90
|
|
|
|
|
+# logger.error(f"🚫 {self.name} RATE LIMITED → disabled for 90s")
|
|
|
|
|
+# elif self.consecutive_failures >= 2:
|
|
|
|
|
+# # 2 failures: wait 60 seconds
|
|
|
|
|
+# self.disabled_until = time.time() + 60
|
|
|
|
|
+# logger.error(f"❌ {self.name} FAILED {self.consecutive_failures}x → disabled for 60s")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# class UltraSafeLoadBalancer:
|
|
|
|
|
+# """
|
|
|
|
|
+# Ultra-conservative load balancer
|
|
|
|
|
+# - Minimum 5 seconds between requests on same key
|
|
|
|
|
+# - Minimum 1 second between ANY requests (global)
|
|
|
|
|
+# - Automatic waiting for key availability
|
|
|
|
|
+# - No parallel requests
|
|
|
|
|
+# """
|
|
|
|
|
+
|
|
|
|
|
+# def __init__(self, api_keys: List[Dict[str, str]]):
|
|
|
|
|
+# if not api_keys:
|
|
|
|
|
+# raise ValueError("At least one API key required")
|
|
|
|
|
+
|
|
|
|
|
+# self.keys = [APIKeyState(key=k['key'], name=k['name']) for k in api_keys]
|
|
|
|
|
+# self.current_index = 0
|
|
|
|
|
+# self.lock = threading.Lock()
|
|
|
|
|
+# self.last_global_request = 0
|
|
|
|
|
+# self.min_global_interval = 1.0 # 1 second between ANY requests
|
|
|
|
|
+
|
|
|
|
|
+# logger.info(f"🔑 Ultra-safe balancer: {len(self.keys)} keys, 5s per-key interval, 1s global interval")
|
|
|
|
|
+
|
|
|
|
|
+# def _enforce_global_rate_limit(self):
|
|
|
|
|
+# """Ensure minimum time between ANY requests"""
|
|
|
|
|
+# with self.lock:
|
|
|
|
|
+# if self.last_global_request > 0:
|
|
|
|
|
+# elapsed = time.time() - self.last_global_request
|
|
|
|
|
+# if elapsed < self.min_global_interval:
|
|
|
|
|
+# wait = self.min_global_interval - elapsed
|
|
|
|
|
+# logger.debug(f"⏱️ Global rate limit: waiting {wait:.1f}s")
|
|
|
|
|
+# time.sleep(wait)
|
|
|
|
|
+# self.last_global_request = time.time()
|
|
|
|
|
+
|
|
|
|
|
+# def get_next_key(self, max_wait: float = 30.0) -> Optional[APIKeyState]:
|
|
|
|
|
+# """Get next available key, waiting if necessary"""
|
|
|
|
|
+# start_time = time.time()
|
|
|
|
|
+
|
|
|
|
|
+# while (time.time() - start_time) < max_wait:
|
|
|
|
|
+# with self.lock:
|
|
|
|
|
+# # Try round-robin
|
|
|
|
|
+# for _ in range(len(self.keys)):
|
|
|
|
|
+# key = self.keys[self.current_index]
|
|
|
|
|
+# self.current_index = (self.current_index + 1) % len(self.keys)
|
|
|
|
|
+
|
|
|
|
|
+# if key.is_available():
|
|
|
|
|
+# return key
|
|
|
|
|
+
|
|
|
|
|
+# # No keys available - find the one that will be ready soonest
|
|
|
|
|
+# wait_times = [(k.get_wait_time(), k) for k in self.keys]
|
|
|
|
|
+# wait_times.sort()
|
|
|
|
|
+
|
|
|
|
|
+# if wait_times:
|
|
|
|
|
+# min_wait, next_key = wait_times[0]
|
|
|
|
|
+
|
|
|
|
|
+# if min_wait > 0 and min_wait < 15:
|
|
|
|
|
+# logger.info(f"⏳ All keys busy. Waiting {min_wait:.1f}s for {next_key.name}...")
|
|
|
|
|
+# time.sleep(min_wait + 0.2)
|
|
|
|
|
+# continue
|
|
|
|
|
+
|
|
|
|
|
+# time.sleep(0.5)
|
|
|
|
|
+
|
|
|
|
|
+# # Timeout
|
|
|
|
|
+# logger.error(f"❌ No keys available after {max_wait}s wait")
|
|
|
|
|
+# return None
|
|
|
|
|
+
|
|
|
|
|
+# def call_llm(self, payload: dict, api_url: str, retry_count: int = 0) -> str:
|
|
|
|
|
+# """
|
|
|
|
|
+# Make LLM call with ONE key
|
|
|
|
|
+# Retries with SAME key after waiting if it fails
|
|
|
|
|
+# """
|
|
|
|
|
+# # Enforce global rate limit FIRST
|
|
|
|
|
+# self._enforce_global_rate_limit()
|
|
|
|
|
+
|
|
|
|
|
+# # Get available key
|
|
|
|
|
+# key = self.get_next_key(max_wait=30.0)
|
|
|
|
|
+
|
|
|
|
|
+# if not key:
|
|
|
|
|
+# raise RuntimeError("No API keys available after 30s wait")
|
|
|
|
|
+
|
|
|
|
|
+# try:
|
|
|
|
|
+# headers = {
|
|
|
|
|
+# "Authorization": f"Bearer {key.key}",
|
|
|
|
|
+# "Content-Type": "application/json"
|
|
|
|
|
+# }
|
|
|
|
|
+
|
|
|
|
|
+# logger.info(f"🔑 Request via {key.name} (retry: {retry_count})")
|
|
|
|
|
+
|
|
|
|
|
+# response = requests.post(
|
|
|
|
|
+# api_url,
|
|
|
|
|
+# headers=headers,
|
|
|
|
|
+# json=payload,
|
|
|
|
|
+# timeout=45
|
|
|
|
|
+# )
|
|
|
|
|
+
|
|
|
|
|
+# # Check rate limit
|
|
|
|
|
+# if response.status_code == 429:
|
|
|
|
|
+# key.mark_failure(is_rate_limit=True)
|
|
|
|
|
+
|
|
|
|
|
+# if retry_count < 2:
|
|
|
|
|
+# logger.warning(f"⚠️ Rate limit hit, retrying with different key...")
|
|
|
|
|
+# time.sleep(2)
|
|
|
|
|
+# return self.call_llm(payload, api_url, retry_count + 1)
|
|
|
|
|
+# else:
|
|
|
|
|
+# raise RuntimeError(f"Rate limit on {key.name} after {retry_count} retries")
|
|
|
|
|
+
|
|
|
|
|
+# # Check for other errors
|
|
|
|
|
+# response.raise_for_status()
|
|
|
|
|
+
|
|
|
|
|
+# # Success!
|
|
|
|
|
+# key.mark_success()
|
|
|
|
|
+# content = response.json()["choices"][0]["message"]["content"]
|
|
|
|
|
+# return content
|
|
|
|
|
+
|
|
|
|
|
+# except requests.exceptions.HTTPError as e:
|
|
|
|
|
+# if e.response and e.response.status_code == 429:
|
|
|
|
|
+# key.mark_failure(is_rate_limit=True)
|
|
|
|
|
+# else:
|
|
|
|
|
+# key.mark_failure(is_rate_limit=False)
|
|
|
|
|
+
|
|
|
|
|
+# if retry_count < 2:
|
|
|
|
|
+# logger.warning(f"⚠️ HTTP error, retrying... ({e})")
|
|
|
|
|
+# time.sleep(2)
|
|
|
|
|
+# return self.call_llm(payload, api_url, retry_count + 1)
|
|
|
|
|
+# else:
|
|
|
|
|
+# raise RuntimeError(f"HTTP error after {retry_count} retries: {e}")
|
|
|
|
|
+
|
|
|
|
|
+# except requests.exceptions.Timeout as e:
|
|
|
|
|
+# key.mark_failure(is_rate_limit=False)
|
|
|
|
|
+
|
|
|
|
|
+# if retry_count < 1:
|
|
|
|
|
+# logger.warning(f"⏱️ Timeout, retrying...")
|
|
|
|
|
+# time.sleep(3)
|
|
|
|
|
+# return self.call_llm(payload, api_url, retry_count + 1)
|
|
|
|
|
+# else:
|
|
|
|
|
+# raise RuntimeError(f"Timeout after {retry_count} retries: {e}")
|
|
|
|
|
+
|
|
|
|
|
+# except Exception as e:
|
|
|
|
|
+# key.mark_failure(is_rate_limit=False)
|
|
|
|
|
+# raise RuntimeError(f"Unexpected error with {key.name}: {e}")
|
|
|
|
|
+
|
|
|
|
|
+# def get_stats(self) -> Dict:
|
|
|
|
|
+# with self.lock:
|
|
|
|
|
+# available = sum(1 for k in self.keys if k.is_available())
|
|
|
|
|
+# total_reqs = sum(k.total_requests for k in self.keys)
|
|
|
|
|
+# total_fails = sum(k.total_failures for k in self.keys)
|
|
|
|
|
+# success_rate = ((total_reqs - total_fails) / max(total_reqs, 1)) * 100
|
|
|
|
|
+
|
|
|
|
|
+# return {
|
|
|
|
|
+# "total_keys": len(self.keys),
|
|
|
|
|
+# "available_keys": available,
|
|
|
|
|
+# "total_requests": total_reqs,
|
|
|
|
|
+# "total_failures": total_fails,
|
|
|
|
|
+# "success_rate": round(success_rate, 2),
|
|
|
|
|
+# "keys": [
|
|
|
|
|
+# {
|
|
|
|
|
+# "name": k.name,
|
|
|
|
|
+# "total_requests": k.total_requests,
|
|
|
|
|
+# "total_failures": k.total_failures,
|
|
|
|
|
+# "consecutive_failures": k.consecutive_failures,
|
|
|
|
|
+# "is_available": k.is_available(),
|
|
|
|
|
+# "wait_time": round(k.get_wait_time(), 1)
|
|
|
|
|
+# }
|
|
|
|
|
+# for k in self.keys
|
|
|
|
|
+# ]
|
|
|
|
|
+# }
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# # Global singleton
|
|
|
|
|
+# _balancer: Optional[UltraSafeLoadBalancer] = None
|
|
|
|
|
+# _balancer_lock = threading.Lock()
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# def get_llm_load_balancer() -> UltraSafeLoadBalancer:
|
|
|
|
|
+# """Get singleton balancer"""
|
|
|
|
|
+# global _balancer
|
|
|
|
|
+
|
|
|
|
|
+# if _balancer is None:
|
|
|
|
|
+# with _balancer_lock:
|
|
|
|
|
+# if _balancer is None:
|
|
|
|
|
+# api_keys = getattr(settings, 'GROQ_API_KEYS', None)
|
|
|
|
|
+
|
|
|
|
|
+# if not api_keys:
|
|
|
|
|
+# single_key = getattr(settings, 'GROQ_API_KEY', None)
|
|
|
|
|
+# if single_key:
|
|
|
|
|
+# api_keys = [{'key': single_key, 'name': 'groq_key_1'}]
|
|
|
|
|
+
|
|
|
|
|
+# if not api_keys:
|
|
|
|
|
+# raise ValueError("No GROQ_API_KEYS configured in settings")
|
|
|
|
|
+
|
|
|
|
|
+# _balancer = UltraSafeLoadBalancer(api_keys)
|
|
|
|
|
+
|
|
|
|
|
+# return _balancer
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# def reset_load_balancer():
|
|
|
|
|
+# """Reset balancer (for testing)"""
|
|
|
|
|
+# global _balancer
|
|
|
|
|
+# with _balancer_lock:
|
|
|
|
|
+# _balancer = None
|
|
|
|
|
+# logger.info("🔄 Balancer reset")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# def call_llm_with_load_balancer(payload: dict) -> str:
|
|
|
|
|
+# """
|
|
|
|
|
+# Call LLM with ultra-safe rate limiting
|
|
|
|
|
+# This is the drop-in replacement for services.py
|
|
|
|
|
+# """
|
|
|
|
|
+# balancer = get_llm_load_balancer()
|
|
|
|
|
+# api_url = getattr(settings, 'GROQ_API_URL')
|
|
|
|
|
+# return balancer.call_llm(payload, api_url)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# def get_load_balancer_stats() -> Dict:
|
|
|
|
|
+# """Get balancer stats"""
|
|
|
|
|
+# try:
|
|
|
|
|
+# return get_llm_load_balancer().get_stats()
|
|
|
|
|
+# except Exception as e:
|
|
|
|
|
+# return {"error": str(e)}
|