1. ホーム
  2. python

[解決済み] TensorFlowで学習済みの単語埋め込み(word2vecやGlove)を利用する

2023-02-10 01:48:50

質問

最近、興味深い実装をレビューしました。 畳み込みテキスト分類 . しかし、私がレビューしたすべてのTensorFlowコードは、以下のようなランダムな(事前学習されていない)埋め込みベクトルを使用しています。

with tf.device('/cpu:0'), tf.name_scope("embedding"):
    W = tf.Variable(
        tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0),
        name="W")
    self.embedded_chars = tf.nn.embedding_lookup(W, self.input_x)
    self.embedded_chars_expanded = tf.expand_dims(self.embedded_chars, -1)

Word2vecやGloVeで学習した単語埋め込みの結果を、ランダムな単語埋め込みの代わりに利用する方法をご存知の方はいらっしゃいませんか?

どのように解決するのですか?

TensorFlowで事前に学習させたembeddingを使うには、いくつかの方法があります。例えば、エンベッディングをNumPyの配列で embedding というNumPyの配列に埋め込むとします。 vocab_size の行と embedding_dim 列からなるテンソル W への呼び出しで使用できるテンソルを作りたい。 tf.nn.embedding_lookup() .

  1. 単に作成する W tf.constant() を取ることで embedding を値として取ります。

    W = tf.constant(embedding, name="W")
    
    

    これは最も簡単な方法ですが,メモリ効率は良くありません. tf.constant() の値はメモリ上に複数回保存されるからです。なぜなら embedding は非常に大きくなる可能性があるので、この方法はおもちゃのような例だけに使うようにしましょう。

  2. 作成 W として tf.Variable として、NumPyの配列からそれを初期化するために tf.placeholder() :

    W = tf.Variable(tf.constant(0.0, shape=[vocab_size, embedding_dim]),
                    trainable=False, name="W")
    
    embedding_placeholder = tf.placeholder(tf.float32, [vocab_size, embedding_dim])
    embedding_init = W.assign(embedding_placeholder)
    
    # ...
    sess = tf.Session()
    
    sess.run(embedding_init, feed_dict={embedding_placeholder: embedding})
    
    

    のコピーを保存することを避けます。 embedding のコピーをグラフに保存することは避けられますが、一度に2つの行列のコピー(1つはNumPyの配列用、もう1つは tf.Variable ). なお、学習中は埋め込み行列を一定に保ちたいものと仮定していますので、その場合は W と共に作成されます。 trainable=False .

  3. エンベッディングが他のTensorFlowモデルの一部として学習されたものである場合は tf.train.Saver を使って他のモデルのチェックポイントファイルから値を読み込むことができます。つまり、エンベッディングマトリックスはPythonを完全にバイパスすることができます。作成 W をオプション2のように作成し、次のようにします。

    W = tf.Variable(...)
    
    embedding_saver = tf.train.Saver({"name_of_variable_in_other_model": W})
    
    # ...
    sess = tf.Session()
    embedding_saver.restore(sess, "checkpoint_filename.ckpt")