utils.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. import math
  2. def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
  3. """Decay the learning rate"""
  4. lr = (init_lr - min_lr) * 0.5 * (1. + math.cos(math.pi * epoch / max_epoch)) + min_lr
  5. for param_group in optimizer.param_groups:
  6. param_group['lr'] = lr
  7. def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
  8. """Warmup the learning rate"""
  9. lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max_step)
  10. for param_group in optimizer.param_groups:
  11. param_group['lr'] = lr
  12. def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
  13. """Decay the learning rate"""
  14. lr = max(min_lr, init_lr * (decay_rate**epoch))
  15. for param_group in optimizer.param_groups:
  16. param_group['lr'] = lr
  17. import numpy as np
  18. import io
  19. import os
  20. import time
  21. from collections import defaultdict, deque
  22. import datetime
  23. import torch
  24. import torch.distributed as dist
  25. class SmoothedValue(object):
  26. """Track a series of values and provide access to smoothed values over a
  27. window or the global series average.
  28. """
  29. def __init__(self, window_size=20, fmt=None):
  30. if fmt is None:
  31. fmt = "{median:.4f} ({global_avg:.4f})"
  32. self.deque = deque(maxlen=window_size)
  33. self.total = 0.0
  34. self.count = 0
  35. self.fmt = fmt
  36. def update(self, value, n=1):
  37. self.deque.append(value)
  38. self.count += n
  39. self.total += value * n
  40. def synchronize_between_processes(self):
  41. """
  42. Warning: does not synchronize the deque!
  43. """
  44. if not is_dist_avail_and_initialized():
  45. return
  46. t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
  47. dist.barrier()
  48. dist.all_reduce(t)
  49. t = t.tolist()
  50. self.count = int(t[0])
  51. self.total = t[1]
  52. @property
  53. def median(self):
  54. d = torch.tensor(list(self.deque))
  55. return d.median().item()
  56. @property
  57. def avg(self):
  58. d = torch.tensor(list(self.deque), dtype=torch.float32)
  59. return d.mean().item()
  60. @property
  61. def global_avg(self):
  62. return self.total / self.count
  63. @property
  64. def max(self):
  65. return max(self.deque)
  66. @property
  67. def value(self):
  68. return self.deque[-1]
  69. def __str__(self):
  70. return self.fmt.format(
  71. median=self.median,
  72. avg=self.avg,
  73. global_avg=self.global_avg,
  74. max=self.max,
  75. value=self.value)
  76. class MetricLogger(object):
  77. def __init__(self, delimiter="\t"):
  78. self.meters = defaultdict(SmoothedValue)
  79. self.delimiter = delimiter
  80. def update(self, **kwargs):
  81. for k, v in kwargs.items():
  82. if isinstance(v, torch.Tensor):
  83. v = v.item()
  84. assert isinstance(v, (float, int))
  85. self.meters[k].update(v)
  86. def __getattr__(self, attr):
  87. if attr in self.meters:
  88. return self.meters[attr]
  89. if attr in self.__dict__:
  90. return self.__dict__[attr]
  91. raise AttributeError("'{}' object has no attribute '{}'".format(
  92. type(self).__name__, attr))
  93. def __str__(self):
  94. loss_str = []
  95. for name, meter in self.meters.items():
  96. loss_str.append(
  97. "{}: {}".format(name, str(meter))
  98. )
  99. return self.delimiter.join(loss_str)
  100. def global_avg(self):
  101. loss_str = []
  102. for name, meter in self.meters.items():
  103. loss_str.append(
  104. "{}: {:.4f}".format(name, meter.global_avg)
  105. )
  106. return self.delimiter.join(loss_str)
  107. def synchronize_between_processes(self):
  108. for meter in self.meters.values():
  109. meter.synchronize_between_processes()
  110. def add_meter(self, name, meter):
  111. self.meters[name] = meter
  112. def log_every(self, iterable, print_freq, header=None):
  113. i = 0
  114. if not header:
  115. header = ''
  116. start_time = time.time()
  117. end = time.time()
  118. iter_time = SmoothedValue(fmt='{avg:.4f}')
  119. data_time = SmoothedValue(fmt='{avg:.4f}')
  120. space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
  121. log_msg = [
  122. header,
  123. '[{0' + space_fmt + '}/{1}]',
  124. 'eta: {eta}',
  125. '{meters}',
  126. 'time: {time}',
  127. 'data: {data}'
  128. ]
  129. if torch.cuda.is_available():
  130. log_msg.append('max mem: {memory:.0f}')
  131. log_msg = self.delimiter.join(log_msg)
  132. MB = 1024.0 * 1024.0
  133. for obj in iterable:
  134. data_time.update(time.time() - end)
  135. yield obj
  136. iter_time.update(time.time() - end)
  137. if i % print_freq == 0 or i == len(iterable) - 1:
  138. eta_seconds = iter_time.global_avg * (len(iterable) - i)
  139. eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
  140. if torch.cuda.is_available():
  141. print(log_msg.format(
  142. i, len(iterable), eta=eta_string,
  143. meters=str(self),
  144. time=str(iter_time), data=str(data_time),
  145. memory=torch.cuda.max_memory_allocated() / MB))
  146. else:
  147. print(log_msg.format(
  148. i, len(iterable), eta=eta_string,
  149. meters=str(self),
  150. time=str(iter_time), data=str(data_time)))
  151. i += 1
  152. end = time.time()
  153. total_time = time.time() - start_time
  154. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  155. print('{} Total time: {} ({:.4f} s / it)'.format(
  156. header, total_time_str, total_time / len(iterable)))
  157. class AttrDict(dict):
  158. def __init__(self, *args, **kwargs):
  159. super(AttrDict, self).__init__(*args, **kwargs)
  160. self.__dict__ = self
  161. def compute_acc(logits, label, reduction='mean'):
  162. ret = (torch.argmax(logits, dim=1) == label).float()
  163. if reduction == 'none':
  164. return ret.detach()
  165. elif reduction == 'mean':
  166. return ret.mean().item()
  167. def compute_n_params(model, return_str=True):
  168. tot = 0
  169. for p in model.parameters():
  170. w = 1
  171. for x in p.shape:
  172. w *= x
  173. tot += w
  174. if return_str:
  175. if tot >= 1e6:
  176. return '{:.1f}M'.format(tot / 1e6)
  177. else:
  178. return '{:.1f}K'.format(tot / 1e3)
  179. else:
  180. return tot
  181. def setup_for_distributed(is_master):
  182. """
  183. This function disables printing when not in master process
  184. """
  185. import builtins as __builtin__
  186. builtin_print = __builtin__.print
  187. def print(*args, **kwargs):
  188. force = kwargs.pop('force', False)
  189. if is_master or force:
  190. builtin_print(*args, **kwargs)
  191. __builtin__.print = print
  192. def is_dist_avail_and_initialized():
  193. if not dist.is_available():
  194. return False
  195. if not dist.is_initialized():
  196. return False
  197. return True
  198. def get_world_size():
  199. if not is_dist_avail_and_initialized():
  200. return 1
  201. return dist.get_world_size()
  202. def get_rank():
  203. if not is_dist_avail_and_initialized():
  204. return 0
  205. return dist.get_rank()
  206. def is_main_process():
  207. return get_rank() == 0
  208. def save_on_master(*args, **kwargs):
  209. if is_main_process():
  210. torch.save(*args, **kwargs)
  211. def init_distributed_mode(args):
  212. if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
  213. args.rank = int(os.environ["RANK"])
  214. args.world_size = int(os.environ['WORLD_SIZE'])
  215. args.gpu = int(os.environ['LOCAL_RANK'])
  216. elif 'SLURM_PROCID' in os.environ:
  217. args.rank = int(os.environ['SLURM_PROCID'])
  218. args.gpu = args.rank % torch.cuda.device_count()
  219. else:
  220. print('Not using distributed mode')
  221. args.distributed = False
  222. return
  223. args.distributed = True
  224. torch.cuda.set_device(args.gpu)
  225. args.dist_backend = 'nccl'
  226. print('| distributed init (rank {}, word {}): {}'.format(
  227. args.rank, args.world_size, args.dist_url), flush=True)
  228. torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
  229. world_size=args.world_size, rank=args.rank)
  230. torch.distributed.barrier()
  231. setup_for_distributed(args.rank == 0)