IT博客汇
  • 首页
  • 精华
  • 技术
  • 设计
  • 资讯
  • 扯淡
  • 权利声明
  • 登录 注册

    spark 实现 mysql upsert

    hpkaiq发表于 2023-04-18 14:23:06
    love 0

    实现 spark dataframe/dataset 根据mysql表唯一键实现有则更新,无则插入功能。

    基于 spark2.4.3 scala2.11.8

    工具类 DataFrameWriterEnhance

      1
      2
      3
      4
      5
      6
      7
      8
      9
     10
     11
     12
     13
     14
     15
     16
     17
     18
     19
     20
     21
     22
     23
     24
     25
     26
     27
     28
     29
     30
     31
     32
     33
     34
     35
     36
     37
     38
     39
     40
     41
     42
     43
     44
     45
     46
     47
     48
     49
     50
     51
     52
     53
     54
     55
     56
     57
     58
     59
     60
     61
     62
     63
     64
     65
     66
     67
     68
     69
     70
     71
     72
     73
     74
     75
     76
     77
     78
     79
     80
     81
     82
     83
     84
     85
     86
     87
     88
     89
     90
     91
     92
     93
     94
     95
     96
     97
     98
     99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    
    package com.xxx.utils
    
    import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
    import org.apache.spark.sql.execution.SQLExecution
    import org.apache.spark.sql.execution.datasources.DataSource
    import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcOptionsInWrite, JdbcRelationProvider, JdbcUtils}
    import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects}
    import org.apache.spark.sql.sources.BaseRelation
    import org.apache.spark.sql.types.StructType
    import org.apache.spark.sql._
    
    import java.sql.Connection
    
    object DataFrameWriterEnhance {
    
      implicit class DataFrameWriterMysqlUpdateEnhance(writer: DataFrameWriter[Row]) {
        def update(): Unit = {
          val extraOptionsField = writer.getClass.getDeclaredField("org$apache$spark$sql$DataFrameWriter$$extraOptions")
          val dfField = writer.getClass.getDeclaredField("df")
          val sourceField = writer.getClass.getDeclaredField("source")
          val partitioningColumnsField = writer.getClass.getDeclaredField("partitioningColumns")
          extraOptionsField.setAccessible(true)
          dfField.setAccessible(true)
          sourceField.setAccessible(true)
          partitioningColumnsField.setAccessible(true)
          val extraOptions = extraOptionsField.get(writer).asInstanceOf[scala.collection.Map[String, String]]
          val df = dfField.get(writer).asInstanceOf[DataFrame]
          val partitioningColumns = partitioningColumnsField.get(writer).asInstanceOf[Option[Seq[String]]]
          val logicalPlanField = df.getClass.getDeclaredField("logicalPlan")
          logicalPlanField.setAccessible(true)
          var logicalPlan = logicalPlanField.get(df).asInstanceOf[LogicalPlan]
          val session = df.sparkSession
          val dataSource = DataSource(
            sparkSession = session,
            className = s"${DataFrameWriterEnhance.getClass.getName}MysqlUpdateRelationProvider",
            partitionColumns = partitioningColumns.getOrElse(Nil),
            options = extraOptions.toMap)
          logicalPlan = dataSource.planForWriting(SaveMode.Append, logicalPlan)
          val qe = session.sessionState.executePlan(logicalPlan)
          SQLExecution.withNewExecutionId(session, qe)(qe.toRdd)
        }
      }
    
      class MysqlUpdateRelationProvider extends JdbcRelationProvider {
        override def createRelation(sqlContext: SQLContext, mode: SaveMode, parameters: Map[String, String], df: DataFrame): BaseRelation = {
          val options = new JdbcOptionsInWrite(parameters)
          val isCaseSensitive = sqlContext.sparkSession.sessionState.conf.caseSensitiveAnalysis
          val conn = JdbcUtils.createConnectionFactory(options)()
          try {
            val tableExists = JdbcUtils.tableExists(conn, options)
            if (tableExists) {
              mode match {
                case SaveMode.Overwrite =>
                  if (options.isTruncate && JdbcUtils.isCascadingTruncateTable(options.url).contains(false)) {
                    // In this case, we should truncate table and then load.
                    JdbcUtils.truncateTable(conn, options)
                    val tableSchema = JdbcUtils.getSchemaOption(conn, options)
                    updateTable(df, tableSchema, isCaseSensitive, options)
                  } else {
                    // Otherwise, do not truncate the table, instead drop and recreate it
                    JdbcUtils.dropTable(conn, options.table, options)
                    JdbcUtils.createTable(conn, df, options)
                    updateTable(df, Some(df.schema), isCaseSensitive, options)
                  }
    
                case SaveMode.Append =>
                  val tableSchema = JdbcUtils.getSchemaOption(conn, options)
                  updateTable(df, tableSchema, isCaseSensitive, options)
    
                case SaveMode.ErrorIfExists =>
                  throw new Exception(
                    s"Table or view '${options.table}' already exists. " +
                      s"SaveMode: ErrorIfExists.")
    
                case SaveMode.Ignore =>
                // With `SaveMode.Ignore` mode, if table already exists, the save operation is expected
                // to not save the contents of the DataFrame and to not change the existing data.
                // Therefore, it is okay to do nothing here and then just return the relation below.
              }
            } else {
              JdbcUtils.createTable(conn, df, options)
              updateTable(df, Some(df.schema), isCaseSensitive, options)
            }
          } finally {
            conn.close()
          }
    
          createRelation(sqlContext, parameters)
        }
    
        def updateTable(df: DataFrame,
                        tableSchema: Option[StructType],
                        isCaseSensitive: Boolean,
                        options: JdbcOptionsInWrite): Unit = {
          val url = options.url
          val table = options.table
          val dialect = JdbcDialects.get(url)
          val rddSchema = df.schema
          val getConnection: () => Connection = JdbcUtils.createConnectionFactory(options)
          val batchSize = options.batchSize
          val isolationLevel = options.isolationLevel
    
          val updateStmt = getUpdateStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect)
          println(updateStmt)
          val repartitionedDF = options.numPartitions match {
            case Some(n) if n <= 0 => throw new IllegalArgumentException(
              s"Invalid value `$n` for parameter `${JDBCOptions.JDBC_NUM_PARTITIONS}` in table writing " +
                "via JDBC. The minimum value is 1.")
            case Some(n) if n < df.rdd.partitions.length => df.coalesce(n)
            case _ => df
          }
          repartitionedDF.rdd.foreachPartition(iterator => JdbcUtils.savePartition(
            getConnection, table, iterator, rddSchema, updateStmt, batchSize, dialect, isolationLevel,
            options)
          )
        }
    
        def getUpdateStatement(table: String,
                               rddSchema: StructType,
                               tableSchema: Option[StructType],
                               isCaseSensitive: Boolean,
                               dialect: JdbcDialect): String = {
          val columns = if (tableSchema.isEmpty) {
            rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
          } else {
            val columnNameEquality = if (isCaseSensitive) {
              org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
            } else {
              org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
            }
            // The generated insert statement needs to follow rddSchema's column sequence and
            // tableSchema's column names. When appending data into some case-sensitive DBMSs like
            // PostgreSQL/Oracle, we need to respect the existing case-sensitive column names instead of
            // RDD column names for user convenience.
            val tableColumnNames = tableSchema.get.fieldNames
            rddSchema.fields.map { col =>
              val normalizedName = tableColumnNames.find(f => columnNameEquality(f, col.name)).getOrElse {
                throw new Exception(s"""Column "${col.name}" not found in schema $tableSchema""")
              }
              dialect.quoteIdentifier(normalizedName)
            }.mkString(",")
          }
          val placeholders = rddSchema.fields.map(_ => "?").mkString(",")
          s"""INSERT INTO $table ($columns) VALUES ($placeholders)
             |ON DUPLICATE KEY UPDATE
             |${columns.split(",").map(col => s"$col=VALUES($col)").mkString(",")}
             |""".stripMargin
        }
      }
    }
    

    工具类 MysqlUtils

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    
    package com.xxx.utils
    
    import com.xxx.utils.DataFrameWriterEnhance.DataFrameWriterMysqlUpdateEnhance
    import org.apache.spark.sql.functions.col
    import org.apache.spark.sql.types.{NullType, ShortType}
    import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}
    
    object MysqlUtils {
    
      def upsert(rawDF: DataFrame, database: String, tableName: String)(implicit spark: SparkSession): Unit = {
        var df = rawDF
        for (elem <- df.schema.fields) {
          if (elem.dataType == NullType) {
            df = df.withColumn(elem.name, col(elem.name).cast(ShortType))
          }
        }
    
        df.write
          .format("jdbc")
          .mode(SaveMode.Append)
          .option("driver", "com.mysql.jdbc.Driver")
          .option("url", spark.conf.get(s"spark.job.mysql.${database}.url"))
          .option("user", spark.conf.get(s"spark.job.mysql.${database}.username"))
          .option("password", spark.conf.get(s"spark.job.mysql.${database}.password"))
          .option("dbtable", tableName)
          .option("useSSL", "false")
          .option("showSql", "false")
          .option("numPartitions", "1")
          .update()
      }
    
    
    }
    

    使用

    spark启动脚本加入mysql配置

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    
    spark-submit \
    --master yarn \
    --deploy-mode cluster \
    --executor-memory 3G \
    --num-executors 5 \
    --executor-cores 4 \
    --driver-memory 3G \
    --conf spark.job.mysql.test.url=${jdbc_url} \
    --conf spark.job.mysql.test.username=${jdbc_username} \
    --conf spark.job.mysql.test.password=${jdbc_password} \
    

    使用范例

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    
    import utils.MysqlUtils
    
    object TestMysqlUpsert {
      def main(args: Array[String]): Unit = {
        implicit val spark = SparkSession.builder().enableHiveSupport().getOrCreate()
        import spark.implicits._
    
        val database = "test"
        val arr = Array((1,11,"name1",11111),(2,22,"name2",22222))
        val df = spark.sparkContext.parallelize(arr)
          .toDF("key_one", "key_two", "val_one", "val_two")
    
        MysqlUtils.upsert(df, database, "test_unique_key")
        spark.close()
    
      }
    }
    

    test_unique_key表结构

    1
    2
    3
    4
    5
    6
    7
    
    CREATE TABLE `test_unique_key` (
      `key_one` int(11) NOT NULL DEFAULT '0',
      `key_two` int(11) NOT NULL DEFAULT '0',
      `val_one` varchar(50) DEFAULT NULL,
      `val_two` int(11) NOT NULL DEFAULT '0',
      UNIQUE KEY `uk` (`key_one`,`key_two`)
    ) ENGINE=InnoDB DEFAULT CHARSET=utf8 COMMENT='test';
    

    参考

    csdn_Spark Upsert写入Mysql(scala增强) 无需依赖



沪ICP备19023445号-2号
友情链接