| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662 |
- """
- Enhanced Multi-API-Key Load Balancer with Smart Rate Limiting
- Optimized for Groq's free tier (30 RPM per key, 14K daily limit)
- """
- import time
- import threading
- import requests
- import logging
- from typing import List, Dict, Optional
- from dataclasses import dataclass, field
- from datetime import datetime
- from django.conf import settings
- logger = logging.getLogger(__name__)
- @dataclass
- class APIKeyState:
- """Tracks state and health of a single API key"""
- key: str
- name: str
- requests_made: int = 0
- last_request_time: float = 0
- is_available: bool = True
- rate_limit_reset_time: Optional[float] = None
- consecutive_failures: int = 0
- total_requests: int = 0
- total_failures: int = 0
- request_times: list = field(default_factory=list)
- requests_per_minute: int = 25 # Conservative: 25 instead of 30
- min_request_interval: float = 2.5 # Minimum 2.5s between requests per key
-
- def can_make_request(self) -> bool:
- """Check if key can make a request (rate limit + spacing)"""
- now = time.time()
-
- # Check minimum interval between requests
- if self.last_request_time and (now - self.last_request_time) < self.min_request_interval:
- return False
-
- # Remove requests older than 1 minute
- self.request_times = [t for t in self.request_times if now - t < 60]
- return len(self.request_times) < self.requests_per_minute
-
- def mark_success(self):
- now = time.time()
- self.requests_made += 1
- self.total_requests += 1
- self.last_request_time = now
- self.request_times.append(now)
- self.consecutive_failures = 0
- self.is_available = True
- self.rate_limit_reset_time = None
-
- # Keep only last 60 seconds
- self.request_times = [t for t in self.request_times if now - t < 60]
-
- def mark_failure(self, is_rate_limit: bool = False, retry_after: Optional[int] = None):
- self.consecutive_failures += 1
- self.total_failures += 1
-
- if is_rate_limit:
- self.is_available = False
- reset_time = time.time() + (retry_after or 65) # 65s default
- self.rate_limit_reset_time = reset_time
- logger.warning(f"🚫 {self.name} rate limited until {datetime.fromtimestamp(reset_time).strftime('%H:%M:%S')}")
-
- # Disable after 5 consecutive failures (increased from 3)
- if self.consecutive_failures >= 5:
- self.is_available = False
- self.rate_limit_reset_time = time.time() + 120 # 2 min cooldown
- logger.error(f"❌ {self.name} disabled (cooldown 2min)")
-
- def check_availability(self) -> bool:
- """Check if key is available"""
- # Check rate limit reset
- if self.rate_limit_reset_time and time.time() >= self.rate_limit_reset_time:
- self.is_available = True
- self.rate_limit_reset_time = None
- self.consecutive_failures = 0
- logger.info(f"✅ {self.name} recovered")
- return True
-
- if not self.is_available:
- return False
-
- return self.can_make_request()
-
- def get_stats(self) -> Dict:
- success_count = self.total_requests - self.total_failures
- success_rate = (success_count / max(self.total_requests, 1)) * 100
-
- return {
- "name": self.name,
- "total_requests": self.total_requests,
- "total_failures": self.total_failures,
- "success_rate": round(success_rate, 2),
- "is_available": self.check_availability(),
- "consecutive_failures": self.consecutive_failures,
- "current_rpm": len(self.request_times),
- "max_rpm": self.requests_per_minute,
- "time_since_last_request": round(time.time() - self.last_request_time, 1) if self.last_request_time else None
- }
- class MultiKeyLLMLoadBalancer:
- """Enhanced load balancer with smart rate limiting"""
-
- def __init__(self, api_keys: List[Dict[str, str]], strategy: str = "round_robin"):
- if not api_keys:
- raise ValueError("At least one API key required")
-
- self.keys = [APIKeyState(key=k['key'], name=k['name']) for k in api_keys]
- self.strategy = strategy
- self.current_index = 0
- self.lock = threading.Lock()
- self.total_requests = 0
- self.total_failures = 0
- self.global_last_request = 0
- self.min_global_interval = 0.5 # 500ms between ANY requests
-
- logger.info(f"🔑 Load balancer initialized: {len(self.keys)} keys, '{strategy}' strategy")
-
- def get_next_key(self) -> Optional[APIKeyState]:
- """Get next available key with global rate limiting"""
- with self.lock:
- # Enforce minimum global interval
- now = time.time()
- time_since_last = now - self.global_last_request
- if time_since_last < self.min_global_interval:
- wait_time = self.min_global_interval - time_since_last
- time.sleep(wait_time)
-
- if self.strategy == "least_loaded":
- return self._least_loaded_select()
- else:
- return self._round_robin_select()
-
- def _round_robin_select(self) -> Optional[APIKeyState]:
- """Round-robin with availability check"""
- attempts = 0
- total_keys = len(self.keys)
-
- while attempts < total_keys:
- key = self.keys[self.current_index]
- self.current_index = (self.current_index + 1) % total_keys
-
- if key.check_availability():
- return key
-
- attempts += 1
-
- return self._wait_for_available_key()
-
- def _least_loaded_select(self) -> Optional[APIKeyState]:
- """Select least loaded key"""
- available = [k for k in self.keys if k.check_availability()]
-
- if not available:
- return self._wait_for_available_key()
-
- available.sort(key=lambda k: (len(k.request_times), k.last_request_time))
- return available[0]
-
- def _wait_for_available_key(self, max_wait: float = 5.0) -> Optional[APIKeyState]:
- """Wait for next available key (with timeout)"""
- keys_with_reset = [k for k in self.keys if k.rate_limit_reset_time]
-
- if not keys_with_reset:
- # Check if any key just needs spacing
- now = time.time()
- for key in self.keys:
- if key.is_available:
- wait = key.min_request_interval - (now - key.last_request_time)
- if 0 < wait < max_wait:
- logger.info(f"⏳ Waiting {wait:.1f}s for {key.name}...")
- time.sleep(wait + 0.1)
- return key if key.check_availability() else None
- return None
-
- keys_with_reset.sort(key=lambda k: k.rate_limit_reset_time)
- next_key = keys_with_reset[0]
- wait = max(0, next_key.rate_limit_reset_time - time.time())
-
- if 0 < wait < max_wait:
- logger.info(f"⏳ Waiting {wait:.1f}s for {next_key.name}...")
- time.sleep(wait + 0.5)
- return next_key if next_key.check_availability() else None
-
- return None
-
- def mark_success(self, key: APIKeyState):
- with self.lock:
- key.mark_success()
- self.total_requests += 1
- self.global_last_request = time.time()
-
- def mark_failure(self, key: APIKeyState, is_rate_limit: bool = False, retry_after: Optional[int] = None):
- with self.lock:
- key.mark_failure(is_rate_limit, retry_after)
- self.total_failures += 1
-
- def get_stats(self) -> Dict:
- with self.lock:
- available_count = sum(1 for k in self.keys if k.check_availability())
- success_rate = ((self.total_requests - self.total_failures) / max(self.total_requests, 1)) * 100
-
- return {
- "total_keys": len(self.keys),
- "available_keys": available_count,
- "strategy": self.strategy,
- "total_requests": self.total_requests,
- "total_failures": self.total_failures,
- "success_rate": round(success_rate, 2),
- "keys": [k.get_stats() for k in self.keys]
- }
-
- def call_llm(self, payload: dict, api_url: str, max_retries: int = None) -> str:
- """Make LLM call with smart retry and failover"""
- if max_retries is None:
- max_retries = len(self.keys) * 3
-
- attempt = 0
- last_error = None
- keys_tried = set()
-
- while attempt < max_retries:
- key_state = self.get_next_key()
-
- if not key_state:
- if len(keys_tried) >= len(self.keys):
- # All keys tried, wait longer
- logger.warning(f"⏳ All keys exhausted. Waiting 3s...")
- time.sleep(3)
- keys_tried.clear()
-
- attempt += 1
- continue
-
- keys_tried.add(key_state.name)
-
- try:
- headers = {
- "Authorization": f"Bearer {key_state.key}",
- "Content-Type": "application/json"
- }
-
- logger.debug(f"🔑 {key_state.name} (attempt {attempt + 1}/{max_retries})")
-
- response = requests.post(
- api_url,
- headers=headers,
- json=payload,
- timeout=30
- )
-
- if response.status_code == 429:
- retry_after = int(response.headers.get('Retry-After', 65))
- self.mark_failure(key_state, is_rate_limit=True, retry_after=retry_after)
- attempt += 1
- time.sleep(1) # Brief pause before next key
- continue
-
- response.raise_for_status()
-
- # Success
- self.mark_success(key_state)
- content = response.json()["choices"][0]["message"]["content"]
- logger.debug(f"✅ Success via {key_state.name}")
- return content
-
- except requests.exceptions.HTTPError as e:
- if e.response and e.response.status_code == 429:
- retry_after = int(e.response.headers.get('Retry-After', 65))
- self.mark_failure(key_state, is_rate_limit=True, retry_after=retry_after)
- else:
- self.mark_failure(key_state)
- logger.error(f"❌ HTTP error {key_state.name}: {e}")
- last_error = e
- attempt += 1
- time.sleep(0.5)
-
- except Exception as e:
- self.mark_failure(key_state)
- logger.error(f"❌ Error {key_state.name}: {e}")
- last_error = e
- attempt += 1
- time.sleep(0.5)
-
- stats = self.get_stats()
- error_msg = (
- f"LLM failed after {max_retries} attempts. "
- f"Available: {stats['available_keys']}/{stats['total_keys']}. "
- f"Error: {last_error}"
- )
- logger.error(f"💥 {error_msg}")
- raise RuntimeError(error_msg)
- # Global instance
- _load_balancer: Optional[MultiKeyLLMLoadBalancer] = None
- _load_balancer_lock = threading.Lock()
- def get_llm_load_balancer() -> MultiKeyLLMLoadBalancer:
- """Get singleton load balancer"""
- global _load_balancer
-
- if _load_balancer is None:
- with _load_balancer_lock:
- if _load_balancer is None:
- api_keys = getattr(settings, 'GROQ_API_KEYS', None)
-
- if not api_keys:
- single_key = getattr(settings, 'GROQ_API_KEY', None)
- if single_key:
- api_keys = [{'key': single_key, 'name': 'groq_key_1'}]
-
- if not api_keys:
- raise ValueError("No GROQ API keys configured")
-
- strategy = getattr(settings, 'LLM_LOAD_BALANCER_STRATEGY', 'round_robin')
- _load_balancer = MultiKeyLLMLoadBalancer(api_keys, strategy=strategy)
-
- return _load_balancer
- def reset_load_balancer():
- """Reset load balancer"""
- global _load_balancer
- with _load_balancer_lock:
- _load_balancer = None
- def call_llm_with_load_balancer(payload: dict) -> str:
- """Drop-in replacement for _call_llm"""
- balancer = get_llm_load_balancer()
- api_url = getattr(settings, 'GROQ_API_URL')
- return balancer.call_llm(payload, api_url)
- def get_load_balancer_stats() -> Dict:
- """Get stats"""
- try:
- return get_llm_load_balancer().get_stats()
- except Exception as e:
- return {"error": str(e), "total_keys": 0, "available_keys": 0}
- # """
- # Ultra-Safe Sequential Load Balancer with Adaptive Rate Limiting
- # Guaranteed to work with strict API rate limits
- # """
- # import time
- # import threading
- # import requests
- # import logging
- # from typing import List, Dict, Optional
- # from dataclasses import dataclass
- # from datetime import datetime
- # from django.conf import settings
- # logger = logging.getLogger(__name__)
- # @dataclass
- # class APIKeyState:
- # """Simple key state tracker"""
- # key: str
- # name: str
- # last_used: float = 0
- # total_requests: int = 0
- # total_failures: int = 0
- # consecutive_failures: int = 0
- # disabled_until: float = 0
-
- # def is_available(self) -> bool:
- # """Check if key is available RIGHT NOW"""
- # now = time.time()
-
- # # Check if disabled
- # if self.disabled_until > now:
- # return False
-
- # # Require 5 seconds between requests on SAME key
- # if self.last_used > 0:
- # elapsed = now - self.last_used
- # if elapsed < 5.0:
- # return False
-
- # return True
-
- # def get_wait_time(self) -> float:
- # """How long until this key is available?"""
- # now = time.time()
-
- # if self.disabled_until > now:
- # return self.disabled_until - now
-
- # if self.last_used > 0:
- # elapsed = now - self.last_used
- # if elapsed < 5.0:
- # return 5.0 - elapsed
-
- # return 0
-
- # def mark_success(self):
- # self.last_used = time.time()
- # self.total_requests += 1
- # self.consecutive_failures = 0
- # self.disabled_until = 0
- # logger.info(f"✅ {self.name} success (total: {self.total_requests})")
-
- # def mark_failure(self, is_rate_limit: bool = False):
- # self.last_used = time.time()
- # self.total_requests += 1
- # self.total_failures += 1
- # self.consecutive_failures += 1
-
- # if is_rate_limit:
- # # Rate limit: wait 90 seconds
- # self.disabled_until = time.time() + 90
- # logger.error(f"🚫 {self.name} RATE LIMITED → disabled for 90s")
- # elif self.consecutive_failures >= 2:
- # # 2 failures: wait 60 seconds
- # self.disabled_until = time.time() + 60
- # logger.error(f"❌ {self.name} FAILED {self.consecutive_failures}x → disabled for 60s")
- # class UltraSafeLoadBalancer:
- # """
- # Ultra-conservative load balancer
- # - Minimum 5 seconds between requests on same key
- # - Minimum 1 second between ANY requests (global)
- # - Automatic waiting for key availability
- # - No parallel requests
- # """
-
- # def __init__(self, api_keys: List[Dict[str, str]]):
- # if not api_keys:
- # raise ValueError("At least one API key required")
-
- # self.keys = [APIKeyState(key=k['key'], name=k['name']) for k in api_keys]
- # self.current_index = 0
- # self.lock = threading.Lock()
- # self.last_global_request = 0
- # self.min_global_interval = 1.0 # 1 second between ANY requests
-
- # logger.info(f"🔑 Ultra-safe balancer: {len(self.keys)} keys, 5s per-key interval, 1s global interval")
-
- # def _enforce_global_rate_limit(self):
- # """Ensure minimum time between ANY requests"""
- # with self.lock:
- # if self.last_global_request > 0:
- # elapsed = time.time() - self.last_global_request
- # if elapsed < self.min_global_interval:
- # wait = self.min_global_interval - elapsed
- # logger.debug(f"⏱️ Global rate limit: waiting {wait:.1f}s")
- # time.sleep(wait)
- # self.last_global_request = time.time()
-
- # def get_next_key(self, max_wait: float = 30.0) -> Optional[APIKeyState]:
- # """Get next available key, waiting if necessary"""
- # start_time = time.time()
-
- # while (time.time() - start_time) < max_wait:
- # with self.lock:
- # # Try round-robin
- # for _ in range(len(self.keys)):
- # key = self.keys[self.current_index]
- # self.current_index = (self.current_index + 1) % len(self.keys)
-
- # if key.is_available():
- # return key
-
- # # No keys available - find the one that will be ready soonest
- # wait_times = [(k.get_wait_time(), k) for k in self.keys]
- # wait_times.sort()
-
- # if wait_times:
- # min_wait, next_key = wait_times[0]
-
- # if min_wait > 0 and min_wait < 15:
- # logger.info(f"⏳ All keys busy. Waiting {min_wait:.1f}s for {next_key.name}...")
- # time.sleep(min_wait + 0.2)
- # continue
-
- # time.sleep(0.5)
-
- # # Timeout
- # logger.error(f"❌ No keys available after {max_wait}s wait")
- # return None
-
- # def call_llm(self, payload: dict, api_url: str, retry_count: int = 0) -> str:
- # """
- # Make LLM call with ONE key
- # Retries with SAME key after waiting if it fails
- # """
- # # Enforce global rate limit FIRST
- # self._enforce_global_rate_limit()
-
- # # Get available key
- # key = self.get_next_key(max_wait=30.0)
-
- # if not key:
- # raise RuntimeError("No API keys available after 30s wait")
-
- # try:
- # headers = {
- # "Authorization": f"Bearer {key.key}",
- # "Content-Type": "application/json"
- # }
-
- # logger.info(f"🔑 Request via {key.name} (retry: {retry_count})")
-
- # response = requests.post(
- # api_url,
- # headers=headers,
- # json=payload,
- # timeout=45
- # )
-
- # # Check rate limit
- # if response.status_code == 429:
- # key.mark_failure(is_rate_limit=True)
-
- # if retry_count < 2:
- # logger.warning(f"⚠️ Rate limit hit, retrying with different key...")
- # time.sleep(2)
- # return self.call_llm(payload, api_url, retry_count + 1)
- # else:
- # raise RuntimeError(f"Rate limit on {key.name} after {retry_count} retries")
-
- # # Check for other errors
- # response.raise_for_status()
-
- # # Success!
- # key.mark_success()
- # content = response.json()["choices"][0]["message"]["content"]
- # return content
-
- # except requests.exceptions.HTTPError as e:
- # if e.response and e.response.status_code == 429:
- # key.mark_failure(is_rate_limit=True)
- # else:
- # key.mark_failure(is_rate_limit=False)
-
- # if retry_count < 2:
- # logger.warning(f"⚠️ HTTP error, retrying... ({e})")
- # time.sleep(2)
- # return self.call_llm(payload, api_url, retry_count + 1)
- # else:
- # raise RuntimeError(f"HTTP error after {retry_count} retries: {e}")
-
- # except requests.exceptions.Timeout as e:
- # key.mark_failure(is_rate_limit=False)
-
- # if retry_count < 1:
- # logger.warning(f"⏱️ Timeout, retrying...")
- # time.sleep(3)
- # return self.call_llm(payload, api_url, retry_count + 1)
- # else:
- # raise RuntimeError(f"Timeout after {retry_count} retries: {e}")
-
- # except Exception as e:
- # key.mark_failure(is_rate_limit=False)
- # raise RuntimeError(f"Unexpected error with {key.name}: {e}")
-
- # def get_stats(self) -> Dict:
- # with self.lock:
- # available = sum(1 for k in self.keys if k.is_available())
- # total_reqs = sum(k.total_requests for k in self.keys)
- # total_fails = sum(k.total_failures for k in self.keys)
- # success_rate = ((total_reqs - total_fails) / max(total_reqs, 1)) * 100
-
- # return {
- # "total_keys": len(self.keys),
- # "available_keys": available,
- # "total_requests": total_reqs,
- # "total_failures": total_fails,
- # "success_rate": round(success_rate, 2),
- # "keys": [
- # {
- # "name": k.name,
- # "total_requests": k.total_requests,
- # "total_failures": k.total_failures,
- # "consecutive_failures": k.consecutive_failures,
- # "is_available": k.is_available(),
- # "wait_time": round(k.get_wait_time(), 1)
- # }
- # for k in self.keys
- # ]
- # }
- # # Global singleton
- # _balancer: Optional[UltraSafeLoadBalancer] = None
- # _balancer_lock = threading.Lock()
- # def get_llm_load_balancer() -> UltraSafeLoadBalancer:
- # """Get singleton balancer"""
- # global _balancer
-
- # if _balancer is None:
- # with _balancer_lock:
- # if _balancer is None:
- # api_keys = getattr(settings, 'GROQ_API_KEYS', None)
-
- # if not api_keys:
- # single_key = getattr(settings, 'GROQ_API_KEY', None)
- # if single_key:
- # api_keys = [{'key': single_key, 'name': 'groq_key_1'}]
-
- # if not api_keys:
- # raise ValueError("No GROQ_API_KEYS configured in settings")
-
- # _balancer = UltraSafeLoadBalancer(api_keys)
-
- # return _balancer
- # def reset_load_balancer():
- # """Reset balancer (for testing)"""
- # global _balancer
- # with _balancer_lock:
- # _balancer = None
- # logger.info("🔄 Balancer reset")
- # def call_llm_with_load_balancer(payload: dict) -> str:
- # """
- # Call LLM with ultra-safe rate limiting
- # This is the drop-in replacement for services.py
- # """
- # balancer = get_llm_load_balancer()
- # api_url = getattr(settings, 'GROQ_API_URL')
- # return balancer.call_llm(payload, api_url)
- # def get_load_balancer_stats() -> Dict:
- # """Get balancer stats"""
- # try:
- # return get_llm_load_balancer().get_stats()
- # except Exception as e:
- # return {"error": str(e)}
|