pytorch 1.6
이상에서는 torch.cuda.amp
패키지를 이용해 편리하게 AMP를 적용할 수 있다.
* AMP는 float32
와 float16
을 함께 적절히 사용함으로써 빠른 연산 속도와 효율적인 메모리 활용이 가능하도록 해준다.
내 딥러닝 모델에 AMP를 적용하려면 앞서 말한 패키지에서 두 가지를 import 해야 한다.
from torch.cuda.amp import autocast, GradScaler
우선 GradScaler
는 mixed-precision으로 학습을 진행할 때 gradient scaling을 수행하는 역할을 한다.
즉 backpropagation 중 gradient가 너무 작아지는 것을 방지하여 floating-point 형식으로 정확하게 표현될 수 없는 숫자가 되는 문제를 방지한다.
Mixed-precision은 single-precision format으로 master weights를 저장하고, half-precision format에서 연산을 수행한다.
* mixed precision 학습 시에는 single-precision, half-precision의 두 가지 weights을 사용하는데, 이때 float32(single-precision)
weights를 master weights라고 한다.
특히 half-precision 형식의 제한된 표현 범위로 인해 weights 또는 gradient에 underflow나 overflow가 발생할 수 있다.
underflow 문제를 방지하기 위해, gradient scaling은 backward pass 전에 네트워크의 loss에 scale factor를 곱한다.
그리고 gradient가 계산된 이후에는 scale factor로 다시 나눠서 정확한 크기로 돌려놓는다.
def train():
scaler = GradScaler() # initialize a scaler
for epoch in range(epochs):
for batch in enumerate(train_dataloader):
...
optimizer.zero_grad()
with autocast():
output = model(image)
loss = loss_fn(output, target)
scaler.scale(loss).backward() # scale the loss and backward
scaler.step(optimizer) # unscale the gradients, update the weights
scaler.update() # prepare the scale for the next iteration
...
scheduler.step()
다음 autocast
는 mixed-precision 학습을 위한 context manager이다.
위 코드에서와 같이 필요한 부분만 half-precision으로 실행할 수 있도록 해준다.
수동으로 연산마다 어떤 precision을 적용할지 일일이 지정하지 않아도 되므로 매우 편리하다.
* 예를 들어 linear, convolutional layers 같은 경우 half-precision으로 연산하면 큰 오차 없이 성능 개선이 가능하다.
그러나 reduction 같은 연산의 경우, single-precision 연산이 필요할 수 있다.
당연한 말이지만, mixed-precision으로 학습했다고 해서 모델을 저장했을 때 파일 크기가 줄어들지는 않는다.
master weights가 저장되기 때문이다.
만약 모델 파일 용량을 줄이고 싶다면 양자화(Quantization)를 해야 한다.
양자화란 weights의 precision을 줄이는 것과는 별개의 프로세스이며, 주로 학습 후에 수행된다.
양자화를 수행할 경우 단지 모델 파일 크기만 줄어드는 것이 아니라, 추론 시 latency를 개선할 수 있다.
이에 대해서는 다른 글에서 다루기로 한다.
'공부하며 성장하기 > 인공지능 AI' 카테고리의 다른 글
Object Detection에서 mAP_0.5와 mAP_0.5:0.95의 의미 (0) | 2023.12.14 |
---|---|
DeeplabV3+ 모델 전이 학습(transfer learning) 쉽게 구현하기 (0) | 2023.06.15 |
K-means Clustering (0) | 2023.05.21 |
Lasso regression (0) | 2023.05.21 |
Random Forest (0) | 2023.05.20 |