Tensorflow 102 - 이미지 분류(CNN)

다섯번째 딥러닝 완성 - LeNet

수업소개

LeNet-5 모델을 완성하고, CIFAR10 이미지 학습을 진행합니다. 

 

강의

실습

소스코드

colab |  backend.ai

###########################
# 라이브러리 사용
import tensorflow as tf
import pandas as pd

###########################
# 데이터를 준비합니다. 
(독립, 종속), _ = tf.keras.datasets.cifar10.load_data()
종속 = pd.get_dummies(종속.reshape(50000))
print(독립.shape, 종속.shape)

###########################
# 모델을 완성합니다. 
X = tf.keras.layers.Input(shape=[32, 32, 3])

H = tf.keras.layers.Conv2D(6, kernel_size=5, activation='swish')(X)
H = tf.keras.layers.MaxPool2D()(H)

H = tf.keras.layers.Conv2D(16, kernel_size=5, activation='swish')(H)
H = tf.keras.layers.MaxPool2D()(H)

H = tf.keras.layers.Flatten()(H)
H = tf.keras.layers.Dense(120, activation='swish')(H)
H = tf.keras.layers.Dense(84, activation='swish')(H)
Y = tf.keras.layers.Dense(10, activation='softmax')(H)

model = tf.keras.models.Model(X, Y)
model.compile(loss='categorical_crossentropy', metrics='accuracy')

###########################
# 모델을 학습하고
model.fit(독립, 종속, epochs=10)

###########################
# 모델을 이용합니다. 
pred = model.predict(독립[0:5])
pd.DataFrame(pred).round(2)

# 정답 확인
종속[0:5]

# 모델 확인
model.summary()

 

댓글

댓글 본문
  1. 솔나무
  2. songji
  3. Hotbrains
    완료~ 감사합니다.
  4. 이덕규
    완료!
  5. 정효빈
    완료했습니다!! batch도 적용해보도록할께요!!
  6. 이숙번
    dropout 기법만으로도 충분히 좋은 결과를 얻을 수는 있습니다.
    하지만, CiFAR10 이미지에서 90% 이상의 결과를 위해서는
    data augmentation을 포함하여 추가적인 regularization 기법들을 활용하셔야 할 거에요.
    대화보기
    • 답글 감사합니다!
      혹시나 해서 모델 피팅시 batch_size 값만 150으로 줘봤는데
      50 epochs 부터 99% 를 넘는 accuracy가 나오네요.

      그런데 이 정확도는 trainset 에 대해서만 측정된 정확도였네요..
      trainset과 testset을 분리한 후 testset에 대해 evalute한 결과 처참한 정확도..(50%대)가 나오네요
      오버피팅의 문제가 역시 있었나 봅니다.

      cifar10의 경우 data augumentation 이 정확도 향상을 위해 꼭필요한가요?
      아니면 dropout 같은 모델의 기법만으로도 90% 이상의 정확도가 나오는지 궁금합니다
      대화보기
      • 이선비
        현재 상황은 학습이 충분히 되지 않는 모델이 완성된 것입니다. underfit 상태라고 할 수 있죠.
        underfit인 이유는 여러가지가 있으니, 이유를 따로 공부해 보시면 좋을 것 같고요.

        underfit의 쉬운 해결을 위해 101 수업의 '모델을 위한 팁'을 수업을 준비해 두었습니다.
        BatchNormalization 기법을 사용해 보세요. 훨씬 학습이 잘 진행되는 것이 보일 겁니다.
        normalization 기법도 여러가지가 있는데, underfit을 해결하기 위한 여러 방법들중에 매우 효과적인 방법이니
        이 또한 추가로 공부해 보시길 권장드립니다. :)
        대화보기
        • 이선비
          95%가 한계치가 아닐거에요.
          어느 정도 모델 학습을 하실 수 있으시다면,
          "학습 데이터"에 대해서 100% 피팅을 목표로
          여러가지 기법을 공부하고 모델을 만들어 보세요.

          준비한 데이터에 대해 100 피팅을 할 수 있는 능력이 있다면,
          여러 overfitting을 방지하는 기법을 적용하여 쓸만한 모델을 완성할 수 있게 될 것입니다.

          :)
          대화보기
          • nann
            완료
          • cifar10 의 경우, 이전에 101에서 팁으로 알려주신 BatchNormalization을 적용해서 Activation Layer와 Dense Layer를 분리하는 것만으로도 정확도가 유의미하게 올라갑니다(epochs=10 기준 정학도 약 80%). 정확도를 더 올리고 싶다면 Dropout을 시킨다거나, 이미지 전처리를 더 시도하는 등의 노력이 필요할 것 같네요. 공부가 더 필요하네요 !

            ->별다른 추가 없이도 epochs 20만 해도 90퍼센트에 근접한 정확도가나오네요.. 참 똑똑하다..
            -> epochs 50-60 정도에서 정확도 95퍼센트가 한계치인듯합니다.
          • 고고고고
            완료
          • 소야
            @.@ 신기하네요.
          • hoddigi
            완료
          • 여어엉
            완료
          • 제사마
            Done
          • John
            완료
          • 과거로의여행
            아~~~ 또 졸린다~~~ ㅠㅠ.
            그래도 애해는 한 것 같은데...
            저녁에 또 들어보면 좀 더 이해가 되겠지요~~~ㅇ
            아자!아자!화이팅!!!
          • VIBOT
            ok
          • 완료
          • CrashOverride
            4일차 완료
          • briliant6424
            21/01/10 완료
          • 행여
            감사합니다~!♡
          • 유니엘
            완료.
            좋은 강의 감사합니다.
          • ukmadang
            대단합니다 완료!
          • younghwani
            댓글보고 BatchNormalization 적용해봤는데 훨씬 학습이 잘 되는 것 같아요!!
          • 강해리
            와 max polling이랑 네트워크 보고 어떻게 쌓는지 세세하게 알려주시니까 너무 이해하기 편하네요!!! ㅠㅠ 감사합니다 꿀같은 강좌에요!
          • 이선비
            텐서플로우 101 수업의 '모델을 위한 팁(https://opentutorials.org......242)' 내용을 참고해 보세요. batchNormalization 만 적용해도. 상당히 학습을 할 수 있고, 모델 사이즈까지 늘리면 (필터 수 조정, 노드 수 조정) 학습이 잘 될 겁니다. :)
          • noahhan
            약 130번정도 돌리면 0.65 ~ 0.68? 그정도 accuracy 나오네요
            어.......? 근데 190번째 돌리니 오히려 0.58대로 감소 ㅠㅠ
          • noahhan
            확실히 rgb 값까지 첨부된 파일이다보니 학습이 다소 느리네요
          버전 관리
          이선비
          현재 버전
          선택 버전
          graphittie 자세히 보기