1. ホーム
  2. tensorflow

tensorflow(6) mnist.train.next_batch() 関数解析

2022-02-11 13:34:57
<パス

データを1バッチずつ生成する必要があるtensorflowのfeed_dictの原理。

1. データセットクラス

データ処理部分をクラスとして記述し、init関数でいくつかのパラメータを定義します。

class DataSet(object):

  def __init__(self,
               images,
               labels,.....)
    self._images = images
    self._labels = labels
    self._epochs_completed = 0 # how many epochs have been gone through
    self._index_in_epoch = 0 # index in an epoch
    self._num_examples # is the total number of samples in the training data

2. next_batch機能

next_batch関数の各呼び出しが最後の位置をまだ覚えていることをどのように保証しますか? tensorflowソースコードはデータセット入力をクラスとして書き、self._index_in_epochは最後の位置を覚えているクラス変数と等価です。
次の関数は、大きく3つの部分に分かれています。
最初のエポックをどうするか。
各エポックの終わりが次のエポックの始まりに合流するのをどうするか。
非最初のエポック& 非終了をどうするか。
このように分ける主な理由は、各エポックの最初に、インデックスがシャッフルされるからである。

def next_batch(self, batch_size, fake_data=False, shuffle=True):
    start = self._index_in_epoch #self._index_in_epoch All calls, total number of samples used, equivalent to a global variable #start The first batch is 0, the rest is the same as self._index_in_epoch, and if more than one epoch is used. The rest is the same as self._index_in_epoch, and if it exceeds one epoch, it is reassigned below.
    # Shuffle for the first epoch The first epoch needs to be shuffled
    if self._epochs_completed == 0 and start == 0 and shuffle:
      perm0 = numpy.array(self._num_examples) # Generate an np.array of all sample lengths
      numpy.random.shuffle(perm0)
      self._images = self.images[perm0]
      self._labels = self.labels[perm0]
    # Go to the next epoch


    if start + batch_size > self._num_examples: # End of epoch and beginning of next epoch
      # Finished epoch
      self._epochs_completed += 1
      # Get the rest examples in this epoch
      rest_num_examples = self._num_examples - start # Last not enough for a batch and a few left
      images_rest_part = self._images[start:self._num_examples]
      labels_rest_part = self._labels[start:self._num_examples]
      # Shuffle the data
      if shuffle: 
        perm = numpy.range(self._num_examples)
        numpy.random.shuffle(perm)
        self._images = self.images[perm]
        self._labels = self.labels[perm]
      # Start next epoch
      start = 0
      self._index_in_epoch = batch_size - rest_num_examples
      end = self._index_in_epoch
      images_new_part = self._images[start:end] 
      labels_new_part = self._labels[start:end]
      return numpy.concatenate((images_rest_part, images_new_part), axis=0) , numpy.concatenate((labels_rest_part, labels_new_part), axis=0)    
    else: # Except for the first epoch, and the beginning of each epoch, the rest of the middle batch is handled
      self._index_in_epoch += batch_size # start = index_in_epoch
      end = self._index_in_epoch # end is simple, it's index_in_epoch plus batch_size 
      return self._images[start:end], self._labels[start:end] # in data x,y