[解決済み] PyTorchのnn.Linearのクラス定義は何ですか?
質問
PyTorchで以下のようなコードを持っています。
import torch.nn as nn
import torch.nn.functional as F
class Network(nn.Module):
def __init__(self):
super().__init__()
self.hidden = nn.Linear(784, 256)
self.output = nn.Linear(256, 10)
def forward(self, x):
x = F.sigmoid(self.hidden(x))
x = F.softmax(self.output(x), dim=1)
return x
私の質問です。これは何ですか?
self.hidden
?
から返されます。
nn.Linear
を取ることができます。
x
を引数として与えます。の目的は一体何なのでしょうか?
self.hidden
?
解決方法は?
<ブロッククオートpytorchのnn.Linearのクラス定義は何ですか?
から ドキュメンテーション :
CLASS torch.nn.Linear(in_features, out_features, bias=True)
入力されたデータに線形変換を適用する。
y = x*W^T + b
パラメータです。
- in_features - 各入力サンプルのサイズ (すなわち x のサイズ)
- アウトフィーチャー - 各出力サンプルの大きさ(つまりyの大きさ)。
- バイアス - Falseに設定すると、レイヤーは加法性バイアスを学習しません。デフォルトは True
なお、ウェイト
W
形状を持つ
(out_features, in_features)
と偏り
b
形がある
(out_features)
. これらはランダムに初期化され,後で変更することができます(例えば,ニューラルネットワークの学習中に,何らかの最適化アルゴリズムによって更新されます).
ニューラルネットワークでは
self.hidden = nn.Linear(784, 256)
を定義しています。
隠された
(入力層と出力層の間にあることを意味する)。
完全連結線形層
を入力とする
x
形状の
(batch_size, 784)
ここで、バッチサイズは一度にネットワークに渡される(1つのテンソルとして)入力(それぞれサイズ784)の数であり、線形方程式によって変換する。
y = x*W^T + b
をテンソルに変換します。
y
形状の
(batch_size, 256)
. さらにシグモイド関数で変換される。
x = F.sigmoid(self.hidden(x))
(これは
nn.Linear
が追加されます)。
具体的な例を見てみましょう。
import torch
import torch.nn as nn
x = torch.tensor([[1.0, -1.0],
[0.0, 1.0],
[0.0, 0.0]])
in_features = x.shape[1] # = 2
out_features = 2
m = nn.Linear(in_features, out_features)
ここで
x
は3つの入力を含む(すなわち、バッチサイズは3)。
x[0]
,
x[1]
と
x[3]
で、それぞれサイズ 2、出力は形状
(batch size, out_features) = (3, 2)
.
パラメータ(重みと偏り)の値は
>>> m.weight
tensor([[-0.4500, 0.5856],
[-0.1807, -0.4963]])
>>> m.bias
tensor([ 0.2223, -0.6114])
(ランダムに初期化されるため、上記とは異なる値が得られる可能性が高いです)
と出力されます。
>>> y = m(x)
tensor([[-0.8133, -0.2959],
[ 0.8079, -1.1077],
[ 0.2223, -0.6114]])
と(裏では)計算されています。
y = x.matmul(m.weight.t()) + m.bias # y = x*W^T + b
すなわち
y[i,j] == x[i,0] * m.weight[j,0] + x[i,1] * m.weight[j,1] + m.bias[j]
ここで
i
は区間
[0, batch_size)
と
j
で
[0, out_features)
.
関連
-
pythonサイクルタスクスケジューリングツール スケジュール詳解
-
[解決済み】NameError: 名前 'self' が定義されていません。
-
[解決済み] for'ループでインデックスにアクセスする?
-
[解決済み] リスト内のアイテムのインデックスを検索する
-
[解決済み] Pythonのリストメソッドであるappendとextendの違いは何ですか?
-
[解決済み] __init__.py は何のためにあるのですか?
-
[解決済み] Pythonで静的なクラス変数は可能ですか?
-
[解決済み] Could not find or load main class "とはどういう意味ですか?
-
[解決済み】if __name__ == "__main__": は何をするのでしょうか?
-
[解決済み】__str__と__repr__の違いは何ですか?
最新
-
nginxです。[emerg] 0.0.0.0:80 への bind() に失敗しました (98: アドレスは既に使用中です)
-
htmlページでギリシャ文字を使うには
-
ピュアhtml+cssでの要素読み込み効果
-
純粋なhtml + cssで五輪を実現するサンプルコード
-
ナビゲーションバー・ドロップダウンメニューのHTML+CSSサンプルコード
-
タイピング効果を実現するピュアhtml+css
-
htmlの選択ボックスのプレースホルダー作成に関する質問
-
html css3 伸縮しない 画像表示効果
-
トップナビゲーションバーメニュー作成用HTML+CSS
-
html+css 実装 サイバーパンク風ボタン
おすすめ
-
PythonによるLeNetネットワークモデルの学習と予測
-
python call matlab メソッドの詳細
-
Pythonの学習とデータマイニングのために知っておくべきターミナルコマンドのトップ10
-
Pythonを使って簡単なzipファイルの解凍パスワードを手作業で解く
-
風力制御におけるKS原理を深く理解するためのpythonアルゴリズム
-
Pythonの画像ファイル処理用ライブラリ「Pillow」(グラフィックの詳細)
-
[解決済み】csv.Error:イテレータはバイトではなく文字列を返すべき
-
[解決済み】LogisticRegression: Pythonでsklearnを使用して、未知のラベルタイプ: '連続'を使用しています。
-
[解決済み] TypeError: 'DataFrame' オブジェクトは呼び出し可能ではない
-
[解決済み】django インポートエラー - core.managementという名前のモジュールがない