-- Copyright (c) 2017-present, Facebook, Inc.
-- All rights reserved.

-- This source code is licensed under the BSD-style license found in the
-- LICENSE file in the root directory of this source tree.

#!/bin/lua

local yt = 'dev'
local filename -- filename for output
local xtime = false -- time axis instead of iteration
local dumb = false -- dumb terminal
local legend = false -- is legend provided?

for i=#arg,1,-1 do
   if arg[i]:match('^%-') then
      local opt = table.remove(arg, i)
      if opt == '-y' then
         yt = assert(table.remove(arg, i), '-y <string> expected')
      elseif opt == '-o' then
         filename = assert(table.remove(arg, i), '-o <filename> expected')
      elseif opt == '-xtime' then
         xtime = true
      elseif opt == '-d' then
         dumb = true
      elseif opt == '-l' then
         legend = true
      else
         error(string.format([[
usage: %s [options] <dir1> <dir2> ...
options:
-y <number>  -- target to plot (%d)
-o <string>  -- output filename (%s)
-xtime       -- use training time for x axis (%s)
-l           -- specifiy legends with <dir1> <legend1> <dir2> <legend2> ...
]], arg[0], yt, filename, xtime))
      end
   end
end

local function split(str)
   local s = {}
   for x in str:gmatch('(%S+)') do
      table.insert(s, x)
   end
   return s
end

local function totime(str)
   local h, m, s = str:match('([^:]+):([^:]+):([^:]+)')
   h, m, s = tonumber(h), tonumber(m), tonumber(s)
   return h*3600+m*60+s
end

local function addplot(plots, dir, legend)
   local x = {}
   local y = {}
   for i=1,999 do
      local filename = string.format("%s/%03d_perf", dir, i)
      local f = io.open(filename)
      if not f then
         break
      end
      f:close()
      io.stderr:write(string.format('# processing: %s\n', filename))
      local lines = io.lines(filename)
      local header = lines()
      if not header:match('^#') then
         io.stderr:write(string.format('! warning: skipping directory <%s> (no header)\n', dir))
         return
      end
      header = header:gsub('^#+', '')
      local i = 0
      local yc
      local rc
      for head in header:gmatch('(%S+)') do
         i = i + 1
         if head:find(yt) then
            yc = i
            break
         end
         if xtime and head == 'runtime' then
            rc = i
         end
      end
      if xtime and not rc then
         io.stderr:write(string.format('! warning: skipping directory <%s> (no field <runtime>)\n', dir))
      end
      if not yc then
         io.stderr:write(string.format('! warning: skipping directory <%s> (no field <%s>)\n', dir, yt))
      end
      for line in lines do
         if not line:match('^#') then
            local tokens = split(line)
            table.insert(y, assert(tonumber(tokens[yc])))
            if xtime then
               table.insert(x, totime(tokens[rc]))
            else
               table.insert(x, #x+1)
            end
         end
      end
   end
   if #y == 0 then
      io.stderr:write(string.format('! warning: skipping directory <%s> (empty)\n', dir))
   else
      if xtime then
         local cumsum = 0
         for i=1,#x do
            cumsum = cumsum + x[i]
            x[i] = cumsum/3600
         end
      end
      table.insert(plots, {legend, x, y})
   end
end

local list = {}
if legend then
   if #arg % 2 ~= 0 then
      error('<dir1> <legend1> <dir2> <legend2> ... expected')
   end
   for i=1,#arg/2 do
      table.insert(
         list,
         {
            dir = arg[2*(i-1)+1],
            legend = arg[2*(i-1)+2]
         }
      )
   end
else
   for i=1,#arg do
      table.insert(
         list,
         {
            dir = arg[i],
            legend = arg[i]
         }
      )
   end
end

local plots = {}
for _, e in ipairs(list) do
   addplot(plots, e.dir, e.legend)
end

local lmax = 0
for i=1,#plots do
   lmax = math.max(lmax, #plots[i][2])
end

for i=1,#plots do
   if xtime then
      io.write('time(h)\t')
   else
      io.write('iter\t')
   end
   io.write(plots[i][1])
   if i < #plots then
      io.write("\t")
   end
end
print()
for l=1,lmax do
   for i=1,#plots do
      if l <= #plots[i][2] then
         io.write(string.format("%g", plots[i][2][l]))
         io.write("\t")
         io.write(string.format("%g", plots[i][3][l]))
      else
         io.write(".\t.")
      end
      if i ~= #plots then
         io.write("\t")
      end
   end
   print()
end
