diff --git a/build.sc b/build.sc index cda2106..d430e0d 100644 --- a/build.sc +++ b/build.sc @@ -277,6 +277,7 @@ class AlmondSpark(val crossScalaVersion: String) extends CrossSbtModule with Amm Deps.scalaKernelApi .exclude(("com.lihaoyi", s"ammonite-compiler_$crossScalaVersion")) .exclude(("com.lihaoyi", s"ammonite-repl-api_$crossScalaVersion")), + Deps.scalatags, Deps.sparkSql(scalaVersion()) ) def repositoriesTask = T.task { diff --git a/modules/almond-spark/src/main/scala/almond/spark/DataFrameRenderers.scala b/modules/almond-spark/src/main/scala/almond/spark/DataFrameRenderers.scala new file mode 100644 index 0000000..e8911be --- /dev/null +++ b/modules/almond-spark/src/main/scala/almond/spark/DataFrameRenderers.scala @@ -0,0 +1,30 @@ +package almond.spark + +import org.apache.spark.sql.{Dataset, Row} + +object DataFrameRenderer { + + // inspired by https://github.com/apache/incubator-toree/blob/5b19aac2e56a56d35c888acc4ed5e549b1f4ed7c/kernel/src/main/scala/org/apache/toree/utils/DataFrameConverter.scala#L47-L59 + def render(df: Dataset[Row], limit: Int = 10): String = { + import scalatags.Text.all._ + val columnFields = df.schema.fieldNames.toSeq.map(th(_)) + val columns = tr(columnFields) + val rows = df + .rdd + .map { row => + val fieldValues = row.toSeq.map(fieldToString).map(td(_)) + tr(fieldValues) + } + .take(limit) + table(columns, rows).render + } + + // https://github.com/apache/incubator-toree/blob/5b19aac2e56a56d35c888acc4ed5e549b1f4ed7c/kernel/src/main/scala/org/apache/toree/utils/DataFrameConverter.scala#L84-L89 + def fieldToString(any: Any): String = + any match { + case null => "null" + case seq: Seq[_] => seq.mkString("[", ", ", "]") + case _ => any.toString + } + +} diff --git a/project/deps.sc b/project/deps.sc index 091aa99..6d2b1dc 100644 --- a/project/deps.sc +++ b/project/deps.sc @@ -23,6 +23,7 @@ object Deps { ivy"com.github.plokhotnyuk.jsoniter-scala::jsoniter-scala-macros:${Versions.jsoniterScala}" def log4j2 = ivy"org.apache.logging.log4j:log4j-core:2.17.2" def scalaKernelApi = ivy"sh.almond:::scala-kernel-api:0.14.0-RC6" + def scalatags = ivy"com.lihaoyi::scalatags:0.12.0" def sparkSql(sv: String) = { val ver = if (sv.startsWith("2.12.")) "2.4.0"