# plot_by_name.rb

class Plot_by_Name_Data

    include Math
    include Tioga
    
    attr_accessor :info, :values, :mesh, :names, :num_values, :dict

    def initialize(names_filename, data_filename, lines_before_real_data=0, strictly_increasing=nil)    
        read_names(names_filename)
        @num_values=@names.size
        @values = Array.new(@num_values) { |i| Dvector.new }
        Dvector.read(data_filename, @values, 1+lines_before_real_data)
        unless strictly_increasing == nil
           col = @values[strictly_increasing]
           # remove rows as necessary to make col strictly increasing
          lst = []
          n = col.size
          (n-1).times do |k|
            lst << k if col[k] >= col[k+1..-1].min
          end
          unless lst.length == 0
            lst = lst.sort
            @values.each { |vec| vec.prune!(lst) }
          end
        end
        @mesh = Dvector.new(@values[0].size) { |i| i+1 }
        @dict = Hash.new
        @num_values.times { |i| @dict[@names[i]] = @values[i] }
    end    
    
    def read_names(names_filename)        
        f = File.open(names_filename)
        @names = []
        500.times {|i| 
            begin
                namestr = f.readline.strip # remove whitespace
            rescue
                namestr = ''
            end
            break if namestr.size == 0
            @names << namestr
        }
        f.close        
    end
    
    
end # Plot_by_Name_Data


class Plot_by_Name

    include Math
    include Tioga
    include FigureConstants
    
    def Plot_by_Name.make_names(logfilename,namesfilename,data_dir='.',skip=0)
      # create names file from names on first line of log file
      tracefile = File.open(data_dir + '/' + logfilename)
      skip.times { tracefile.readline } # skip a few lines before names
      names = tracefile.readline.split
      tracefile.close
      namesfile = File.open(data_dir + '/' + namesfilename, 'w')
      names.each { |nm| namesfile.puts nm }
      namesfile.close
    end
    
    def t
        @figure_maker
    end

    def d
        @data
    end

    def cd
        @compare_data
    end

    def initialize(dict=nil)
        
        
        @title = dict['title']
        @log_mass_frac_ymax = dict['log_mass_frac_ymax']
        @log_mass_frac_ymin = dict['log_mass_frac_ymin']
        
        @grid_min = dict['grid_min']
        @grid_min = 0 if @grid_min == nil
        @grid_max = dict['grid_max']
        @grid_max = -1 if @grid_max == nil

        @xmin = dict['xmin']
        @xmax = dict['xmax']
        
        @num_abundance_line_labels = dict['num_abundance_line_labels']
        @num_abundance_line_labels = 9 if @num_abundance_line_labels == nil


        @xaxis_column = get_if_given_else_default(dict, 'xaxis_column', 0)
        @reverse_xaxis = get_if_given_else_default(dict, 'reverse_xaxis', false)
        @xlabel = get_if_given_else_default(dict, 'xlabel', "x")
        @xaxis_label = get_if_given_else_default(dict, 'xaxis_label', nil)
        data_dir = get_if_given_else_default(dict, 'data_dir', "plot_data")
        @rtol = get_if_given_else_default(dict, 'rtol', 1e-8)
        @first_pt = get_if_given_else_default(dict, 'first', 0)
        @last_pt = get_if_given_else_default(dict, 'last', -1)
        test_names_file = get_if_given_else_default(dict, 'names_file', "names.data")
        test_data_file = get_if_given_else_default(dict, 'test_file', "test.data")
        compare_data_file = get_if_given_else_default(dict, 'compare_file', nil)
        compare_names_file = get_if_given_else_default(dict, 'compare_names_file', test_data_file)
        lines_before_real_data = get_if_given_else_default(dict, 'lines_before_real_data', 0)
        log_vals = get_if_given_else_default(dict, 'log_vals', nil)
        strictly_increasing = get_if_given_else_default(dict, 'strictly_increasing', nil)
        @ymin = get_if_given_else_default(dict, 'ymin', nil)
        @ymax = get_if_given_else_default(dict, 'ymax', nil)
        @log_xaxis = get_if_given_else_default(dict, 'log_xaxis', nil)
        @log_xaxis = nil if @log_xaxis == false
        @log_xaxis_below_max = get_if_given_else_default(dict, 'log_xaxis_below_max', nil)
        @log_xaxis_below_max = nil if @log_xaxis_below_max == false
        @log_xaxis_shift_by_last = get_if_given_else_default(dict, 'log_xaxis_shift_by_last', nil)
        @log_xaxis_shift_by_last = nil if @log_xaxis_shift_by_last == false
        
        # columns are numbered from 1
        # use by_column = 0 to have plots done by row number

        @atol = 1e-30
        
        @figure_maker = FigureMaker.default
        t.save_dir = 'plot_out'
        t.def_eval_function { |str| eval(str) }
        t.tex_preview_preamble += "\n\\include{color_names}\n"
        
        t.def_enter_page_function { enter_page }
        
        @data = Plot_by_Name_Data.new(
          data_dir + '/' + test_names_file, 
          data_dir + '/' + test_data_file, 
          lines_before_real_data, 
          strictly_increasing)
          
        @show_net_burn = false
        
        @net_burn_data = Plot_by_Name_Data.new(
          data_dir + '/net_burn.names', 
          data_dir + '/net_burn.data', 
          lines_before_real_data, 
          strictly_increasing) if @show_net_burn
          
        @first_pt = 0 if d.mesh.length <= @first_pt
        @last_pt = -1 if d.mesh.length <= @last_pt
        
        if compare_data_file == nil
           @compare_data = nil
        else
           @compare_data = Plot_by_Name_Data.new(
             data_dir + '/' + compare_names_file, 
             data_dir + '/' + compare_data_file, 
             lines_before_real_data, 
             strictly_increasing)
        end
        
        if log_vals != nil
         log_vals.each {|nm|
            i = d.names.index(nm) # index of nm in names
            if i != nil
               d.names[i] = 'log ' + d.names[i]
               puts "converting to log10 " + d.names[i]
               d.values[i].safe_log10!
               cd.values[i].safe_log10! if compare_file != nil
            end
         }
        end
        
        if @xaxis_column.kind_of? String
            @xaxis_column = d.names.index(@xaxis_column)
            @xaxis_column = @xaxis_column+1 if @xaxis_column.kind_of? Integer
            @xaxis_column = 0 if @xaxis_column == nil
        end

        set_plotting_info
        @xlabel = 'mesh point' if @xlabel == nil and @xaxis_column == 0
        @xlabel = d.names[@xaxis_column-1] if dict['xlabel'] == nil and @xaxis_column != 0
        
        @xaxis_label = @xlabel if @xaxis_label == nil

        t.def_figure("burn") { plot_burn }
        t.def_figure("ye") { plot_ye }
        t.def_figure("lg_eps_nuc") { plot_lg_eps_nuc }
        t.def_figure("lg_ergs") { plot_lg_ergs }
        t.def_figure("xsum_sub_1") { plot_xsum_sub_1 }
        
        puts "@first_pt #{@first_pt}" unless @first_pt == 0
        puts "@last_pt #{@last_pt}" unless @last_pt == -1

    end  
    
    
    def enter_page
        t.page_setup(11*72/2,8.5*72/2)
        t.set_frame_sides(0.15,0.85,0.85,0.15) # left, right, top, bottom in page coords        
    end
    
    
    def get_if_given_else_default(dict, name, default)
        return default if dict == nil
        val = dict[name]
        return val if val != nil
        return default
    end
    
    
    def get_ys_for_plot(name)
      if cd == nil
         d.dict[name][@first_pt..@last_pt]
      else
         other = cd.dict[name]
         if other == nil
            d.dict[name][@first_pt..@last_pt]
         else
            d.dict[name][@first_pt..@last_pt].sub(cd.dict[name][@first_pt..@last_pt])
         end
      end
    end
    
    
    def get_min_max_for_plot(n)
        ys = get_ys_for_plot(d.names[n])
        if ys == nil
           puts "invalid range @first_pt=#{@first_pt}, @last_pt=#{@last_pt}, d.values[n].size=#{d.values[n].size}"
           return
        end
        ymin = ys.min
        ymin = ys.min_gt(-99) if ymin == -99
        ymin = -99 if ymin == nil
        ymax = ys.max
        dy = ymax - ymin
        ytol = [ymin.abs, ymax.abs].max
        dy = ytol * @rtol + @atol if dy < ytol * @rtol + @atol
        ymargin = 0.02
        @plot_ymax = ymax + ymargin * dy
        @plot_ymin = ymin - ymargin * dy
    end

    
    def set_plotting_info # for each of the results
        @num_plots = d.num_values
        @plot_info = Array.new(@num_plots)        
        @num_plots.times {|i|
            get_min_max_for_plot(i)
            @plot_info[i] = [d.names[i], @plot_ymin, @plot_ymax]
        }        
    end
    
    
    def plot_ye
        plot(d.names.index('ye'))
    end
    
    
    def plot_lg_eps_nuc
        plot(d.names.index('lg_eps_nuc'))
    end
    
    
    def plot_lg_ergs
        plot(d.names.index('lg_ergs'))
    end
    
    
    def plot_xsum_sub_1
        plot(d.names.index('xsum_sub_1'))
    end
    
    
    def get_xs_for_plot
       if @xaxis_column == 0
          xs = d.mesh[@first_pt..@last_pt]
       elsif @log_xaxis != nil
          xs = d.values[@xaxis_column-1][@first_pt..@last_pt].safe_log10
       elsif @log_xaxis_below_max != nil
          xs = d.values[@xaxis_column-1][@first_pt..@last_pt]
          maxx = xs.max
          xs.sub!(maxx+1).neg!.safe_log10!
       elsif @log_xaxis_shift_by_last != nil
          xs = d.values[@xaxis_column-1][@first_pt..@last_pt]
          maxx = xs[-1]
          k = -1
          while true
            shift = maxx-xs[k]
            break if shift > 0
            k = k-1
            break if k < -xs.length
          end
          xs.sub!(maxx+shift).neg!.safe_log10!
       else
          xs = d.values[@xaxis_column-1][@first_pt..@last_pt]
       end
       return xs
    end


    def plot(n)    
       xs = get_xs_for_plot
       ys = get_ys_for_plot(d.names[n])
       plot_results(n, xs, ys, @xaxis_label, @reverse_xaxis)
    end
    
    
    def plot_results(n, xs, ys, xlabel, xreversed)
        t.rescale(0.8)
        t.set_subframe('right_margin' => 0.12, 'left_margin' => 0.0)
        background
        plt = @plot_info[n]
        title = plt[0].tr("_", " ")
        t.show_title(title)
        t.show_xlabel(xlabel.tr("_", " "))
        ymin = (@ymin != nil)? @ymin : plt[1]
        ymax  = (@ymax != nil)? @ymax : plt[2]
        xmargin = 0.07
        ymargin = 0.1
        if xreversed
            xmin = xs.min; xmax = xs.max
            width = (xmax == xmin)? 1 : xmax - xmin
            left_boundary = xmax + xmargin * width
            right_boundary = xmin - xmargin * width
        else
            xmin = xs.min; xmax = xs.max
            width = (xmax == xmin)? 1 : xmax - xmin
            left_boundary = xmin - xmargin * width
            right_boundary = xmax + xmargin * width
        end
        ythresh = @atol
        height = (ymax < ymin+ythresh)? ythresh : ymax - ymin
        top_boundary = ymax + ymargin * height
        bottom_boundary = ymin - ymargin * height
        boundaries = [ left_boundary, right_boundary, top_boundary, bottom_boundary ]
        t.show_plot(boundaries) {
            t.show_polyline(xs,ys,BrightBlue)
        }
        show_footer
        
    end
    
    
    def show_footer
    end
        
        
    def background
    end
    
    
    def plot_burn
        
        do_eps = true
        do_eps_nuc = false

        t.rescale(0.75)
        t.ylabel_shift = 1.4
        if do_eps
            t.set_subframe('right_margin' => 0.3, 'top_margin' => 0.1, 'bottom_margin' => 0.1)
        else
            t.set_subframe('right_margin' => 0.1, 'top_margin' => 0.1, 'bottom_margin' => 0.1)
        end
        xs = get_xs_for_plot
        title = @title
        t.show_title(title) unless title == 'HE4'
        if @xaxis_label == 'lg_age'
            xlabel = 'log age (years)'
        elsif @xaxis_label == 'lg_time'
            xlabel = 'log age (seconds)'
        else
            xlabel = @xaxis_label.tr("_", " ")
        end 
        t.show_xlabel(xlabel)
        t.legend_text_dy = 1        
        ymin = @log_mass_frac_ymin
        if ymin == nil
            ymin = -8.1
        end 
        ymax = @log_mass_frac_ymax
        if ymax == nil
            if ymin < -8
                ymax = 0.8
            else
                ymax = 0.25
            end
        end
        
        xleft = xs[@grid_min]
        xright = xs[@grid_max]
        
        xleft = @xmin unless @xmin == nil
        xright = @xmax unless @xmax == nil
        
        @xleft = xleft
        @xright = xright

        first_abundance = @first_abundance_column
        puts "first_abundance #{first_abundance}"
        
        t.subplot {
            t.yaxis_loc = t.ylabel_side = LEFT
            t.right_edge_type = AXIS_HIDDEN if do_eps
            t.show_plot('left_boundary' => xleft, 'right_boundary' => xright,
                'top_boundary' => ymax, 'bottom_boundary' => ymin) do
                ylabel = 'log mass fraction'
                t.show_ylabel(ylabel)
                cnt = 1 # reserve 0 for eps nuc
                d.num_values.times { |i|
                   unless d.names[i] == 'ye' || d.names[i] == 'xsum_sub_1' || i < first_abundance
                        puts "plot column #{i} #{d.names[i]}"
                        cnt = plot_abundance_line(cnt, xs, d.values[i], d.names[i])
                   end
                }
                if title == 'HE4'
                  xpos = 1.7; ypos = -0.5; dypos = -0.3; scale = 0.7
                  t.show_label(
                      'x' => xpos, 'y' => ypos, 'justification' => LEFT_JUSTIFIED,
                      'text' => '0.98 He4, 0.02 N14', 'scale' => scale, 'alignment' => ALIGNED_AT_BASELINE)
                  ypos += dypos
                  t.show_label(
                      'x' => xpos, 'y' => ypos,  'justification' => LEFT_JUSTIFIED,
                      'text' => 'logRho = 4.0', 'scale' => scale, 'alignment' => ALIGNED_AT_BASELINE)
                  ypos += dypos
                  t.show_label(
                      'x' => xpos, 'y' => ypos,  'justification' => LEFT_JUSTIFIED,
                      'text' => 'logT = 8.1', 'scale' => scale, 'alignment' => ALIGNED_AT_BASELINE)
                end
            end
        }
        
        return unless do_eps
        
        if do_eps_nuc
            t.context do # eps_nuc
                i = 2
                ys = d.values[i]
                ymax = ys.max + 0.2
                ymin = ys.max - 1.8
            
                #ymax = 24.3
                #ymin = 16.2

                left_boundary = xleft
                right_boundary = xright
                top_boundary = ymax
                bottom_boundary = ymin
                dx = right_boundary - left_boundary
                bounds = [ left_boundary, right_boundary, top_boundary, bottom_boundary ]
                t.set_bounds(bounds)
                t.context {
                    t.clip_to_frame
                    plot_abundance_line(0, xs, ys, 'eps')
                    #t.show_polyline(xs,ys,Blue)
                }
                spec = {
                    'ticks_outside' => false,
                    'ticks_inside' => true,
                    'from' => [right_boundary + 0.175*dx, top_boundary],
                    'to' => [right_boundary + 0.175*dx, bottom_boundary],
                }
                t.show_axis(spec)
                t.yaxis_loc = t.ylabel_side = RIGHT
                t.ylabel_shift = 5.3
                ylabel = 'log eps nuc (ergs gm^{-1} \mathrm{s}^{-1})'
                t.show_ylabel(ylabel)
            end
        end
        
        t.subplot do  # ergs per gm
            i = 3
            t.yaxis_loc = t.ylabel_side = RIGHT
            t.left_edge_type = AXIS_HIDDEN
            ys = d.values[i]
            name = d.names[i]
            ymax = ys.max + 1
            ymin = ys.min - 1

            ymax = 19
            ymin = 15.5
            
            t.show_plot('left_boundary' => xleft, 'right_boundary' => xright,
                'top_boundary' => ymax, 'bottom_boundary' => ymin) do
                ylabel = 'log total ergs gm^{-1}'
                t.show_ylabel(ylabel)
                plot_abundance_line(-1, xs, ys, name)
            end
        end
    end
    
    
    def show_box_labels(title, xlabel=nil, ylabel=nil)
        if title != nil
            t.show_title(title); t.no_title
        end
        if xlabel != nil
            t.show_xlabel(xlabel); t.no_xlabel
        end
        if ylabel != nil
            t.show_ylabel(ylabel); t.no_ylabel
        end
    end

    
    
    def plot_abundance_line(cnt, xs, ys, label)
        show_lim = -50
        return cnt if ys == nil
        return cnt unless ys.max > show_lim
        #ys = ys_in.safe_log10
        xs_plot = xs[@grid_min..@grid_max]
        ys_plot = ys[@grid_min..@grid_max]
        txt = nil
        txt = label if @legend_abundance_lines == true || @doing_multiple != true
        
        colors = [ BrightBlue, Goldenrod, Coral, FireBrick, RoyalPurple, Lilac ]
        num_colors = colors.length
        patterns = [ Line_Type_Solid, Line_Type_Dash, Line_Type_Dot, Line_Type_Dot_Dash, Line_Type_Dot_Long_Dash ]
        num_patterns = patterns.length
        
        color = colors[cnt - num_colors*(cnt/num_colors)]
        pattern = Line_Type_Solid #patterns[cnt/num_colors]
        if cnt < 0
            color = Crimson
            pattern = Line_Type_Solid
        end
        
        if true
        #puts "show jina burn " + label
        t.line_width = 0.7
        #color = Black
        #pattern = Line_Type_Solid
        t.show_polyline(xs_plot, ys_plot, color, txt, pattern)
        t.show_marker('xs' => xs_plot, 'ys' => ys_plot, 'marker' => Bullet,
            'color' => color, 'scale' => 0.4) if @show_points
        end

         if @show_net_burn
            d1 = @net_burn_data
            i = d1.names.index(label)
            unless i == nil
               xname = d.names[@xaxis_column-1]
               puts "xname #{xname}"
               iname = d1.names.index(xname)
               puts "iname #{iname}"
               xs1 = d1.values[iname]
               ys1 = d1.values[i]
               t.line_width = 1.5
               color = Black
               pattern = Line_Type_Dot
               t.show_polyline(xs1, ys1, color, txt, pattern)
            end
         end

        
        return cnt+1 unless @num_abundance_line_labels > 0
        xwidth = @xright - @xleft
        n = @num_abundance_line_labels
        dx = xwidth/n
        xs_reversed = xs_plot.reverse # need to reverse for use with linear_interpolate
        ys_reversed = ys_plot.reverse
        dy = t.default_text_height_dy*0.2
        n.times do |j|
            x = @xleft + dx*(j+0.5)
            y = Dvector.linear_interpolate(x, xs_reversed, ys_reversed) + dy
            scale = 0.9
            if label == 'ye'
                txt = label
            elsif label == 'eps'
                txt = '$\epsilon_{nuc}$'
                scale = 0.9
            else
                txt = label[3,label.length]
                txt = '$\epsilon_{nuc}$' if txt == 'eps_nuc'
                txt = '$error$' if txt == 'abs_1_sub_sum_x'
            end
            #next if x > 2 && x < 8 && txt != 'o18'
            #next if txt == 'o18' && x > 3
            t.show_label(
                'x' => x, 'y' => y, 
                'text' => txt, 'scale' => scale, 'alignment' => ALIGNED_AT_BASELINE)
                #'text' => label, 'scale' => 0.6, 'alignment' => ALIGNED_AT_BASELINE)
        end
        return cnt+1
    end 


end

data_dir = '.'

logfile = 'one_zone_burn.data'
f = File.open(logfile)
title = f.readline
f.close
puts "title #{title}"


namesfile = 'burn.names'
Plot_by_Name.make_names(logfile,namesfile,data_dir,1) 
# create names file from names on first line of log file
p = Plot_by_Name.new(
    'xaxis_column' => 'lg_time', 
    #'xaxis_column' => 'lg_age', 
    'log_mass_frac_ymin' => -6,
    'log_mass_frac_ymax' => 0.5,
    'num_abundance_line_labels' => 3,
    'xmin' => nil,
    'xmax' => nil,
    'title' => title,
    #'reverse_xaxis' => false, 
    #'reverse_xaxis' => true, 
    'data_dir' => data_dir, 
    'lines_before_real_data' => 2,
    'test_file' => logfile, 
    'names_file' => namesfile, 
    'first' => 0, 'last' => -1)
    #'first' => 1200, 'last' => 1250)

class Plot_by_Name
    def setup
        @first_abundance_column = 6
    end 
end

p.setup

