AI
efficientnet + amp BCE
cepiloth
2023. 7. 6. 09:25
728x90
반응형
Efficientnet 모델 구성
from torch import nn
import timm
import torch
class EffNet(nn.Module):
def __init__(self, backbone, n_out, is_sigmoid):
super(EffNet, self).__init__()
self.model = timm.create_model(model_name=backbone, pretrained=True)
self.model.classifier = nn.Linear(self.model.classifier.in_features, n_out)
self.is_sigmoid = is_sigmoid
def forward(self, x):
x = self.model(x)
if self.is_sigmoid:
x = nn.Sigmoid()(x)
return x
amp 설치
git clone https://www.github.com/nvidia/apex
cd apex
python setup.py install
amp 구성
if config['TRAINER']['amp'] == True:
from apex import amp
amp.register_float_function(torch, 'sigmoid')
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
728x90
반응형