package is.hail.expr.types.virtual

import is.hail.annotations._
import is.hail.check.{Arbitrary, Gen}
import is.hail.expr.ir.IRParser
import is.hail.expr.types._
import is.hail.expr.types.physical.PType
import is.hail.expr.{JSONAnnotationImpex, SparkAnnotationImpex}
import is.hail.utils
import is.hail.utils._
import is.hail.variant.ReferenceGenome
import org.apache.spark.sql.types.DataType
import org.json4s.JsonAST.JString
import org.json4s.{CustomSerializer, JValue}

import scala.reflect.ClassTag

class TypeSerializer extends CustomSerializer[Type](format => (
  { case JString(s) => IRParser.parseType(s) },
  { case t: Type => JString(t.parsableString()) }))

object Type {
  def genScalar(): Gen[Type] =
    Gen.oneOf(TBoolean, TInt32, TInt64, TFloat32,
      TFloat64, TString, TCall)

  def genComplexType(): Gen[ComplexType] = {
    val rgDependents = ReferenceGenome.references.values.toArray.map(rg =>
      TLocus(rg))
    val others = Array(TCall)
    Gen.oneOfSeq(rgDependents ++ others)
  }

  def genFields(genFieldType: Gen[Type]): Gen[Array[Field]] = {
    Gen.buildableOf[Array](
      Gen.zip(Gen.identifier, genFieldType))
      .filter(fields => fields.map(_._1).areDistinct())
      .map(fields => fields
        .iterator
        .zipWithIndex
        .map { case ((k, t), i) => Field(k, t, i) }
        .toArray)
  }

  def preGenStruct(genFieldType: Gen[Type]): Gen[TStruct] = {
    for (fields <- genFields(genFieldType)) yield {
      TStruct(fields)
    }
  }

  def preGenTuple(genFieldType: Gen[Type]): Gen[TTuple] = {
    for (fields <- genFields(genFieldType)) yield {
      TTuple(fields.map(_.typ): _*)
    }
  }

  private val defaultRequiredGenRatio = 0.2
  def genStruct: Gen[TStruct] = Gen.coin(defaultRequiredGenRatio).flatMap(c => preGenStruct(genArb))

  def genSized(size: Int, genTStruct: Gen[TStruct]): Gen[Type] =
    if (size < 1)
      Gen.const(TStruct.empty)
    else if (size < 2)
      genScalar()
    else {
      Gen.frequency(
        (4, genScalar()),
        (1, genComplexType()),
        (1, genArb.map {
          TArray(_)
        }),
        (1, genArb.map {
          TSet(_)
        }),
        (1, genArb.map {
          TInterval(_)
        }),
        (1, preGenTuple(genArb)),
        (1, Gen.zip(genRequired, genArb).map { case (k, v) => TDict(k, v) }),
        (1, genTStruct.resize(size)))
    }

  def preGenArb(genStruct: Gen[TStruct] = genStruct): Gen[Type] =
    Gen.sized(genSized(_, genStruct))

  def genArb: Gen[Type] = preGenArb()

  val genOptional: Gen[Type] = preGenArb()

  val genRequired: Gen[Type] = preGenArb()

  def genWithValue: Gen[(Type, Annotation)] = for {
    s <- Gen.size
    // prefer smaller type and bigger values
    fraction <- Gen.choose(0.1, 0.3)
    x = (fraction * s).toInt
    y = s - x
    t <- Type.genStruct.resize(x)
    v <- t.genValue.resize(y)
  } yield (t, v)

  implicit def arbType = Arbitrary(genArb)
}

abstract class Type extends BaseType with Serializable {
  self =>

  def children: Seq[Type] = FastSeq()

  def clear(): Unit = children.foreach(_.clear())

  def unify(concrete: Type): Boolean =
    this == concrete

  def _isCanonical: Boolean = true

  final def isCanonical: Boolean = _isCanonical && children.forall(_.isCanonical)

  def isBound: Boolean = children.forall(_.isBound)

  def subst(): Type = this

  def insert(signature: Type, fields: String*): (Type, Inserter) = insert(signature, fields.toList)

  def insert(signature: Type, path: List[String]): (Type, Inserter) = {
    if (path.nonEmpty)
      TStruct.empty.insert(signature, path)
    else
      (signature, (a, toIns) => toIns)
  }

  def query(fields: String*): Querier = query(fields.toList)

  def query(path: List[String]): Querier = {
    val (t, q) = queryTyped(path)
    q
  }

  def queryTyped(fields: String*): (Type, Querier) = queryTyped(fields.toList)

  def queryTyped(path: List[String]): (Type, Querier) = {
    if (path.nonEmpty)
      throw new AnnotationPathException(s"invalid path ${ path.mkString(".") } from type ${ this }")
    else
      (this, identity[Annotation])
  }

  final def pretty(sb: StringBuilder, indent: Int, compact: Boolean) {
    _pretty(sb, indent, compact)
  }

  def _toPretty: String

  def _pretty(sb: StringBuilder, indent: Int, compact: Boolean) {
    sb.append(_toPretty)
  }

  def fieldOption(fields: String*): Option[Field] = fieldOption(fields.toList)

  def fieldOption(path: List[String]): Option[Field] =
    None

  def schema: DataType = SparkAnnotationImpex.exportType(this)

  def str(a: Annotation): String = if (a == null) "NA" else a.toString

  def _showStr(a: Annotation): String = str(a)

  def showStr(a: Annotation): String = if (a == null) "NA" else _showStr(a)

  def showStr(a: Annotation, trunc: Int): String = {
    val s = showStr(a)
    if (s.length > trunc)
      s.substring(0, trunc - 3) + "..."
    else
      s
  }

  def toJSON(a: Annotation): JValue = JSONAnnotationImpex.exportAnnotation(a, this)

  def genNonmissingValue: Gen[Annotation] = ???

  def genValue: Gen[Annotation] =
    Gen.nextCoin(0.05).flatMap(isEmpty => if (isEmpty) Gen.const(null) else genNonmissingValue)

  def isRealizable: Boolean = children.forall(_.isRealizable)

  /* compare values for equality, but compare Float and Double values by the absolute value of their difference is within tolerance or with D_== */
  def valuesSimilar(a1: Annotation, a2: Annotation, tolerance: Double = utils.defaultTolerance, absolute: Boolean = false): Boolean = a1 == a2

  def scalaClassTag: ClassTag[_ <: AnyRef]

  def canCompare(other: Type): Boolean = this == other

  def ordering: ExtendedOrdering

  def jsonReader: JSONReader[Annotation] = new JSONReader[Annotation] {
    def fromJSON(a: JValue): Annotation = JSONAnnotationImpex.importAnnotation(a, self)
  }

  def jsonWriter: JSONWriter[Annotation] = new JSONWriter[Annotation] {
    def toJSON(pk: Annotation): JValue = JSONAnnotationImpex.exportAnnotation(pk, self)
  }

  /*  Fundamental types are types that can be handled natively by RegionValueBuilder: primitive
      types, Array and Struct. */
  def fundamentalType: Type = this

  def _typeCheck(a: Any): Boolean

  final def typeCheck(a: Any): Boolean = a == null || _typeCheck(a)

  def canCastTo(t: Type): Boolean = this match {
    case TInterval(tt1) => t match {
      case TInterval(tt2) => tt1.canCastTo(tt2)
      case _ => false
    }
    case TStruct(f1) => t match {
      case TStruct(f2) => f1.size == f2.size && f1.indices.forall(i => f1(i).typ.canCastTo(f2(i).typ))
      case _ => false
    }
    case TTuple(f1) => t match {
      case TTuple(f2) => f1.size == f2.size && f1.indices.forall(i => f1(i).typ.canCastTo(f2(i).typ))
      case _ => false
    }
    case TArray(t1) => t match {
      case TArray(t2) => t1.canCastTo(t2)
      case _ => false
    }
    case TSet(t1) => t match {
      case TSet(t2) => t1.canCastTo(t2)
      case _ => false
    }
    case TDict(k1, v1) => t match {
      case TDict(k2, v2) => k1.canCastTo(k2) && v1.canCastTo(v2)
      case _ => false
    }
    case _ => this == t
  }
}
