1. ホーム
  2. python

[解決済み] scikit-learnの決定木から決定規則を抽出する方法は?

2022-04-15 12:26:49

質問

決定木の学習済みツリーから、基本となる決定規則(または「決定パス」)をテキストリストとして抽出することは可能ですか?

のようなものです。

if A>0.4 then if B<0.2 then if C>0.8 then class='X'

解決方法は?

この回答は、他の回答よりも正しいと思います。

from sklearn.tree import _tree

def tree_to_code(tree, feature_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    print "def tree({}):".format(", ".join(feature_names))

    def recurse(node, depth):
        indent = "  " * depth
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            print "{}if {} <= {}:".format(indent, name, threshold)
            recurse(tree_.children_left[node], depth + 1)
            print "{}else:  # if {} > {}".format(indent, name, threshold)
            recurse(tree_.children_right[node], depth + 1)
        else:
            print "{}return {}".format(indent, tree_.value[node])

    recurse(0, 1)

これは有効なPython関数をプリントアウトするものです。以下は、入力である0から10の間の数を返そうとする木の出力例です。

def tree(f0):
  if f0 <= 6.0:
    if f0 <= 1.5:
      return [[ 0.]]
    else:  # if f0 > 1.5
      if f0 <= 4.5:
        if f0 <= 3.5:
          return [[ 3.]]
        else:  # if f0 > 3.5
          return [[ 4.]]
      else:  # if f0 > 4.5
        return [[ 5.]]
  else:  # if f0 > 6.0
    if f0 <= 8.5:
      if f0 <= 7.5:
        return [[ 7.]]
      else:  # if f0 > 7.5
        return [[ 8.]]
    else:  # if f0 > 8.5
      return [[ 9.]]

他の回答で見かけたつまずきを紹介します。

  1. 使用方法 tree_.threshold == -2 を使用して、ノードが葉であるかどうかを決定するのは良いアイデアではありません。しきい値が-2の本当の判定ノードだったらどうする?その代わりに tree.feature または tree.children_* .
  2. ライン features = [feature_names[i] for i in tree_.feature] は私のバージョンの sklearn ではクラッシュします。 tree.tree_.feature が -2 (特にリーフノード) である。
  3. 再帰関数内に複数のif文がある必要はなく、1つで良い。