1. ホーム
  2. math

[解決済み] tf.truncated_normalとtf.random_normalの違いは何ですか?

2022-02-27 16:42:20

質問

tf.random_normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None) は、正規分布からランダムな値を出力します。

tf.truncated_normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None) は、切断された正規分布からランダムな値を出力します。

切り捨てられた正規分布」でググってみました。しかし、あまり理解できなかった。

どのように解決するのですか?

その ドキュメント が全てを語っている。 切り捨てられた正規分布のため。

平均と標準偏差を指定した正規分布から値を抽出し、平均から2標準偏差以上離れたサンプルは破棄して再抽出します。

おそらく、自分でグラフを描いてみると、その違いがよくわかると思います(%magicはjupyter notebookを使用しているためです)。

import tensorflow as tf
import matplotlib.pyplot as plt

%matplotlib inline  

n = 500000
A = tf.truncated_normal((n,))
B = tf.random_normal((n,))
with tf.Session() as sess:
    a, b = sess.run([A, B])

そして今

plt.hist(a, 100, (-4.2, 4.2));
plt.hist(b, 100, (-4.2, 4.2));


切り捨てられた正規分布を使用するポイントは、シグモイドのようなトーム関数の飽和(値が大きすぎたり小さすぎたりすると、ニューロンは学習を停止してしまう)を克服することである。