# plot_by_name.rb

class Plot_by_Name_Data

    include Math
    include Tioga
    
    attr_accessor :info, :values, :mesh, :names, :num_values, :dict

    def initialize(names_filename, data_filename, lines_before_real_data=0, strictly_increasing=nil)    
        read_names(names_filename)
        @num_values=@names.size
        @values = Array.new(@num_values) { |i| Dvector.new }
        Dvector.read(data_filename, @values, 1+lines_before_real_data)
        unless strictly_increasing == nil
           col = @values[strictly_increasing]
           # remove rows as necessary to make col strictly increasing
          lst = []
          n = col.size
          (n-1).times do |k|
            lst << k if col[k] >= col[k+1..-1].min
          end
          unless lst.length == 0
            lst = lst.sort
            @values.each { |vec| vec.prune!(lst) }
          end
        end
        @mesh = Dvector.new(@values[0].size) { |i| i+1 }
        @dict = Hash.new
        @num_values.times { |i| @dict[@names[i]] = @values[i] }
    end    
    
    def read_names(names_filename)        
        f = File.open(names_filename)
        @names = []
        500.times {|i| 
            begin
                namestr = f.readline.strip # remove whitespace
            rescue
                namestr = ''
            end
            break if namestr.size == 0
            @names << namestr
        }
        f.close        
    end
    
    
end # Plot_by_Name_Data


class Plot_by_Name

    include Math
    include Tioga
    include FigureConstants
    
    def Plot_by_Name.make_names(logfilename,namesfilename,data_dir='.',skip=0)
      # create names file from names on first line of log file
      tracefile = File.open(data_dir + '/' + logfilename)
      skip.times { tracefile.readline } # skip a few lines before names
      names = tracefile.readline.split
      tracefile.close
      namesfile = File.open(data_dir + '/' + namesfilename, 'w')
      names.each { |nm| namesfile.puts nm }
      namesfile.close
    end
    
    def t
        @figure_maker
    end

    def d
        @data
    end

    def cd
        @compare_data
    end

    def initialize(dict=nil)
        @xaxis_column = get_if_given_else_default(dict, 'xaxis_column', 0)
        @reverse_xaxis = get_if_given_else_default(dict, 'reverse_xaxis', false)
        @xlabel = get_if_given_else_default(dict, 'xlabel', "x")
        @xaxis_label = get_if_given_else_default(dict, 'xaxis_label', nil)
        data_dir = get_if_given_else_default(dict, 'data_dir', "plot_data")
        @rtol = get_if_given_else_default(dict, 'rtol', 1e-8)
        @first_pt = get_if_given_else_default(dict, 'first', 0)
        @last_pt = get_if_given_else_default(dict, 'last', -1)
        test_names_file = get_if_given_else_default(dict, 'names_file', "names.data")
        test_data_file = get_if_given_else_default(dict, 'test_file', "test.data")
        compare_data_file = get_if_given_else_default(dict, 'compare_file', nil)
        compare_names_file = get_if_given_else_default(dict, 'compare_names_file', test_data_file)
        lines_before_real_data = get_if_given_else_default(dict, 'lines_before_real_data', 0)
        log_vals = get_if_given_else_default(dict, 'log_vals', nil)
        strictly_increasing = get_if_given_else_default(dict, 'strictly_increasing', nil)
        @xmin = get_if_given_else_default(dict, 'xmin', nil)
        @xmax = get_if_given_else_default(dict, 'xmax', nil)
        @ymin = get_if_given_else_default(dict, 'ymin', nil)
        @ymax = get_if_given_else_default(dict, 'ymax', nil)
        @log_xaxis = get_if_given_else_default(dict, 'log_xaxis', nil)
        @log_xaxis = nil if @log_xaxis == false
        @log_xaxis_below_max = get_if_given_else_default(dict, 'log_xaxis_below_max', nil)
        @log_xaxis_below_max = nil if @log_xaxis_below_max == false
        @log_xaxis_shift_by_last = get_if_given_else_default(dict, 'log_xaxis_shift_by_last', nil)
        @log_xaxis_shift_by_last = nil if @log_xaxis_shift_by_last == false
        @define_plots_for_all_columns = get_if_given_else_default(dict, 'define_plots_for_all_columns', true)
        @define_plots_for_these_columns = get_if_given_else_default(dict, 'define_plots_for_these_columns', nil)
        @show_points = get_if_given_else_default(dict, 'show_points', false)
        @log_vals = get_if_given_else_default(dict, 'log_vals', nil)
        
        # columns are numbered from 1
        # use by_column = 0 to have plots done by row number

        @atol = 1e-30
        
        @figure_maker = FigureMaker.default
        t.save_dir = 'plot_out'
        t.def_eval_function { |str| eval(str) }
        t.tex_preview_preamble += "\n\\include{color_names}\n"
        
        t.def_enter_page_function { enter_page }
        
        @data = Plot_by_Name_Data.new(
          data_dir + '/' + test_names_file, 
          data_dir + '/' + test_data_file, 
          lines_before_real_data, 
          strictly_increasing)
          
        @first_pt = 0 if d.mesh.length <= @first_pt
        @last_pt = -1 if d.mesh.length <= @last_pt
        
        
        
        if compare_data_file == nil
           @compare_data = nil
        else
           @compare_data = Plot_by_Name_Data.new(
             data_dir + '/' + compare_names_file, 
             data_dir + '/' + compare_data_file, 
             lines_before_real_data, 
             strictly_increasing)
        end
        
        if @log_vals != nil
         @log_vals.each {|nm|
            i = d.names.index(nm) # index of nm in names
            if i != nil
               #d.names[i] = 'log ' + d.names[i]
               puts "converting to log10 " + d.names[i]
               d.values[i].safe_log10!
            end
         }
        end
        
        if @xaxis_column.kind_of? String
            @xaxis_column = d.names.index(@xaxis_column)
            @xaxis_column = @xaxis_column+1 if @xaxis_column.kind_of? Integer
            @xaxis_column = 0 if @xaxis_column == nil
        end


        xs = d.values[@xaxis_column-1]
        @first_pt = xs.where_closest(@xmin) unless @xmin == nil
        @last_pt = xs.where_closest(@xmax) unless @xmax == nil


        set_plotting_info
        @xlabel = 'mesh point' if @xlabel == nil and @xaxis_column == 0
        @xlabel = d.names[@xaxis_column-1] if dict['xlabel'] == nil and @xaxis_column != 0
        
        @xaxis_label = @xlabel if @xaxis_label == nil
        
        define_figures(dict) if defined?(define_figures)
        
        if @define_plots_for_all_columns
            @num_plots.times do |i|
                if (i+1) != @xaxis_column
                   eval(sprintf("t.def_figure(\"%s\") { plot %i }", @plot_info[i][0], i))
                end
            end
        end
        
        unless @define_plots_for_these_columns == nil
            @define_plots_for_these_columns.each do |nm|
                i = d.names.index(nm) # index of nm in names
                if i == nil
                    puts "cannot find data for define_plots_for_these_columns item <#{nm}>}"
                else
                    eval(sprintf("t.def_figure(\"%s\") { plot %i }", @plot_info[i][0], i))
                end
            end
        end

    end  
    
    
    def enter_page
        t.page_setup(11*72/2,8.5*72/2)
        t.set_frame_sides(0.15,0.85,0.85,0.15) # left, right, top, bottom in page coords        
    end
    
    def get_if_given_else_default(dict, name, default)
        return default if dict == nil
        val = dict[name]
        return val if val != nil
        return default
    end
    
    
    def get_ys_for_plot(name)
      d.dict[name][@first_pt..@last_pt]
    end
    
    
    def get_min_max_for_plot(n)
        ys = get_ys_for_plot(d.names[n])
        if ys == nil
           puts "invalid range @first_pt=#{@first_pt}, @last_pt=#{@last_pt}, d.values[n].size=#{d.values[n].size}"
           return
        end
        ymin = ys.min
        ymin = ys.min_gt(-99) if ymin == -99
        ymin = -99 if ymin == nil
        ymax = ys.max
        dy = ymax - ymin
        ytol = [ymin.abs, ymax.abs].max
        dy = ytol * @rtol + @atol if dy < ytol * @rtol + @atol
        ymargin = 0.02
        @plot_ymax = ymax + ymargin * dy
        @plot_ymin = ymin - ymargin * dy
    end

    
    def set_plotting_info # for each of the results
        @num_plots = d.num_values
        @plot_info = Array.new(@num_plots)        
        @num_plots.times {|i|
            get_min_max_for_plot(i)
            @plot_info[i] = [d.names[i], @plot_ymin, @plot_ymax]
        }        
    end
    
    
    def get_xs_for_plot
        get1_xs_for_plot(d)
    end


    def get1_xs_for_plot(d)
       if @xaxis_column == 0
          xs = d.mesh[@first_pt..@last_pt]
       elsif @log_xaxis != nil
          xs = d.values[@xaxis_column-1][@first_pt..@last_pt].safe_log10
       elsif @log_xaxis_below_max != nil
          xs = d.values[@xaxis_column-1][@first_pt..@last_pt]
          maxx = xs.max
          xs.sub!(maxx+1).neg!.safe_log10!
       elsif @log_xaxis_shift_by_last != nil
          xs = d.values[@xaxis_column-1][@first_pt..@last_pt]
          maxx = xs[-1]
          k = -1
          while true
            shift = maxx-xs[k]
            break if shift > 0
            k = k-1
            break if k < -xs.length
          end
          xs.sub!(maxx+shift).neg!.safe_log10!
       else
          xs = d.values[@xaxis_column-1]
          xs = xs[@first_pt..@last_pt]
       end
       return xs
    end


    def plot(n)    
       xs = get_xs_for_plot
       ys = get_ys_for_plot(d.names[n])
       plot_results(n, xs, ys, @xaxis_label, @reverse_xaxis)
    end
    
    
    def plot_results(n, xs, ys, xlabel, xreversed)
        #t.rescale(0.8)
        t.set_subframe('right_margin' => 0.12, 'left_margin' => 0.0)
        background
        plt = @plot_info[n]
        title = plt[0].tr("_", " ")
        t.show_title(title)
        t.show_xlabel(xlabel.tr("_", " "))
        ymin = (@ymin != nil)? @ymin : plt[1]
        ymax  = (@ymax != nil)? @ymax : plt[2]
        xmargin = 0.07
        ymargin = 0.1
        if xreversed
            xmin = xs.min; xmax = xs.max
            width = (xmax == xmin)? 1 : xmax - xmin
            left_boundary = xmax + xmargin * width
            right_boundary = xmin - xmargin * width
        else
            xmin = xs.min; xmax = xs.max
            width = (xmax == xmin)? 1 : xmax - xmin
            left_boundary = xmin - xmargin * width
            right_boundary = xmax + xmargin * width
        end
        ythresh = @atol
        height = (ymax < ymin+ythresh)? ythresh : ymax - ymin
        top_boundary = ymax + ymargin * height
        bottom_boundary = ymin - ymargin * height
        boundaries = [ left_boundary, right_boundary, top_boundary, bottom_boundary ]
        t.show_plot(boundaries) {
            if @show_points
               t.show_marker( 
                  'Xs' => xs, 'Ys' => ys, 'color' => BrightBlue, 
                  'marker' => Bullet, 'scale' => 0.2)
            else
               t.show_polyline(xs,ys,BrightBlue)
            end
        }
        show_footer
        
    end
    
    
    def show_box_labels(title, xlabel=nil, ylabel=nil)
        if title != nil
            t.show_title(title); t.no_title
        end
        if xlabel != nil
            t.show_xlabel(xlabel); t.no_xlabel
        end
        if ylabel != nil
            t.show_ylabel(ylabel); t.no_ylabel
        end
    end
    
    
    def show_footer
    end
        
        
    def background
        return
        
        t.fill_color = FloralWhite
        t.fill_opacity = 0.5
        t.fill_frame
    end


end

# call Plot_by_Name.new with optional args