1. ホーム
  2. artificial-intelligence

[解決済み] ニューラルネットワークの訓練セット、検証セット、テストセットの違いは何ですか?

2022-04-24 23:24:14

質問

を使っています。 本ライブラリ を使用して、学習エージェントを実装しています。

学習用ケースは作成しましたが、検証用とテスト用のセットがはっきりわかりません。

と先生はおっしゃいます。

70%を訓練ケース、10%をテストケース、残りの20%を検証ケースにします。

編集

トレーニング用にこのようなコードを用意したのですが、どのタイミングで 停止 をトレーニングします。

  def train(self, train, validation, N=0.3, M=0.1):
    # N: learning rate
    # M: momentum factor
    accuracy = list()
    while(True):
        error = 0.0
        for p in train:
            input, target = p
            self.update(input)
            error = error + self.backPropagate(target, N, M)
        print "validation"
        total = 0
        for p in validation:
            input, target = p
            output = self.update(input)
            total += sum([abs(target - output) for target, output in zip(target, output)]) #calculates sum of absolute diference between target and output

        accuracy.append(total)
        print min(accuracy)
        print sum(accuracy[-5:])/5
        #if i % 100 == 0:
        print 'error %-14f' % error
        if ? < ?:
            break

編集

20回の学習反復の後、検証データで平均誤差0.2が得られますが、これは80%になるはずですよね?

平均誤差=検証データの入力/検証データの大きさが与えられたときの、検証目標と出力の差の絶対値の総和。

1
        avg error 0.520395 
        validation
        0.246937882684
2
        avg error 0.272367   
        validation
        0.228832420879
3
        avg error 0.249578    
        validation
        0.216253590304
        ...
22
        avg error 0.227753
        validation
        0.200239244714
23
        avg error 0.227905    
        validation
        0.199875013416

解決方法は?

学習時に学習用セットと検証用セットを使用する。

for each epoch
    for each training data instance
        propagate error through the network
        adjust the weights
        calculate the accuracy over training data
    for each validation data instance
        calculate the accuracy over the validation data
    if the threshold validation accuracy is met
        exit training
    else
        continue training

学習が終わったら、テストセットに対して実行し、精度が十分であることを確認します。

トレーニングセット このデータセットは、ニューラルネットワークの重みを調整するために使用されます。

バリデーションセット このデータセットは、オーバーフィッティングを最小限に抑えるために使用されます。このデータセットでネットワークの重みを調整するのではなく、学習データセットに対する精度の向上が、ネットワークに見せたことのない、あるいは少なくともネットワークが学習したことのないデータセット(すなわち検証データセット)に対する精度の向上を実際にもたらすかどうかを検証しているに過ぎないのです。もし、トレーニングデータセットに対する精度が上がっても、バリデーションデータセットに対する精度が変わらないか下がるようであれば、ニューラルネットワークをオーバーフィットさせているので、トレーニングを中止すべきなのです。

テストセット このデータセットは、ネットワークの実際の予測能力を確認するために、最終解のテストにのみ使用されます。