U-Net과 유사한 아키텍처를 사용한 이미지 세그멘테이션 (Image segmentation with a U-Net-like architecture)

목차

  1. 데이터 다운로드
  2. 입력 이미지와 타겟 세그멘테이션 마스크 경로 준비
  3. 하나의 입력 이미지와 해당 세그멘테이션 마스크는 어떻게 생겼나요?
  4. 데이터 배치를 로드하고 벡터화하기 위한 데이터 세트 준비
  5. U-Net Xception 스타일 모델 준비
  6. 검증 분할을 별도로 설정
  7. 모델 트레이닝
  8. 예측 시각화

저자: fchollet
생성일: 2019/03/20
최종편집일: 2020/04/20
설명: Oxford Pets 데이터세트를 사용하여 처음부터 트레이닝된 이미지 세그멘테이션 모델입니다.

Colab에서 보기 GitHub 소스

ⓘ 이 예제는 Keras 3을 사용합니다.


데이터 다운로드

!!wget https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz
!!wget https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz
!
!curl -O https://thor.robots.ox.ac.uk/datasets/pets/images.tar.gz
!curl -O https://thor.robots.ox.ac.uk/datasets/pets/annotations.tar.gz
!
!tar -xf images.tar.gz
!tar -xf annotations.tar.gz
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  755M  100  755M    0     0  21.3M      0  0:00:35  0:00:35 --:--:-- 22.2M
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 18.2M  100 18.2M    0     0  7977k      0  0:00:02  0:00:02 --:--:-- 7974k

입력 이미지와 타겟 세그멘테이션 마스크 경로 준비

import os

input_dir = "images/"
target_dir = "annotations/trimaps/"
img_size = (160, 160)
num_classes = 3
batch_size = 32

input_img_paths = sorted(
    [
        os.path.join(input_dir, fname)
        for fname in os.listdir(input_dir)
        if fname.endswith(".jpg")
    ]
)
target_img_paths = sorted(
    [
        os.path.join(target_dir, fname)
        for fname in os.listdir(target_dir)
        if fname.endswith(".png") and not fname.startswith(".")
    ]
)

print("Number of samples:", len(input_img_paths))

for input_path, target_path in zip(input_img_paths[:10], target_img_paths[:10]):
    print(input_path, "|", target_path)
Number of samples: 7390
images/Abyssinian_1.jpg | annotations/trimaps/Abyssinian_1.png
images/Abyssinian_10.jpg | annotations/trimaps/Abyssinian_10.png
images/Abyssinian_100.jpg | annotations/trimaps/Abyssinian_100.png
images/Abyssinian_101.jpg | annotations/trimaps/Abyssinian_101.png
images/Abyssinian_102.jpg | annotations/trimaps/Abyssinian_102.png
images/Abyssinian_103.jpg | annotations/trimaps/Abyssinian_103.png
images/Abyssinian_104.jpg | annotations/trimaps/Abyssinian_104.png
images/Abyssinian_105.jpg | annotations/trimaps/Abyssinian_105.png
images/Abyssinian_106.jpg | annotations/trimaps/Abyssinian_106.png
images/Abyssinian_107.jpg | annotations/trimaps/Abyssinian_107.png

하나의 입력 이미지와 해당 세그멘테이션 마스크는 어떻게 생겼나요?

from IPython.display import Image, display
from keras.utils import load_img
from PIL import ImageOps

# 입력 이미지 #7 표시
display(Image(filename=input_img_paths[9]))

# 해당 대상의 자동 대비 버전 표시 (픽셀별 카테고리)
img = ImageOps.autocontrast(load_img(target_img_paths[9]))
display(img)

jpeg

png


데이터 배치를 로드하고 벡터화하기 위한 데이터 세트 준비

import keras
import numpy as np
from tensorflow import data as tf_data
from tensorflow import image as tf_image
from tensorflow import io as tf_io


def get_dataset(
    batch_size,
    img_size,
    input_img_paths,
    target_img_paths,
    max_dataset_len=None,
):
    """TF 데이터셋을 반환합니다."""

    def load_img_masks(input_img_path, target_img_path):
        input_img = tf_io.read_file(input_img_path)
        input_img = tf_io.decode_png(input_img, channels=3)
        input_img = tf_image.resize(input_img, img_size)
        input_img = tf_image.convert_image_dtype(input_img, "float32")

        target_img = tf_io.read_file(target_img_path)
        target_img = tf_io.decode_png(target_img, channels=1)
        target_img = tf_image.resize(target_img, img_size, method="nearest")
        target_img = tf_image.convert_image_dtype(target_img, "uint8")

        # 실제 라벨은 1, 2, 3입니다. 1을 빼면 0, 1, 2가 됩니다.
        target_img -= 1
        return input_img, target_img

    # 더 빠른 디버깅을 위해, 데이터 크기를 제한합니다.
    if max_dataset_len:
        input_img_paths = input_img_paths[:max_dataset_len]
        target_img_paths = target_img_paths[:max_dataset_len]
    dataset = tf_data.Dataset.from_tensor_slices((input_img_paths, target_img_paths))
    dataset = dataset.map(load_img_masks, num_parallel_calls=tf_data.AUTOTUNE)
    return dataset.batch(batch_size)

U-Net Xception 스타일 모델 준비

from keras import layers


def get_model(img_size, num_classes):
    inputs = keras.Input(shape=img_size + (3,))

    ### [네트워크 전반: 입력 다운샘플링] ###

    # 진입 블록
    x = layers.Conv2D(32, 3, strides=2, padding="same")(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    previous_block_activation = x  # 옆으로가는 residual 설정

    # 블록 1, 2, 3은 피처 깊이를 제외하면 동일합니다.
    for filters in [64, 128, 256]:
        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.MaxPooling2D(3, strides=2, padding="same")(x)

        # residual 프로젝션
        residual = layers.Conv2D(filters, 1, strides=2, padding="same")(
            previous_block_activation
        )
        x = layers.add([x, residual])  # 다시 residual 합산
        previous_block_activation = x  # 다음 옆으로가는 residual 설정

    ### [네트워크의 후반부: 입력 업샘플링] ###

    for filters in [256, 128, 64, 32]:
        x = layers.Activation("relu")(x)
        x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.UpSampling2D(2)(x)

        # residual 프로젝션
        residual = layers.UpSampling2D(2)(previous_block_activation)
        residual = layers.Conv2D(filters, 1, padding="same")(residual)
        x = layers.add([x, residual])  # 다시 residual 합산
        previous_block_activation = x  # 다음 옆으로가는 residual 설정

    # 픽셀별 분류 레이어 추가
    outputs = layers.Conv2D(num_classes, 3, activation="softmax", padding="same")(x)

    # 모델 정의
    model = keras.Model(inputs, outputs)
    return model

# 모델 빌드
model = get_model(img_size, num_classes)
model.summary()
Model: "functional_1"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)        ┃ Output Shape      ┃ Param # ┃ Connected to         ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
│ input_layer         │ (None, 160, 160,  │       0 │ -                    │
│ (InputLayer)        │ 3)                │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv2d (Conv2D)     │ (None, 80, 80,    │     896 │ input_layer[0][0]    │
│                     │ 32)               │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalization │ (None, 80, 80,    │     128 │ conv2d[0][0]         │
│ (BatchNormalizatio… │ 32)               │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation          │ (None, 80, 80,    │       0 │ batch_normalization… │
│ (Activation)        │ 32)               │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_1        │ (None, 80, 80,    │       0 │ activation[0][0]     │
│ (Activation)        │ 32)               │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ separable_conv2d    │ (None, 80, 80,    │   2,400 │ activation_1[0][0]   │
│ (SeparableConv2D)   │ 64)               │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 80, 80,    │     256 │ separable_conv2d[0]… │
│ (BatchNormalizatio… │ 64)               │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_2        │ (None, 80, 80,    │       0 │ batch_normalization… │
│ (Activation)        │ 64)               │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ separable_conv2d_1  │ (None, 80, 80,    │   4,736 │ activation_2[0][0]   │
│ (SeparableConv2D)   │ 64)               │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 80, 80,    │     256 │ separable_conv2d_1[… │
│ (BatchNormalizatio… │ 64)               │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ max_pooling2d       │ (None, 40, 40,    │       0 │ batch_normalization… │
│ (MaxPooling2D)      │ 64)               │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv2d_1 (Conv2D)   │ (None, 40, 40,    │   2,112 │ activation[0][0]     │
│                     │ 64)               │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ add (Add)           │ (None, 40, 40,    │       0 │ max_pooling2d[0][0], │
│                     │ 64)               │         │ conv2d_1[0][0]       │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_3        │ (None, 40, 40,    │       0 │ add[0][0]            │
│ (Activation)        │ 64)               │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ separable_conv2d_2  │ (None, 40, 40,    │   8,896 │ activation_3[0][0]   │
│ (SeparableConv2D)   │ 128)              │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 40, 40,    │     512 │ separable_conv2d_2[… │
│ (BatchNormalizatio… │ 128)              │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_4        │ (None, 40, 40,    │       0 │ batch_normalization… │
│ (Activation)        │ 128)              │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ separable_conv2d_3  │ (None, 40, 40,    │  17,664 │ activation_4[0][0]   │
│ (SeparableConv2D)   │ 128)              │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 40, 40,    │     512 │ separable_conv2d_3[… │
│ (BatchNormalizatio… │ 128)              │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ max_pooling2d_1     │ (None, 20, 20,    │       0 │ batch_normalization… │
│ (MaxPooling2D)      │ 128)              │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv2d_2 (Conv2D)   │ (None, 20, 20,    │   8,320 │ add[0][0]            │
│                     │ 128)              │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ add_1 (Add)         │ (None, 20, 20,    │       0 │ max_pooling2d_1[0][… │
│                     │ 128)              │         │ conv2d_2[0][0]       │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_5        │ (None, 20, 20,    │       0 │ add_1[0][0]          │
│ (Activation)        │ 128)              │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ separable_conv2d_4  │ (None, 20, 20,    │  34,176 │ activation_5[0][0]   │
│ (SeparableConv2D)   │ 256)              │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 20, 20,    │   1,024 │ separable_conv2d_4[… │
│ (BatchNormalizatio… │ 256)              │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_6        │ (None, 20, 20,    │       0 │ batch_normalization… │
│ (Activation)        │ 256)              │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ separable_conv2d_5  │ (None, 20, 20,    │  68,096 │ activation_6[0][0]   │
│ (SeparableConv2D)   │ 256)              │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 20, 20,    │   1,024 │ separable_conv2d_5[… │
│ (BatchNormalizatio… │ 256)              │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ max_pooling2d_2     │ (None, 10, 10,    │       0 │ batch_normalization… │
│ (MaxPooling2D)      │ 256)              │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv2d_3 (Conv2D)   │ (None, 10, 10,    │  33,024 │ add_1[0][0]          │
│                     │ 256)              │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ add_2 (Add)         │ (None, 10, 10,    │       0 │ max_pooling2d_2[0][… │
│                     │ 256)              │         │ conv2d_3[0][0]       │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_7        │ (None, 10, 10,    │       0 │ add_2[0][0]          │
│ (Activation)        │ 256)              │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv2d_transpose    │ (None, 10, 10,    │ 590,080 │ activation_7[0][0]   │
│ (Conv2DTranspose)   │ 256)              │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 10, 10,    │   1,024 │ conv2d_transpose[0]… │
│ (BatchNormalizatio… │ 256)              │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_8        │ (None, 10, 10,    │       0 │ batch_normalization… │
│ (Activation)        │ 256)              │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv2d_transpose_1  │ (None, 10, 10,    │ 590,080 │ activation_8[0][0]   │
│ (Conv2DTranspose)   │ 256)              │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 10, 10,    │   1,024 │ conv2d_transpose_1[… │
│ (BatchNormalizatio… │ 256)              │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ up_sampling2d_1     │ (None, 20, 20,    │       0 │ add_2[0][0]          │
│ (UpSampling2D)      │ 256)              │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ up_sampling2d       │ (None, 20, 20,    │       0 │ batch_normalization… │
│ (UpSampling2D)      │ 256)              │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv2d_4 (Conv2D)   │ (None, 20, 20,    │  65,792 │ up_sampling2d_1[0][… │
│                     │ 256)              │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ add_3 (Add)         │ (None, 20, 20,    │       0 │ up_sampling2d[0][0], │
│                     │ 256)              │         │ conv2d_4[0][0]       │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_9        │ (None, 20, 20,    │       0 │ add_3[0][0]          │
│ (Activation)        │ 256)              │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv2d_transpose_2  │ (None, 20, 20,    │ 295,040 │ activation_9[0][0]   │
│ (Conv2DTranspose)   │ 128)              │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 20, 20,    │     512 │ conv2d_transpose_2[… │
│ (BatchNormalizatio… │ 128)              │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_10       │ (None, 20, 20,    │       0 │ batch_normalization… │
│ (Activation)        │ 128)              │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv2d_transpose_3  │ (None, 20, 20,    │ 147,584 │ activation_10[0][0]  │
│ (Conv2DTranspose)   │ 128)              │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 20, 20,    │     512 │ conv2d_transpose_3[… │
│ (BatchNormalizatio… │ 128)              │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ up_sampling2d_3     │ (None, 40, 40,    │       0 │ add_3[0][0]          │
│ (UpSampling2D)      │ 256)              │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ up_sampling2d_2     │ (None, 40, 40,    │       0 │ batch_normalization… │
│ (UpSampling2D)      │ 128)              │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv2d_5 (Conv2D)   │ (None, 40, 40,    │  32,896 │ up_sampling2d_3[0][… │
│                     │ 128)              │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ add_4 (Add)         │ (None, 40, 40,    │       0 │ up_sampling2d_2[0][… │
│                     │ 128)              │         │ conv2d_5[0][0]       │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_11       │ (None, 40, 40,    │       0 │ add_4[0][0]          │
│ (Activation)        │ 128)              │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv2d_transpose_4  │ (None, 40, 40,    │  73,792 │ activation_11[0][0]  │
│ (Conv2DTranspose)   │ 64)               │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 40, 40,    │     256 │ conv2d_transpose_4[… │
│ (BatchNormalizatio… │ 64)               │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_12       │ (None, 40, 40,    │       0 │ batch_normalization… │
│ (Activation)        │ 64)               │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv2d_transpose_5  │ (None, 40, 40,    │  36,928 │ activation_12[0][0]  │
│ (Conv2DTranspose)   │ 64)               │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 40, 40,    │     256 │ conv2d_transpose_5[… │
│ (BatchNormalizatio… │ 64)               │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ up_sampling2d_5     │ (None, 80, 80,    │       0 │ add_4[0][0]          │
│ (UpSampling2D)      │ 128)              │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ up_sampling2d_4     │ (None, 80, 80,    │       0 │ batch_normalization… │
│ (UpSampling2D)      │ 64)               │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv2d_6 (Conv2D)   │ (None, 80, 80,    │   8,256 │ up_sampling2d_5[0][… │
│                     │ 64)               │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ add_5 (Add)         │ (None, 80, 80,    │       0 │ up_sampling2d_4[0][… │
│                     │ 64)               │         │ conv2d_6[0][0]       │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_13       │ (None, 80, 80,    │       0 │ add_5[0][0]          │
│ (Activation)        │ 64)               │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv2d_transpose_6  │ (None, 80, 80,    │  18,464 │ activation_13[0][0]  │
│ (Conv2DTranspose)   │ 32)               │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 80, 80,    │     128 │ conv2d_transpose_6[… │
│ (BatchNormalizatio… │ 32)               │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_14       │ (None, 80, 80,    │       0 │ batch_normalization… │
│ (Activation)        │ 32)               │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv2d_transpose_7  │ (None, 80, 80,    │   9,248 │ activation_14[0][0]  │
│ (Conv2DTranspose)   │ 32)               │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 80, 80,    │     128 │ conv2d_transpose_7[… │
│ (BatchNormalizatio… │ 32)               │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ up_sampling2d_7     │ (None, 160, 160,  │       0 │ add_5[0][0]          │
│ (UpSampling2D)      │ 64)               │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ up_sampling2d_6     │ (None, 160, 160,  │       0 │ batch_normalization… │
│ (UpSampling2D)      │ 32)               │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv2d_7 (Conv2D)   │ (None, 160, 160,  │   2,080 │ up_sampling2d_7[0][… │
│                     │ 32)               │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ add_6 (Add)         │ (None, 160, 160,  │       0 │ up_sampling2d_6[0][… │
│                     │ 32)               │         │ conv2d_7[0][0]       │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv2d_8 (Conv2D)   │ (None, 160, 160,  │     867 │ add_6[0][0]          │
│                     │ 3)                │         │                      │
└─────────────────────┴───────────────────┴─────────┴──────────────────────┘
 Total params: 2,058,979 (7.85 MB)
 Trainable params: 2,055,203 (7.84 MB)
 Non-trainable params: 3,776 (14.75 KB)

검증 분할을 별도로 설정

import random

# img 경로를 트레이닝 및 검증 세트로 분할합니다.
val_samples = 1000
random.Random(1337).shuffle(input_img_paths)
random.Random(1337).shuffle(target_img_paths)
train_input_img_paths = input_img_paths[:-val_samples]
train_target_img_paths = target_img_paths[:-val_samples]
val_input_img_paths = input_img_paths[-val_samples:]
val_target_img_paths = target_img_paths[-val_samples:]

# 각 분할에 대한 데이터세트 인스턴스화
# 더 빠른 에포크 트레이닝 시간을 위해, `max_dataset_len`의 입력 파일 제한
# 전체 데이터세트로 실행할 때, `max_dataset_len` 인수 제거.
train_dataset = get_dataset(
    batch_size,
    img_size,
    train_input_img_paths,
    train_target_img_paths,
    max_dataset_len=1000,
)
valid_dataset = get_dataset(
    batch_size, img_size, val_input_img_paths, val_target_img_paths
)

모델 트레이닝

# 트레이닝을 위해 모델을 구성합니다. 
# 우리는 대상 데이터가 정수이기 때문에, 
# categorical_crossentropy의 "sparse" 버전을 사용합니다.
model.compile(
    optimizer=keras.optimizers.Adam(1e-4), loss="sparse_categorical_crossentropy"
)

callbacks = [
    keras.callbacks.ModelCheckpoint("oxford_segmentation.keras", save_best_only=True)
]

# 각 에포크가 끝날 때마다 검증을 수행하여, 모델을 트레이닝합니다.
epochs = 50
model.fit(
    train_dataset,
    epochs=epochs,
    validation_data=valid_dataset,
    callbacks=callbacks,
    verbose=2,
)
Epoch 1/50

WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1700414690.172044 2226172 device_compiler.h:187] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 62s - 2s/step - loss: 1.6363 - val_loss: 2.2226
Epoch 2/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 94ms/step - loss: 0.9223 - val_loss: 1.8273
Epoch 3/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 82ms/step - loss: 0.7894 - val_loss: 2.0044
Epoch 4/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 83ms/step - loss: 0.7174 - val_loss: 2.3480
Epoch 5/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 82ms/step - loss: 0.6695 - val_loss: 2.7528
Epoch 6/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 83ms/step - loss: 0.6325 - val_loss: 3.1453
Epoch 7/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 84ms/step - loss: 0.6012 - val_loss: 3.5611
Epoch 8/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 87ms/step - loss: 0.5730 - val_loss: 4.0003
Epoch 9/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 85ms/step - loss: 0.5466 - val_loss: 4.4798
Epoch 10/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 86ms/step - loss: 0.5210 - val_loss: 5.0245
Epoch 11/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 87ms/step - loss: 0.4958 - val_loss: 5.5950
Epoch 12/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 87ms/step - loss: 0.4706 - val_loss: 6.1534
Epoch 13/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 85ms/step - loss: 0.4453 - val_loss: 6.6107
Epoch 14/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 83ms/step - loss: 0.4202 - val_loss: 6.8010
Epoch 15/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 84ms/step - loss: 0.3956 - val_loss: 6.6751
Epoch 16/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 83ms/step - loss: 0.3721 - val_loss: 6.0800
Epoch 17/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 84ms/step - loss: 0.3506 - val_loss: 5.1820
Epoch 18/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 82ms/step - loss: 0.3329 - val_loss: 4.0350
Epoch 19/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 4s - 114ms/step - loss: 0.3216 - val_loss: 3.0513
Epoch 20/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 94ms/step - loss: 0.3595 - val_loss: 2.2567
Epoch 21/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 100ms/step - loss: 0.4417 - val_loss: 1.5873
Epoch 22/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 101ms/step - loss: 0.3531 - val_loss: 1.5798
Epoch 23/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 96ms/step - loss: 0.3353 - val_loss: 1.5525
Epoch 24/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 95ms/step - loss: 0.3392 - val_loss: 1.4625
Epoch 25/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 95ms/step - loss: 0.3596 - val_loss: 0.8867
Epoch 26/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 94ms/step - loss: 0.3528 - val_loss: 0.8021
Epoch 27/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 92ms/step - loss: 0.3237 - val_loss: 0.7986
Epoch 28/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 89ms/step - loss: 0.3198 - val_loss: 0.8533
Epoch 29/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 84ms/step - loss: 0.3272 - val_loss: 1.0588
Epoch 30/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 88ms/step - loss: 0.3164 - val_loss: 1.1889
Epoch 31/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 85ms/step - loss: 0.2987 - val_loss: 0.9518
Epoch 32/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 87ms/step - loss: 0.2749 - val_loss: 0.9011
Epoch 33/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 84ms/step - loss: 0.2595 - val_loss: 0.8872
Epoch 34/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 87ms/step - loss: 0.2552 - val_loss: 1.0221
Epoch 35/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 82ms/step - loss: 0.2628 - val_loss: 1.1553
Epoch 36/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 85ms/step - loss: 0.2788 - val_loss: 2.1549
Epoch 37/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 94ms/step - loss: 0.2870 - val_loss: 1.6282
Epoch 38/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 89ms/step - loss: 0.2702 - val_loss: 1.3201
Epoch 39/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 91ms/step - loss: 0.2569 - val_loss: 1.2364
Epoch 40/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 106ms/step - loss: 0.2523 - val_loss: 1.3673
Epoch 41/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 86ms/step - loss: 0.2570 - val_loss: 1.3999
Epoch 42/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 87ms/step - loss: 0.2680 - val_loss: 0.9976
Epoch 43/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 83ms/step - loss: 0.2558 - val_loss: 1.0209
Epoch 44/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 85ms/step - loss: 0.2403 - val_loss: 1.3271
Epoch 45/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 83ms/step - loss: 0.2414 - val_loss: 1.1993
Epoch 46/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 84ms/step - loss: 0.2516 - val_loss: 1.0532
Epoch 47/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 83ms/step - loss: 0.2695 - val_loss: 1.1183
Epoch 48/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 87ms/step - loss: 0.2555 - val_loss: 1.0432
Epoch 49/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 82ms/step - loss: 0.2290 - val_loss: 0.9444
Epoch 50/50

Corrupt JPEG data: 240 extraneous bytes before marker 0xd9

32/32 - 3s - 83ms/step - loss: 0.1994 - val_loss: 1.2182

<keras.src.callbacks.history.History at 0x7fe01842dab0>

예측 시각화


# 검증 세트의 모든 이미지에 대한 예측 생성

val_dataset = get_dataset(
    batch_size, img_size, val_input_img_paths, val_target_img_paths
)
val_preds = model.predict(val_dataset)


def display_mask(i):
    """모델의 예측을 표시하는 빠른 유틸리티입니다."""
    mask = np.argmax(val_preds[i], axis=-1)
    mask = np.expand_dims(mask, axis=-1)
    img = ImageOps.autocontrast(keras.utils.array_to_img(mask))
    display(img)


# 검증 이미지 #10에 대한 결과 표시
i = 10

# 입력 이미지 표시
display(Image(filename=val_input_img_paths[i]))

# ground-truth 타겟 마스크 표시
img = ImageOps.autocontrast(load_img(val_target_img_paths[i]))
display(img)

# 우리 모델이 예측한 디스플레이 마스크
display_mask(i)  # 이 모델은 150x150에서만 입력을 받는다는 점에 유의하세요.
 32/32 ━━━━━━━━━━━━━━━━━━━━ 5s 100ms/step

jpeg

png

png