- 원본 링크 : https://keras.io/guides/writing_a_custom_training_loop_in_jax/
- 최종 수정일 : 2024-09-18
JAX에서 처음부터 트레이닝 루프 작성하기 (Writing a training loop from scratch in JAX)
목차
저자: fchollet
생성일: 2023/06/25
최종편집일: 2023/06/25
설명: JAX에서 낮은 레벨의 트레이닝 및 평가 루프 작성하기
셋업
import os
# 이 가이드는 jax 백엔드에서만 실행할 수 있습니다.
os.environ["KERAS_BACKEND"] = "jax"
import jax
# tf.data를 사용하기 위해 TF를 임포트합니다.
import tensorflow as tf
import keras
import numpy as np
소개
Keras는 기본 트레이닝 및 평가 루프인 fit()
과 evaluate()
를 제공합니다. 이러한 사용 방법은 빌트인 메서드를 사용한 트레이닝 및 평가 가이드에서 다룹니다.
모델의 학습 알고리즘을 커스터마이즈하면서도 fit()
의 편리함을 활용하고 싶다면 (예를 들어, fit()
을 사용해 GAN을 트레이닝하려는 경우), Model
클래스를 서브클래싱하고, fit()
동안 반복적으로 호출되는 자체 train_step()
메서드를 구현할 수 있습니다.
이제, 트레이닝 및 평가에 대해 매우 낮은 레벨의 제어를 원한다면, 처음부터 직접 트레이닝 및 평가 루프를 작성해야 합니다. 이 가이드는 그것에 관한 것입니다.
첫 번째 엔드 투 엔드 예제
커스텀 트레이닝 루프를 작성하려면, 다음이 필요합니다:
- 트레이닝할 모델.
- 옵티마이저.
keras.optimizers
의 옵티마이저를 사용하거나,optax
패키지에서 사용할 수 있습니다. - 손실 함수.
- 데이터셋. JAX 생태계의 표준은
tf.data
를 통해 데이터를 로드하는 것이므로, 이를 사용할 것입니다.
이제 하나씩 설정해 보겠습니다.
먼저, 모델과 MNIST 데이터셋을 가져옵니다:
def get_model():
inputs = keras.Input(shape=(784,), name="digits")
x1 = keras.layers.Dense(64, activation="relu")(inputs)
x2 = keras.layers.Dense(64, activation="relu")(x1)
outputs = keras.layers.Dense(10, name="predictions")(x2)
model = keras.Model(inputs=inputs, outputs=outputs)
return model
model = get_model()
# 트레이닝 데이터셋 준비.
batch_size = 32
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784)).astype("float32")
x_test = np.reshape(x_test, (-1, 784)).astype("float32")
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)
# 검증을 위해 10,000개의 샘플을 예약합니다.
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]
# 트레이닝 데이터셋 준비.
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
# 검증 데이터셋 준비.
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(batch_size)
다음으로, 손실 함수와 옵티마이저를 설정합니다. 이번에는 Keras 옵티마이저를 사용합니다.
# 손실 함수 인스턴스화.
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)
# 옵티마이저 인스턴스화.
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
JAX에서 그래디언트 계산하기
커스텀 트레이닝 루프를 사용하여, 미니 배치 그래디언트로 모델을 트레이닝해봅시다.
JAX에서는 메타프로그래밍(metaprogramming)을 통해 그래디언트를 계산합니다. jax.grad
(또는 jax.value_and_grad
)를 함수에 호출하여, 그 함수에 대한 그래디언트 계산 함수를 생성합니다.
따라서 먼저 필요한 것은 손실 값을 반환하는 함수입니다. 이 함수를 사용하여 그래디언트 함수를 생성할 것입니다. 다음과 같은 형태입니다:
def compute_loss(x, y):
...
return loss
이와 같은 함수를 갖게 되면, 메타프로그래밍을 통해 다음과 같이 그래디언트를 계산할 수 있습니다:
grad_fn = jax.grad(compute_loss)
grads = grad_fn(x, y)
일반적으로는, 단순히 그래디언트 값만 얻는 것이 아니라 손실 값도 함께 얻고자 합니다. 이를 위해 jax.grad
대신 jax.value_and_grad
를 사용할 수 있습니다:
grad_fn = jax.value_and_grad(compute_loss)
loss, grads = grad_fn(x, y)
JAX 계산은 순수하게 stateless입니다
JAX에서는, 모든 것이 stateless 함수여야 하므로, 손실 계산 함수도 stateless여야 합니다. 이는 모든 Keras 변수(예: 가중치 텐서)를 함수의 입력으로 전달해야 하며, 순전파 동안 업데이트된 모든 변수를 함수의 출력으로 반환해야 함을 의미합니다. 함수는 부수 효과가 없어야 합니다.
순전파 동안 Keras 모델의 비트레이닝 변수는 업데이트될 수 있습니다. 이러한 변수는 예를 들어 RNG 시드 상태 변수나 BatchNormalization 통계일 수 있습니다. 우리는 이러한 변수들을 반환해야 합니다. 따라서 다음과 같은 함수가 필요합니다:
def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y):
...
return loss, non_trainable_variables
이러한 함수를 갖게 되면, value_and_grad
에서 has_aux
를 지정하여 그래디언트 함수를 얻을 수 있습니다. 이는 JAX에 손실 계산 함수가 손실 외에도 더 많은 출력을 반환한다고 알려줍니다. 손실은 항상 첫 번째 출력이어야 한다는 점에 유의하세요.
grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)
(loss, non_trainable_variables), grads = grad_fn(
trainable_variables, non_trainable_variables, x, y
)
이제 기본 사항을 정립했으니, compute_loss_and_updates
함수를 구현해봅시다. Keras 모델에는 stateless_call
메서드가 있는데, 이 메서드는 여기서 유용하게 사용할 수 있습니다. model.__call__
과 비슷하게 작동하지만, 모델의 모든 변수 값을 명시적으로 전달해야 하며, __call__
출력뿐만 아니라 (잠재적으로 업데이트된) 트레이닝 불가능한(non-trainable) 변수도 반환합니다.
def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y):
y_pred, non_trainable_variables = model.stateless_call(
trainable_variables, non_trainable_variables, x
)
loss = loss_fn(y, y_pred)
return loss, non_trainable_variables
그래디언트 함수를 구해봅시다:
grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)
트레이닝 스텝 함수
다음으로, 엔드 투 엔드 트레이닝 스텝을 구현해봅시다. 이 함수는 순전파를 실행하고, 손실을 계산하고, 그래디언트를 계산하며, 옵티마이저를 사용하여 트레이닝 가능한 변수를 업데이트하는 역할을 합니다. 이 함수도 stateless여야 하므로, 우리가 사용할 모든 상태 요소를 포함하는 state
튜플을 입력으로 받아야 합니다:
trainable_variables
및non_trainable_variables
: 모델의 변수들.optimizer_variables
: 옵티마이저의 상태 변수들, 예를 들어 모멘텀 누적기(momentum accumulators)와 같은 것들.
트레이닝 가능한 변수를 업데이트하기 위해, 옵티마이저의 stateless 메서드인 stateless_apply
를 사용합니다. 이는 optimizer.apply()
와 동등하지만, 항상 trainable_variables
와 optimizer_variables
를 전달해야 합니다. 이는 업데이트된 trainable_variables
와 업데이트된 optimizer_variables
를 반환합니다.
def train_step(state, data):
trainable_variables, non_trainable_variables, optimizer_variables = state
x, y = data
(loss, non_trainable_variables), grads = grad_fn(
trainable_variables, non_trainable_variables, x, y
)
trainable_variables, optimizer_variables = optimizer.stateless_apply(
grads, trainable_variables, optimizer_variables
)
# 업데이트된 상태 반환
return loss, (
trainable_variables,
non_trainable_variables,
optimizer_variables,
)
jax.jit
로 속도 향상시키기
기본적으로 JAX 연산은 TensorFlow eager 모드와 PyTorch eager 모드처럼, 즉시(eagerly) 실행됩니다. 그리고, TensorFlow eager 모드와 PyTorch eager 모드처럼 꽤 느립니다. 즉시 실행(eager) 모드는 디버깅 환경으로 더 적합하며, 실제 작업을 수행하는 방법으로는 적합하지 않습니다. 따라서, train_step
을 컴파일하여 빠르게 만들어봅시다.
stateless JAX 함수가 있는 경우, @jax.jit
데코레이터를 통해 이를 XLA로 컴파일할 수 있습니다. 첫 번째 실행 시 함수가 추적되고, 이후 실행에서는 추적된 그래프를 실행하게 됩니다. (이는 @tf.function(jit_compile=True)
와 유사합니다) 시도해 봅시다:
@jax.jit
def train_step(state, data):
trainable_variables, non_trainable_variables, optimizer_variables = state
x, y = data
(loss, non_trainable_variables), grads = grad_fn(
trainable_variables, non_trainable_variables, x, y
)
trainable_variables, optimizer_variables = optimizer.stateless_apply(
optimizer_variables, grads, trainable_variables
)
# 업데이트된 상태 반환
return loss, (
trainable_variables,
non_trainable_variables,
optimizer_variables,
)
이제 모델을 트레이닝할 준비가 되었습니다. 트레이닝 루프 자체는 간단합니다: loss, state = train_step(state, data)
를 반복적으로 호출하기만 하면 됩니다.
참고:
tf.data.Dataset
에서 생성된(yielded) TF 텐서를 JAX 함수에 전달하기 전에 NumPy로 변환합니다.- 모든 변수는 사전에 빌드되어야 합니다: 모델은 빌드되어야 하고, 옵티마이저도 빌드되어야 합니다. 우리가 사용하는 것은 Functional API 모델이므로 이미 빌드되어 있지만, 만약 서브클래스화된 모델이라면 데이터를 하나의 배치에 대해 호출하여 빌드해야 합니다.
# 옵티마이저 변수 빌드
optimizer.build(model.trainable_variables)
trainable_variables = model.trainable_variables
non_trainable_variables = model.non_trainable_variables
optimizer_variables = optimizer.variables
state = trainable_variables, non_trainable_variables, optimizer_variables
# 트레이닝 루프
for step, data in enumerate(train_dataset):
data = (data[0].numpy(), data[1].numpy())
loss, state = train_step(state, data)
# 100 배치마다 로그 출력
if step % 100 == 0:
print(f"스텝 {step}에서의 트레이닝 손실 (1 배치 기준): {float(loss):.4f}")
print(f"지금까지 본 샘플 수: {(step + 1) * batch_size}")
결과를 보려면 클릭하세요.
Training loss (for 1 batch) at step 0: 156.4785
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 2.5526
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 1.8922
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 1.2381
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.4812
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 2.3339
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 0.5615
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.6471
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 1.6272
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.9416
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.8152
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.8838
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.1278
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 1.9234
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.3413
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.2429
Seen so far: 48032 samples
여기서 주목해야 할 중요한 점은 루프가 완전히 stateless하다는 것입니다. 모델에 첨부된 변수들(model.weights
)은 루프 동안 절대 업데이트되지 않습니다. 변수의 새로운 값은 오직 state
튜플에만 저장됩니다. 이는 모델을 저장하기 전에, 새로운 변수 값을 다시 모델에 연결해야 한다는 것을 의미합니다.
업데이트하려는 각 모델 변수에 대해, variable.assign(new_value)
를 호출하기만 하면 됩니다:
trainable_variables, non_trainable_variables, optimizer_variables = state
for variable, value in zip(model.trainable_variables, trainable_variables):
variable.assign(value)
for variable, value in zip(model.non_trainable_variables, non_trainable_variables):
variable.assign(value)
메트릭의 낮은 레벨 처리
이 기본 트레이닝 루프에 메트릭 모니터링을 추가해 봅시다.
이렇게 처음부터 작성한 트레이닝 루프에서도 빌트인 Keras 메트릭(또는 사용자가 작성한 커스텀 메트릭)을 쉽게 재사용할 수 있습니다. 흐름은 다음과 같습니다:
- 루프 시작 시 메트릭을 인스턴스화합니다.
train_step
인자와compute_loss_and_updates
인자에metric_variables
를 포함시킵니다.compute_loss_and_updates
함수에서metric.stateless_update_state()
를 호출합니다. 이는update_state()
와 같은 역할을 하지만, stateless인 것만 다릅니다.train_step
외부 (즉시 실행 범위)에서 메트릭의 현재 값을 표시해야 할 때, 새로운 메트릭 변수 값을 메트릭 객체에 연결하고,metric.result()
를 호출합니다.- 메트릭의 상태를 초기화해야 할 때(일반적으로 에포크가 끝날 때),
metric.reset_state()
를 호출합니다.
이 지식을 사용하여 트레이닝이 끝날 때, 트레이닝 및 검증 데이터에 대한 CategoricalAccuracy
를 계산해 보겠습니다:
# 새로운 모델 가져오기
model = get_model()
# 모델을 트레이닝할 옵티마이저 인스턴스화
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
# 손실 함수 인스턴스화
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)
# 메트릭 준비
train_acc_metric = keras.metrics.CategoricalAccuracy()
val_acc_metric = keras.metrics.CategoricalAccuracy()
def compute_loss_and_updates(
trainable_variables, non_trainable_variables, metric_variables, x, y
):
y_pred, non_trainable_variables = model.stateless_call(
trainable_variables, non_trainable_variables, x
)
loss = loss_fn(y, y_pred)
metric_variables = train_acc_metric.stateless_update_state(
metric_variables, y, y_pred
)
return loss, (non_trainable_variables, metric_variables)
grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)
@jax.jit
def train_step(state, data):
(
trainable_variables,
non_trainable_variables,
optimizer_variables,
metric_variables,
) = state
x, y = data
(loss, (non_trainable_variables, metric_variables)), grads = grad_fn(
trainable_variables, non_trainable_variables, metric_variables, x, y
)
trainable_variables, optimizer_variables = optimizer.stateless_apply(
optimizer_variables, grads, trainable_variables
)
# 업데이트된 상태 반환
return loss, (
trainable_variables,
non_trainable_variables,
optimizer_variables,
metric_variables,
)
평가 스텝 함수도 준비해 보겠습니다:
@jax.jit
def eval_step(state, data):
trainable_variables, non_trainable_variables, metric_variables = state
x, y = data
y_pred, non_trainable_variables = model.stateless_call(
trainable_variables, non_trainable_variables, x
)
loss = loss_fn(y, y_pred)
metric_variables = val_acc_metric.stateless_update_state(
metric_variables, y, y_pred
)
return loss, (
trainable_variables,
non_trainable_variables,
metric_variables,
)
다음은 우리의 루프입니다:
# 옵티마이저 변수 빌드
optimizer.build(model.trainable_variables)
trainable_variables = model.trainable_variables
non_trainable_variables = model.non_trainable_variables
optimizer_variables = optimizer.variables
metric_variables = train_acc_metric.variables
state = (
trainable_variables,
non_trainable_variables,
optimizer_variables,
metric_variables,
)
# 트레이닝 루프
for step, data in enumerate(train_dataset):
data = (data[0].numpy(), data[1].numpy())
loss, state = train_step(state, data)
# 100 배치마다 로그 출력
if step % 100 == 0:
print(f"스텝 {step}에서의 트레이닝 손실 (1 배치 기준): {float(loss):.4f}")
_, _, _, metric_variables = state
for variable, value in zip(train_acc_metric.variables, metric_variables):
variable.assign(value)
print(f"트레이닝 정확도: {train_acc_metric.result()}")
print(f"지금까지 본 샘플 수: {(step + 1) * batch_size}")
metric_variables = val_acc_metric.variables
(
trainable_variables,
non_trainable_variables,
optimizer_variables,
metric_variables,
) = state
state = trainable_variables, non_trainable_variables, metric_variables
# 평가 루프
for step, data in enumerate(val_dataset):
data = (data[0].numpy(), data[1].numpy())
loss, state = eval_step(state, data)
# 100 배치마다 로그 출력
if step % 100 == 0:
print(f"스텝 {step}에서의 검증 손실 (1 배치 기준): {float(loss):.4f}")
_, _, metric_variables = state
for variable, value in zip(val_acc_metric.variables, metric_variables):
variable.assign(value)
print(f"검증 정확도: {val_acc_metric.result()}")
print(f"지금까지 본 샘플 수: {(step + 1) * batch_size}")
결과를 보려면 클릭하세요.
Training loss (for 1 batch) at step 0: 96.4990
Training accuracy: 0.0625
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 2.0447
Training accuracy: 0.6064356565475464
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 2.0184
Training accuracy: 0.6934079527854919
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 1.9111
Training accuracy: 0.7303779125213623
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 1.8042
Training accuracy: 0.7555330395698547
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 1.2200
Training accuracy: 0.7659056782722473
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 1.3437
Training accuracy: 0.7793781161308289
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 1.2409
Training accuracy: 0.789318859577179
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 1.6530
Training accuracy: 0.7977527976036072
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.4173
Training accuracy: 0.8060488104820251
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.5543
Training accuracy: 0.8100025057792664
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 1.2699
Training accuracy: 0.8160762786865234
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 1.2621
Training accuracy: 0.8213468194007874
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.8028
Training accuracy: 0.8257350325584412
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 1.0701
Training accuracy: 0.8298090696334839
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.3910
Training accuracy: 0.8336525559425354
Seen so far: 48032 samples
Validation loss (for 1 batch) at step 0: 0.2482
Validation accuracy: 0.835365355014801
Seen so far: 32 samples
Validation loss (for 1 batch) at step 100: 1.1641
Validation accuracy: 0.8388938903808594
Seen so far: 3232 samples
Validation loss (for 1 batch) at step 200: 0.1201
Validation accuracy: 0.8428196907043457
Seen so far: 6432 samples
Validation loss (for 1 batch) at step 300: 0.0755
Validation accuracy: 0.8471122980117798
Seen so far: 9632 samples
모델이 추적하는 손실의 낮은 레벨 처리
레이어와 모델은 순전파 중, self.add_loss(value)
를 호출하는 레이어에 의해 생성된 모든 손실을 재귀적으로 추적합니다. 그 결과 생성된 스칼라 손실 값들의 목록은, 순전파가 끝난 후 model.losses
속성을 통해 확인할 수 있습니다.
이러한 손실 요소들을 사용하고 싶다면, 이를 합산하여 트레이닝 스텝의 메인 손실에 추가해야 합니다.
활동 정규화 손실을 생성하는 다음 레이어를 고려해 보세요:
class ActivityRegularizationLayer(keras.layers.Layer):
def call(self, inputs):
self.add_loss(1e-2 * jax.numpy.sum(inputs))
return inputs
간단한 모델을 빌드해봅시다:
inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu")(inputs)
# activity 정규화 레이어를 추가합니다.
x = ActivityRegularizationLayer()(x)
x = keras.layers.Dense(64, activation="relu")(x)
outputs = keras.layers.Dense(10, name="predictions")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
다음은 compute_loss_and_updates
함수가 어떻게 생겨야 하는지 보여줍니다:
model.stateless_call()
에return_losses=True
를 전달합니다.- 결과로 생성된
losses
를 합산하여, 메인 손실에 추가합니다.
def compute_loss_and_updates(
trainable_variables, non_trainable_variables, metric_variables, x, y
):
y_pred, non_trainable_variables, losses = model.stateless_call(
trainable_variables, non_trainable_variables, x, return_losses=True
)
loss = loss_fn(y, y_pred)
if losses:
loss += jax.numpy.sum(losses)
metric_variables = train_acc_metric.stateless_update_state(
metric_variables, y, y_pred
)
return loss, non_trainable_variables, metric_variables
이것으로 끝입니다!