Skip to content

Commit

Permalink
Go: introduce special GoOutputWriter and move generation of error che…
Browse files Browse the repository at this point in the history
…ck to it

This will allow to override this method and do not generate check in KST when
performing checks for returned errors using `asserts[i].exception` key
  • Loading branch information
Mingun committed Apr 13, 2024
1 parent 128e179 commit 2dadf6c
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -950,7 +950,7 @@ class TranslatorSpec extends AnyFunSpec {
eo = Some(Expressions.parse(src))
}

val goOutput = new StringLanguageOutputWriter(" ")
val goOutput = new GoOutputWriter(" ")

val langs = ListMap[LanguageCompilerStatic, AbstractTranslator with TypeDetector](
CppCompiler -> new CppTranslator(tp, new CppImportList(), new CppImportList(), RuntimeConfig()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import io.kaitai.struct.datatype._
import io.kaitai.struct.exprlang.Ast
import io.kaitai.struct.format._
import io.kaitai.struct.languages.components._
import io.kaitai.struct.translators.{GoTranslator, ResultString, TranslatorResult}
import io.kaitai.struct.translators.{GoOutputWriter, GoTranslator, ResultString, TranslatorResult}
import io.kaitai.struct.{ClassTypeProvider, RuntimeConfig, Utils}

class GoCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
Expand All @@ -19,6 +19,7 @@ class GoCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
with GoReads {
import GoCompiler._

override val out = new GoOutputWriter(indent)
override val translator = new GoTranslator(out, typeProvider, importList)

override def innerClasses = false
Expand Down Expand Up @@ -270,21 +271,21 @@ class GoCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)

override def pushPos(io: String): Unit = {
out.puts(s"_pos, err := $io.Pos()")
translator.outAddErrCheck()
out.putsErrCheck(translator.returnRes)
}

override def seek(io: String, pos: Ast.expr): Unit = {
importList.add("io")

out.puts(s"_, err = $io.Seek(int64(${expression(pos)}), io.SeekStart)")
translator.outAddErrCheck()
out.putsErrCheck(translator.returnRes)
}

override def popPos(io: String): Unit = {
importList.add("io")

out.puts(s"_, err = $io.Seek(_pos, io.SeekStart)")
translator.outAddErrCheck()
out.putsErrCheck(translator.returnRes)
}

override def alignToByte(io: String): Unit =
Expand All @@ -306,7 +307,7 @@ class GoCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)

val eofVar = translator.allocateLocalVar()
out.puts(s"${translator.localVarName(eofVar)}, err := this._io.EOF()")
translator.outAddErrCheck()
out.putsErrCheck(translator.returnRes)
out.puts(s"if ${translator.localVarName(eofVar)} {")
out.inc
out.puts("break")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,30 @@ sealed trait TranslatorResult
case class ResultString(s: String) extends TranslatorResult
case class ResultLocalVar(n: Int) extends TranslatorResult

class GoTranslator(out: StringLanguageOutputWriter, provider: TypeProvider, importList: ImportList)
class GoOutputWriter(indentStr: String) extends StringLanguageOutputWriter(indentStr) {
/**
* Puts to the output code to check variable `err` for error value and emit
* a premature return with value of error.
*
* @param result If not none this value will be returned as the first value of
* the returned tuple, otherwise only `err` is returned
*/
def putsErrCheck(result: Option[String]): Unit = {
puts("if err != nil {")
inc

val noValueAndErr = result match {
case None => "err"
case Some(r) => s"$r, err"
}

puts(s"return $noValueAndErr")
dec
puts("}")
}
}

class GoTranslator(out: GoOutputWriter, provider: TypeProvider, importList: ImportList)
extends TypeDetector(provider)
with AbstractTranslator
with CommonLiterals
Expand All @@ -26,6 +49,10 @@ class GoTranslator(out: StringLanguageOutputWriter, provider: TypeProvider, impo

import io.kaitai.struct.languages.GoCompiler._

/**
* Dummy return value that should be returned in case of error just because
* we cannot return nothing.
*/
var returnRes: Option[String] = None

override def translate(v: Ast.expr, extPrec: Int): String = resToStr(translateExpr(v, extPrec))
Expand Down Expand Up @@ -470,7 +497,7 @@ class GoTranslator(out: StringLanguageOutputWriter, provider: TypeProvider, impo
val addParams = t.args.map((a) => translate(a)).mkString(", ")
out.puts(s"${localVarName(v)} := New${GoCompiler.types2class(t.classSpec.get.name)}($addParams)")
out.puts(s"err = ${localVarName(v)}.Read($io, $parent, $root)")
outAddErrCheck()
out.putsErrCheck(returnRes)
ResultLocalVar(v)
}

Expand Down Expand Up @@ -502,7 +529,7 @@ class GoTranslator(out: StringLanguageOutputWriter, provider: TypeProvider, impo
def outVarCheckRes(expr: String): ResultLocalVar = {
val v1 = allocateLocalVar()
out.puts(s"${localVarName(v1)}, err := $expr")
outAddErrCheck()
out.putsErrCheck(returnRes)
ResultLocalVar(v1)
}

Expand All @@ -521,20 +548,6 @@ class GoTranslator(out: StringLanguageOutputWriter, provider: TypeProvider, impo

def localVarName(n: Int) = s"tmp$n"

def outAddErrCheck(): Unit = {
out.puts("if err != nil {")
out.inc

val noValueAndErr = returnRes match {
case None => "err"
case Some(r) => s"$r, err"
}

out.puts(s"return $noValueAndErr")
out.dec
out.puts("}")
}

override def byteSizeOfValue(attrName: String, valType: DataType): TranslatorResult =
trIntLiteral(CommonSizeOf.bitToByteSize(CommonSizeOf.getBitsSizeOfType(attrName, valType)))
}

0 comments on commit 2dadf6c

Please sign in to comment.