1. ホーム
  2. apache-spark

[解決済み] Spark SQL: カラムのリストに集約関数を適用する

2023-03-19 23:36:29

質問

データフレームのすべての列(またはリスト)に対して、集約関数を適用する方法はありますか? groupBy ? 言い換えれば、すべての列に対してこれを行うことを回避する方法はありますか。

df.groupBy("col1")
  .agg(sum("col2").alias("col2"), sum("col3").alias("col3"), ...)

どのように解決するのですか?

複数の列に対して集計関数を適用する方法は複数あります。

GroupedData クラスは、最も一般的な関数のために、以下のような多くのメソッドを提供します。 count , max , min , meansum であり,以下のように直接利用することができる.

  • Pythonです。

    df = sqlContext.createDataFrame(
        [(1.0, 0.3, 1.0), (1.0, 0.5, 0.0), (-1.0, 0.6, 0.5), (-1.0, 5.6, 0.2)],
        ("col1", "col2", "col3"))
    
    df.groupBy("col1").sum()
    
    ## +----+---------+-----------------+---------+
    ## |col1|sum(col1)|        sum(col2)|sum(col3)|
    ## +----+---------+-----------------+---------+
    ## | 1.0|      2.0|              0.8|      1.0|
    ## |-1.0|     -2.0|6.199999999999999|      0.7|
    ## +----+---------+-----------------+---------+
    
    
  • Scala

    val df = sc.parallelize(Seq(
      (1.0, 0.3, 1.0), (1.0, 0.5, 0.0),
      (-1.0, 0.6, 0.5), (-1.0, 5.6, 0.2))
    ).toDF("col1", "col2", "col3")
    
    df.groupBy($"col1").min().show
    
    // +----+---------+---------+---------+
    // |col1|min(col1)|min(col2)|min(col3)|
    // +----+---------+---------+---------+
    // | 1.0|      1.0|      0.3|      0.0|
    // |-1.0|     -1.0|      0.6|      0.2|
    // +----+---------+---------+---------+
    
    

オプションで、集約する列のリストを渡すことができます。

df.groupBy("col1").sum("col2", "col3")

カラムをキー、関数を値とする辞書/マップを渡すことも可能です。

  • Python

    exprs = {x: "sum" for x in df.columns}
    df.groupBy("col1").agg(exprs).show()
    
    ## +----+---------+
    ## |col1|avg(col3)|
    ## +----+---------+
    ## | 1.0|      0.5|
    ## |-1.0|     0.35|
    ## +----+---------+
    
    
  • Scala

    val exprs = df.columns.map((_ -> "mean")).toMap
    df.groupBy($"col1").agg(exprs).show()
    
    // +----+---------+------------------+---------+
    // |col1|avg(col1)|         avg(col2)|avg(col3)|
    // +----+---------+------------------+---------+
    // | 1.0|      1.0|               0.4|      0.5|
    // |-1.0|     -1.0|3.0999999999999996|     0.35|
    // +----+---------+------------------+---------+
    
    

最後にvarargsを使用することができます。

  • Python

    from pyspark.sql.functions import min
    
    exprs = [min(x) for x in df.columns]
    df.groupBy("col1").agg(*exprs).show()
    
    
  • Scala

    import org.apache.spark.sql.functions.sum
    
    val exprs = df.columns.map(sum(_))
    df.groupBy($"col1").agg(exprs.head, exprs.tail: _*)
    
    

同様の効果を得るための他の方法もありますが、ほとんどの場合、これで十分すぎるほどの効果が得られるはずです。

こちらもご覧ください。