返回 导航

Spark

hangge.com

Spark - SparkSQL使用详解4(UDF、UDAF)

作者:hangge | 2023-11-24 08:50

一、UDF(用户自定义函数)

1,基本介绍

  • UDFUser-Defined Function)是用户自定义函数,它允许我们在 Spark SQL 中创建自定义函数,以对 DataFrame 中的每个元素进行处理,并返回一个新的元素。
  • UDF 可以用于单个数据项的转换,类似于对 DataFrame 的某一列进行自定义操作。
  • UDF 的创建可以通过 Spark 提供的编程 API 来实现。

2,使用样例

(1)下面代码通过 SQL 语法来处理 DataFrame,并使用 SELECT 语句调用我们注册的 UDF 函数 firstLetterToUpper 来实现将名字的首字母变成大写。
object Hello {
  def main(args: Array[String]): Unit = {
    // 创建 Spark 运行配置对象
    val sparkConf = new SparkConf().setMaster("local[*]").setAppName("Hello")
    //创建 SparkSession 对象
    val spark: SparkSession = SparkSession.builder().config(sparkConf).getOrCreate()

    // 导入隐式转换
    import spark.implicits._

    // 读取原始数据,创建DataFrame
    val data = Seq(
      (1, "hangge", 30),
      (2, "guge", 25),
      (3, "baige", 15)
    )
    val df:DataFrame = data.toDF("id", "name", "age")

    // 注册UDF
    spark.udf.register("firstLetterToUpper",
      (input: String) => input.substring(0, 1).toUpperCase() + input.substring(1))

    // 创建临时表
    df.createOrReplaceTempView("employees")

    // 使用SQL语法处理DataFrame并显示结果
    val resultDF = spark.sql("SELECT id, firstLetterToUpper(name) AS name_upper FROM employees")
    resultDF.show()

    //关闭 Spark
    spark.stop()
  }
}

(2)运行结果如下:

(3)当然我们也可以使用 DSL 语法进行查询,最终结果是一样的:
object Hello {
  def main(args: Array[String]): Unit = {
    // 创建 Spark 运行配置对象
    val sparkConf = new SparkConf().setMaster("local[*]").setAppName("Hello")
    //创建 SparkSession 对象
    val spark: SparkSession = SparkSession.builder().config(sparkConf).getOrCreate()

    // 导入隐式转换
    import spark.implicits._

    // 读取原始数据,创建DataFrame
    val data = Seq(
      (1, "hangge", 30),
      (2, "guge", 25),
      (3, "baige", 15)
    )
    val df:DataFrame = data.toDF("id", "name", "age")

    // 注册UDF
    spark.udf.register("firstLetterToUpper",
      (input: String) => input.substring(0, 1).toUpperCase() + input.substring(1))

    // 使用DSL语法处理DataFrame并显示结果
    val resultDF = df.select($"id", expr("firstLetterToUpper(name)").as("name_upper"))
    resultDF.show()

    //关闭 Spark
    spark.stop()
  }
}

二、UDAF(用户自定义聚合函数)

1,基本介绍

(1)UDAFUser-Defined Aggregate Functio)是用户自定义聚合函数,它允许我们在 Spark SQL 中定义自己的聚合函数,从而实现复杂的聚合逻辑,如计算平均值、拼接字符串、自定义统计等。

(2)UDAF 又可以分为自定义弱类型聚合函数(UserDefinedAggregateFunction)和强类型聚合函数(Aggregator):
  • 自定义弱类型聚合函数UserDefinedAggregateFunction)是一个抽象类,我们需要继承它并实现一些方法来定义自己的聚合函数。这些方法包括 inputSchemabufferSchemadataTypeinitializeupdatemergeevaluate 等。但是,这种方法对数据类型的约束较少,因为数据在处理过程中通常以 Any 类型传递。
  • 强类型聚合函数Aggregator)是一个泛型类,它允许我们定义输入和缓冲区的数据类型,并在编译时进行类型检查。我们需要提供两个函数:zero 函数用于初始化聚合缓冲区,reduce 函数用于更新聚合缓冲区。
注意:从 Spark3.0 版本后,UserDefinedAggregateFunction 已经不推荐使用了。可以统一采用强类型聚合函数 Aggregator

2,使用样例

(1)下面我们使用强类型聚合函数计算一组学生的平均分和及格率,首先我们定义如下内容:
  • 定义一个 Student 类来表示学生的数据
  • 定义一个 StudentStats 类来表示学生的统计信息,其中包括平均分和及格率。
  • 定义一个强类型聚合函数 StudentStatsAggregator,它使用一个三元组(Double,Long,Long)作为缓冲区,分别存储总分、及格人数和总人数。在 reduce 方法中,我们根据每个学生的成绩更新缓冲区的值。merge 方法用于合并两个缓冲区的值。在 finish 方法中,我们根据缓冲区的值计算平均分和及格率,并返回 StudentStats 对象。
// 表示学生的数据结构
case class Student(id: Int, name: String, score: Double)

// 表示学生统计信息的数据结构,包括平均分和及格率
case class StudentStats(averageScore: Double, passRate: Double)

// 定义一个强类型聚合函数(类继承 org.apache.spark.sql.expressions.Aggregator)
class StudentStatsAggregator extends Aggregator[Student, (Double, Long, Long), StudentStats] {
  // 初始缓冲区的值,由三元组 (总分,及格人数,总人数) 组成
  def zero: (Double, Long, Long) = (0.0, 0L, 0L)

  // 更新缓冲区的值,根据每个学生的成绩进行更新
  def reduce(buffer: (Double, Long, Long), student: Student): (Double, Long, Long) = {
    val (totalScore, passCount, totalCount) = buffer
    val newTotalScore = totalScore + student.score
    val newPassCount = if (student.score >= 60) passCount + 1 else passCount
    val newTotalCount = totalCount + 1
    (newTotalScore, newPassCount, newTotalCount)
  }

  // 合并两个缓冲区的值
  def merge(b1: (Double, Long, Long), b2: (Double, Long, Long)): (Double, Long, Long) = {
    (b1._1 + b2._1, b1._2 + b2._2, b1._3 + b2._3)
  }

  // 根据缓冲区的值计算学生的平均分和及格率,并返回 StudentStats 对象
  def finish(reduction: (Double, Long, Long)): StudentStats = {
    val (totalScore, passCount, totalCount) = reduction
    val averageScore = totalScore / totalCount
    val passRate = passCount.toDouble / totalCount
    StudentStats(averageScore, passRate)
  }

  // Spark 用于序列化缓冲区的编码器
  def bufferEncoder: Encoder[(Double, Long, Long)] = Encoders.product

  // Spark 用于序列化输出结果的编码器
  def outputEncoder: Encoder[StudentStats] = Encoders.product
}

(2)接着我们就可以使用这个自定义的强类型聚合函数来计算学生的平均分和及格率:
object Hello {
  def main(args: Array[String]): Unit = {
    // 创建 Spark 运行配置对象
    val sparkConf = new SparkConf().setMaster("local[*]").setAppName("Hello")
    //创建 SparkSession 对象
    val spark: SparkSession = SparkSession.builder().config(sparkConf).getOrCreate()

    // 导入隐式转换
    import spark.implicits._

    // 读取原始数据,创建Dataset
    val ds = Seq(
      Student(1, "小刘", 85.0),
      Student(2, "小李", 62.5),
      Student(3, "老余", 90.0),
      Student(4, "老杨", 45.0)
    ).toDS()

    // 创建聚合函数
    val studentStatsAggregator = new StudentStatsAggregator

    //将聚合函数转换为查询的列进行查询
    val stats = ds.select(studentStatsAggregator.toColumn).as[StudentStats].first()
    println("平均分: " + stats.averageScore)
    println("及格率: " + stats.passRate)

    //关闭 Spark
    spark.stop()
  }
}

(3)当然我们也可以不使用 DSL 语法而是使用 SQL 语法,最终得到的结果是一样的:
object Hello {
  def main(args: Array[String]): Unit = {
    // 创建 Spark 运行配置对象
    val sparkConf = new SparkConf().setMaster("local[*]").setAppName("Hello")
    //创建 SparkSession 对象
    val spark: SparkSession = SparkSession.builder().config(sparkConf).getOrCreate()

    // 导入隐式转换
    import spark.implicits._

    // 读取原始数据,创建Dataset
    val ds = Seq(
      Student(1, "小刘", 85.0),
      Student(2, "小李", 62.5),
      Student(3, "老余", 90.0),
      Student(4, "老杨", 45.0)
    ).toDS()

    // 创建聚合函数
    val studentStatsAggregator = new StudentStatsAggregator
    // 注册UDF
    spark.udf.register("studentStatsAggregator", functions.udaf(studentStatsAggregator))

    // 创建临时表
    ds.createOrReplaceTempView("students")

    // 使用SQL语法进行查询
    val resultDS:Dataset[Row] = spark.sql(
      "SELECT studentStatsAggregator(*) as stats, " +
        "stats.averageScore as averageScore, " +
        "stats.passRate as passRate " +
        "FROM students"
    )
    resultDS.show()

    // 将 Dataset[Row] 转换成 Dataset[StudentStats]
    val statsDS: Dataset[StudentStats] = resultDS.as[StudentStats]
    statsDS.show()

    // 获取 StudentStats
    val stats = statsDS.first()
    println("平均分: " + stats.averageScore)
    println("及格率: " + stats.passRate)

    //关闭 Spark
    spark.stop()
  }
}
评论

全部评论(0)

回到顶部