1. ホーム
  2. python

[解決済み] Tensorflowのnext_batchを自分のデータに対して実装する方法

2022-02-10 07:58:10

質問

TENSORFLOW MNIST チュートリアル mnist.train.next_batch(100) 関数は非常に便利です。今、私自身は簡単な分類を実装しようとしています。私の学習データはnumpyの配列にあります。次のバッチを与えるために、どのように私自身のデータに対して同様の関数を実装することができますか?

sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
Xtr, Ytr = loadData()
for it in range(1000):
    batch_x = Xtr.next_batch(100)
    batch_y = Ytr.next_batch(100)

解決方法は?

貼ってあるリンクに書いてあります。 トレーニングセットから100個のランダムなデータ点をバッチとして取得します。 . 私の例では、(あなたの例のようなメソッドではなく)グローバル関数を使用しているので、構文に違いがあります。

私の関数では、必要なサンプル数とデータ配列を渡す必要があります。

以下は正しいコードで、サンプルに正しいラベルが付くようにします。

import numpy as np

def next_batch(num, data, labels):
    '''
    Return a total of `num` random samples and labels. 
    '''
    idx = np.arange(0 , len(data))
    np.random.shuffle(idx)
    idx = idx[:num]
    data_shuffle = [data[ i] for i in idx]
    labels_shuffle = [labels[ i] for i in idx]

    return np.asarray(data_shuffle), np.asarray(labels_shuffle)

Xtr, Ytr = np.arange(0, 10), np.arange(0, 100).reshape(10, 10)
print(Xtr)
print(Ytr)

Xtr, Ytr = next_batch(5, Xtr, Ytr)
print('\n5 random samples')
print(Xtr)
print(Ytr)

そしてデモラン。

[0 1 2 3 4 5 6 7 8 9]
[[ 0  1  2  3  4  5  6  7  8  9]
 [10 11 12 13 14 15 16 17 18 19]
 [20 21 22 23 24 25 26 27 28 29]
 [30 31 32 33 34 35 36 37 38 39]
 [40 41 42 43 44 45 46 47 48 49]
 [50 51 52 53 54 55 56 57 58 59]
 [60 61 62 63 64 65 66 67 68 69]
 [70 71 72 73 74 75 76 77 78 79]
 [80 81 82 83 84 85 86 87 88 89]
 [90 91 92 93 94 95 96 97 98 99]]

5 random samples
[9 1 5 6 7]
[[90 91 92 93 94 95 96 97 98 99]
 [10 11 12 13 14 15 16 17 18 19]
 [50 51 52 53 54 55 56 57 58 59]
 [60 61 62 63 64 65 66 67 68 69]
 [70 71 72 73 74 75 76 77 78 79]]