#!/usr/bin/env ruby

require "bundler/setup"
require "absynthe"
require "absynthe/python"
require "json"
require "rdl"
require "fc"
require "pp"

class DataFrame; end
class Series; end
class Lambda; end
class NUnique < Lambda; end
class Type; end
class PyInt < Type; end
class DataFrameGroupBy; end
class NdArray; end

RDL.type_params :Array, [:t], :all?
RDL.type :DataFrame, :loc_getitem, "(Array<Integer>) -> DataFrame"
RDL.type :DataFrame, :loc_getitem, "(Lambda) -> DataFrame"
RDL.type :DataFrame, :xs, "(String, level: Integer) -> DataFrame"
RDL.type :DataFrame, :pivot, "(columns: String, values: String, index: String) -> DataFrame"
RDL.type :DataFrame, :sort_values, "(by: Array<String>, ascending: Array<true or false>) -> DataFrame"
RDL.type :DataFrame, :sort_values, "(Array<String>, ascending: Array<true or false>) -> DataFrame"
RDL.type :DataFrame, :combine_first, "(DataFrame) -> DataFrame"
RDL.type :DataFrame, :pivot_table, "(values: String, index: String, columns: String, aggfunc: Lambda) -> DataFrame"
RDL.type :DataFrame, :pivot_table, "(values: String, index: String, columns: String, aggfunc: String) -> DataFrame"
RDL.type :DataFrame, :pivot_table, "(values: String, index: String, columns: Array<String>, aggfunc: Lambda) -> DataFrame"
RDL.type :DataFrame, :pivot_table, "(values: String, index: Array<String>, columns: Array<String>) -> DataFrame"
RDL.type :DataFrame, :merge, "(DataFrame, on: Integer) -> DataFrame"
RDL.type :DataFrame, :merge, "(DataFrame, on: Array<String>) -> DataFrame"
RDL.type :DataFrame, :__getitem__, "(Series) -> DataFrame"
RDL.type :DataFrame, :apply, "(Lambda, axis: 0 or 1) -> Series"
RDL.type :DataFrame, :query, "(String) -> DataFrame"
RDL.type :DataFrame, :fillna, "(method: 'ffill' or 'bfill') -> DataFrame"
RDL.type :DataFrame, :astype, "(Type) -> DataFrame"
RDL.type :DataFrame, :stack, "() -> DataFrame"
RDL.type :DataFrame, :unstack, "() -> DataFrame"
RDL.type :DataFrame, :groupby, "(Array<String>) -> DataFrameGroupBy"
RDL.type :DataFrameGroupBy, :size, "() -> Series"
RDL.type :DataFrame, :dropna, '() -> DataFrame'
RDL.type :DataFrame, :reset_index, '(drop: true or false) -> DataFrame'
RDL.type :DataFrame, :T, '() -> DataFrameGroupBy'
RDL.type :DataFrameGroupBy, :reset_index, '() -> DataFrame'
RDL.type :DataFrame, :reset_index, '() -> DataFrame'
RDL.type :DataFrame, :values, '() -> NdArray'
RDL.type :DataFrame, :sum, '(0 or 1, level: Integer) -> DataFrame'
RDL.type :DataFrameGroupBy, :sum, '() -> DataFrame'
RDL.type :DataFrame, :cumsum, '() -> DataFrame'
RDL.type :DataFrame, :div, '(DataFrame, 0 or 1, Integer) -> DataFrame'
RDL.type :DataFrame, :xs, "(String, 0 or 1, Integer) -> DataFrame"
RDL.type :DataFrame, :__getitem__, "(String) -> DataFrame"
RDL.type :DataFrame, :isin, "(Array<Integer>) -> Series"
RDL.type :Array, :__getitem__, "(Integer) -> t"
RDL.type :DataFrame, :merge, "(DataFrame, how: 'left' or 'right' or 'inner' or 'outer' or 'cross') -> DataFrame"
RDL.type :DataFrame, :set_index, "(Array<String>) -> DataFrame"
RDL.type :DataFrame, :melt, "(value_vars: Array<String>, var_name: String, value_name: String) -> DataFrame"
RDL.type :DataFrame, :groupby, "(String, as_index: true or false) -> DataFrameGroupBy"
RDL.type :DataFrameGroupBy, :mean, "() -> DataFrame"

$stdout.sync = true

class Protocol
  def self.read
    l = gets
    return if l.nil? # process ended
    data = JSON.parse(l)
    # TODO: additional parsing
    data
  end

  def self.write(data)
    puts data.to_json
  end
end

class PythonSpec
  def test_prog(prog)
    Protocol.write({
      action: 'test',
      prog: PyLang::unparse(prog)
    })
    handle_action(Protocol.read)
  end
end

class PyLang
  def self.unparse(node, in_arg: false)
    case node.type
    when :const
      konst = node.children[0]
      case konst
      when Integer, Symbol
        konst.to_s
      when true
        'True'
      when false
        'False'
      when NUnique
        'pd.Series.nunique'
      when PyInt
        'int'
      when String
        konst.inspect
      else
        raise AbsyntheError, "unexpected constant type"
      end
    when :prop
      recv = unparse(node.children[0])
      propname = node.children[1].to_s
      if propname == "loc_getitem"
        propname = ".loc"
      elsif propname == "__getitem__"
        propname = ""
      elsif propname == "values"
        propname = ".values"
      elsif propname == "T"
        propname = ".T"
      end

      args = node.children[2..].map { |n| unparse(n) }.join(", ")
      if args.empty?
        fmtargs = ""
      else
        fmtargs = "[#{args}]"
      end

      "#{recv}#{propname}#{fmtargs}"
    when :send
      recv = unparse(node.children[0])
      methd_name = node.children[1].to_s
      args = node.children[2..].map { |n| unparse(n, in_arg: true) }.join(", ")
      "#{recv}.#{methd_name}(#{args})"
    when :key
      val = unparse(node.children[1])
      "#{node.children[0]}=#{val}"
    when :hash
      raise AbsyntheError, "only expected in_arg" unless in_arg
      node.children.map { |kv| unparse(kv) }. join(", ")
    when :array
      args = node.children.map { |n| unparse(n) }.join(", ")
      "[#{args}]"
    when :hole
      # "(□: #{node.children[1]})"
      "□"
    when :dephole
      # "(□: #{node.children[1]})"
      "◐"
    else
      raise AbsyntheError, "unexpected AST node #{node.type}"
    end
  end
end

def parse_cols(data)
  [data['colin'].map { |n|
      if n == "bot"
        PandasCols.bot
      else
        PandasCols.val(n)
      end
    },
   PandasCols.var(data['colout'])]
end

def syn_start(data)
  start_time = Process.clock_gettime(Process::CLOCK_MONOTONIC)
  consts = data['consts']
  argsty = data['argsty']
  fnty = RDL::Globals.parser.scan_str("(#{argsty.join(', ')}) -> #{data['outputty']}")
  col_in, col_out = parse_cols(data)
  spec = PythonSpec.new

  tyargs = fnty.args.map { |ty| PyType.val(ty) }

  abs_env = tyargs.zip(col_in).each_with_index.map { |arg, i|
      ["arg#{i}".to_sym, ProductDomain.val(arg[0], arg[1])]
    }.to_h
  # abs_env = tyargs.each_with_index.map { |arg, i| ["arg#{i}".to_sym, arg]}.to_h
  ctx = Context.new(abs_env, ProductDomain.val(PyType.val(fnty.ret), col_out))
  # ctx = Context.new(abs_env, PyType.val(fnty.ret))
  ctx.lang = :py
  ctx.domain = ProductDomain
  # ctx.domain = PyType
  ctx.score = Proc.new { |prog| WeightedSizePass.prog_size(prog) }
  ctx.consts[:int] = consts.filter { |k| k.is_a? Integer }
  ctx.consts[:str] = consts.filter { |k| k.is_a? String }
  ctx.max_size = data['seqs']
  # puts "==> #{ctx.max_size}"

  seed = s(:hole, nil, ctx.goal)
  q = FastContainers::PriorityQueue.new(:min)
  q.push(seed, ProgSizePass.prog_size(seed))
  begin
    Timeout::timeout(20 * 60) do
      prog = synthesize(ctx, spec, q)
      end_time = Process.clock_gettime(Process::CLOCK_MONOTONIC)
      Protocol.write({
        action: 'done',
        time: end_time - start_time,
        specs: 1,
        tested_progs: Instrumentation.tested_progs,
        size: PyProgSizePass.prog_size(prog),
        prog: PyLang::unparse(prog)
      })
    end
  rescue Timeout::Error
    Protocol.write({
      action: 'timeout'
    })
  end
end

def handle_action(data)
  case data['action']
  when 'start'
    syn_start(data)
  when 'test_res'
    data['res']
  else
    raise AbsyntheError, "unexpected action #{data}"
  end
end

Instrumentation.reset!
handle_action(Protocol.read)
# handle_action({'args': ['DataFrame'], 'output': 'DataFrame', 'action': 'start'}.transform_keys(&:to_s))
