微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

在spark mapPartition中使用迭代器进行优化

一般在使用mapPartition时,往往会跟随着文件的创建或者数据库的连接等,此时我们需要在创建一个容器,用于存储维表关联后的数据,但这有一个问题,创建的容器会占用内存的,这时我们可以使用迭代器进行优化。

 

一、普遍方法

package org.shydow

import java.sql.{Connection, PreparedStatement}

import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}
import org.shydow.DBPool.MysqLPoolManager

import scala.collection.mutable.ListBuffer

/**
 * @author shydow
 * @date 2021-12-13
 * @desc mapPartition一般使用方法
 */
object TestMapPartition {

  case class Event(eventId: String, eventName: String, pv: Long, stayTime: String)

  def main(args: Array[String]): Unit = {

    val conf: SparkConf = new SparkConf().setAppName("test-mapPartition").setMaster("local[*]")
    val sc = new SparkContext(conf)

    val lines: RDD[String] = sc.textFile("/app/event_log.txt", 4)
    lines.mapPartitions { it =>
      val conn: Connection = MysqLPoolManager.getMysqLManager.getConnection
      val ps: PreparedStatement = conn.prepareStatement("select event_name from dim_event_info where event_id = ?")
      val list: ListBuffer[Event] = ListBuffer[Event]()
      while (it.hasNext) {
        val line: String = it.next()
        val arr: Array[String] = line.split(",")
        ps.setString(1, arr(0))
     val res = ps.executEQuery()
        var eventName: String = null
     while(res.next){
      eventName = res.getString("event_name")
        }
        list.append(Event(arr(0), eventName, arr(2).toLong, arr(3)))
      }
      list.toIterator
    }

    sc.stop()
  }
}

 

 

二、使用迭代器进行优化

package org.shydow

import java.sql.{Connection, PreparedStatement, ResultSet}

import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}
import org.shydow.DBPool.MysqLPoolManager


/**
 * @author shydow
 * @date 2021-12-13
 * @desc 测试mapPartition中进行维表关联时使用迭代器进行优化
 */
object TestMapPartition {

  def main(args: Array[String]): Unit = {

    val conf: SparkConf = new SparkConf().setAppName("test-mapPartition").setMaster("local[*]")
    val sc = new SparkContext(conf)

    val lines: RDD[String] = sc.textFile("/app/event_log.txt", 4)
    lines.mapPartitions(new LookupEventIter(_))

    sc.stop()
  }
}

case class Event(eventId: String, eventName: String, pv: Long, stayTime: String)

class LookupEventIter(it: Iterator[String]) extends Iterator[Event] {

  private val conn: Connection = MysqLPoolManager.getMysqLManager.getConnection
  private val ps: PreparedStatement = conn.prepareStatement("select event_name from dim_event_info where event_id = ?")

  override def hasNext: Boolean = {
    if (it.hasNext) true
    else {
      ps.close()
      conn.close()
      false
    }
  }

  override def next(): Event = {
    val line: String = it.next()
    val arr: Array[String] = line.split(",")
    var eventName: String = null
    ps.setString(1, arr(0))
    val res: ResultSet = ps.executeQuery()
    while (res.next()) {
      eventName = res.getString("event_name")
    }
    Event(arr(0), eventName, arr(2).toLong, arr(3))
  }
}

 

三、数据库连接池,使用cpd3

package org.shydow.DBPool

import java.sql.Connection

/**
 * @author shydow
 * @date 2021-10-09
 */

class MysqLPool extends Serializable {

  private val cpd = new ComboPooledDataSource(true)
  try {
    cpd.setJdbcUrl(Constants.MysqL_URL)
    cpd.setDriverClass(Constants.MysqL_DRIVER)
    cpd.setUser(Constants.MysqL_USER)
    cpd.setPassword(Constants.MysqL_PASSWORD)
    cpd.setAcquireIncrement(Constants.MysqL_AC)
    cpd.setMinPoolSize(Constants.MysqL_MINPS)
    cpd.setMaxPoolSize(Constants.MysqL_MAXPS)
    cpd.setMaxStatements(Constants.MysqL_MAXST)
  } catch {
    case e: Exception => e.printstacktrace()
  }

  def getConnection: Connection = {
    try {
      cpd.getConnection()
    } catch {
      case e: Exception =>
        e.printstacktrace()
        null
    }
  }

  def close(): Unit = {
    try {
      cpd.close()
    } catch {
      case e: Exception => e.printstacktrace()
    }
  }
}
package org.shydow.DBPool

/**
 * @author shydow
 * @date 2021-10-09
 */

object MysqLPoolManager {
  var mm: MysqLPool = _

  def getMysqLManager: MysqLPool = {
    synchronized {
      if (mm == null) {
        mm = new MysqLPool
      }
    }
    mm
  }
}

 

 

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐