"""Balancing Methods"""
import sharpy.utils.settings as settings
import numpy as np
from abc import ABCMeta
import sharpy.utils.cout_utils as cout
import sharpy.utils.rom_interface as rom_interface
import sharpy.rom.utils.librom as librom
import sharpy.linear.src.libss as libss

dict_of_balancing_roms = dict()

def bal_rom(arg):
    global dict_of_balancing_roms
    try:
        arg._bal_rom_id
    except AttributeError:
        raise AttributeError('Class defined as balanced rom has no _bal_rom_id attribute')
    dict_of_balancing_roms[arg._bal_rom_id] = arg
    return arg


class BaseBalancedRom(metaclass=ABCMeta):

    def initialise(self, in_settings=None):
        pass

    def run(self, ss):
        pass

@bal_rom
class Direct(BaseBalancedRom):
    __doc__ = librom.balreal_direct_py.__doc__
    _bal_rom_id = 'Direct'

    settings_types = dict()
    settings_default = dict()
    settings_description = dict()

    settings_types['tune'] = 'bool'
    settings_default['tune'] = True
    settings_description['tune'] = 'Tune ROM to specified tolerance'

    settings_types['rom_tolerance'] = 'float'
    settings_default['rom_tolerance'] = 1e-2
    settings_description['rom_tolerance'] = 'Absolute accuracy with respect to full order frequency response'

    settings_types['rom_tune_freq_range'] = 'list(float)'
    settings_default['rom_tune_freq_range'] = [0, 1]
    settings_description['rom_tune_freq_range'] = 'Beginning and end of frequency range where to tune ROM'

    settings_types['convergence'] = 'str'
    settings_default['convergence'] = 'min'
    settings_description['convergence'] = 'ROM tuning convergence. If ``min`` attempts to find minimal number of states.' \
                                          'If ``all`` it starts from larger size ROM until convergence to ' \
                                          'specified tolerance is found.'

    settings_types['reduction_method'] = 'str'
    settings_default['reduction_method'] = 'realisation'
    settings_description['reduction_method'] = 'Reduction method. ``realisation`` or ``truncation``'

    settings_table = settings.SettingsTable()
    __doc__ += settings_table.generate(settings_types, settings_default, settings_description)

    def __init__(self):
        self.settings = dict()

    def initialise(self, in_settings=None):
        if in_settings is not None:
            self.settings = in_settings

        settings.to_custom_types(self.settings, self.settings_types, self.settings_default)

    def run(self, ss):
        A, B, C, D = ss.get_mats()

        try:
            if ss.dt is not None:
                dtsystem = True
            else:
                dtsystem = False
        except AttributeError:
            dtsystem = False

        S, T, Tinv = librom.balreal_direct_py(A, B, C, DLTI=dtsystem)

        Ar = T.dot(A.dot(Tinv))
        Br = T.dot(B)
        Cr = C.dot(Tinv)

        if self.dtsystem:
            ss_bal = libss.ss(Ar, Br, Cr, self.ss.D, dt=self.ss.dt)
        else:
            ss_bal = libss.ss(Ar, Br, Cr, self.ss.D)

        if self.settings['tune']:
            kv = np.linspace(self.settings['rom_tune_freq_range'][0],
                             self.settings['rom_tune_freq_range'][1])
            ssrom = librom.tune_rom(ss_bal,
                                    kv=kv,
                                    tol=self.settings['rom_tolerance'],
                                    gv=S,
                                    convergence=self.settings['convergence'],
                                    method=self.settings['reduction_method'])

            return ssrom
        else:
            return ss_bal


@bal_rom
class FrequencyLimited(BaseBalancedRom):
    __doc__ = librom.balfreq.__doc__

    _bal_rom_id = 'FrequencyLimited'

    settings_types = dict()
    settings_default = dict()
    settings_description = dict()

    settings_types['frequency'] = 'float'
    settings_default['frequency'] = 1.
    settings_description['frequency'] = 'defines limit frequencies for balancing. The balanced model will be accurate ' \
                                        'in the range ``[0,F]``, where ``F`` is the value of this key. Note that ``F`` ' \
                                        'units must be consistent with the units specified in the in ' \
                                        'the ``self.ScalingFacts`` dictionary.'

    settings_types['method_low'] = 'str'
    settings_default['method_low'] = 'trapz'
    settings_description['method_low'] = '``gauss`` or ``trapz`` specifies whether to use gauss quadrature or ' \
                                         'trapezoidal rule in the low-frequency range ``[0,F]``'

    settings_types['options_low'] = 'dict'
    settings_default['options_low'] = dict()
    settings_description['options_low'] = 'Settings for the low frequency integration. See Notes.'

    settings_types['method_high'] = 'str'
    settings_default['method_high'] = 'trapz'
    settings_description['method_high'] = '``gauss`` or ``trapz`` specifies whether to use gauss quadrature or ' \
                                         'trapezoidal rule in the high-frequency range ``[F,FN]``'

    settings_types['options_high'] = 'dict'
    settings_default['options_high'] = dict()
    settings_description['options_high'] = 'Settings for the high frequency integration. See Notes.'

    settings_types['check_stability'] = 'bool'
    settings_default['check_stability'] = True
    settings_description['check_stability'] = 'if True, the balanced model is truncated to eliminate ' \
                                              'unstable modes - if any is found. Note that very accurate ' \
                                              'balanced model can still be obtained, even if high order ' \
                                              'modes are unstable.'

    settings_types['get_frequency_response'] = 'bool'
    settings_default['get_frequency_response'] = False
    settings_description['get_frequency_response'] = 'if True, the function also returns the frequency ' \
                                                     'response evaluated at the low-frequency range integration' \
                                                     ' points. If True, this option also allows to automatically' \
                                                     ' tune the balanced model.'

    # Integrator options
    settings_options_types = dict()
    settings_options_default = dict()
    settings_options_description = dict()

    settings_options_types['points'] = 'int'
    settings_options_default['points'] = 12
    settings_options_description['points'] = 'Trapezoidal points of integration'

    settings_options_types['partitions'] = 'int'
    settings_options_default['partitions'] = 2
    settings_options_description['partitions'] = 'Number of Gauss-Lobotto quadratures'

    settings_options_types['order'] = 'int'
    settings_options_default['order'] = 2
    settings_options_description['order'] = 'Order of Gauss-Lobotto quadratures'

    settings_table = settings.SettingsTable()
    __doc__ += settings_table.generate(settings_types, settings_default, settings_description)

    options_table = settings.SettingsTable()
    __doc__ += options_table.generate(settings_options_types, settings_options_default, settings_options_description,
                                      'The parameters of integration take the following options:\n')

    def __init__(self):
        self.settings = dict()

    def initialise(self, in_settings=None):

        if in_settings is not None:
            self.settings = in_settings

        settings.to_custom_types(self.settings, self.settings_types, self.settings_default)
        settings.to_custom_types(self.settings['options_low'], self.settings_options_types, self.settings_options_default)
        settings.to_custom_types(self.settings['options_high'], self.settings_options_types, self.settings_options_default)

        # Remove c-type
        for k in self.settings_types:
            if self.settings_types[k] == 'float' or self.settings_types[k] == 'int':
                self.settings[k] = self.settings[k].value
            elif self.settings_types[k] == 'dict':
                opt_dict = self.settings[k]
                for kk in opt_dict:
                    if self.settings_options_types[kk] == 'int':
                        opt_dict[kk] = opt_dict[kk].value

    def run(self, ss):

        output_results = librom.balfreq(ss, self.settings)

        return output_results[0]


@bal_rom
class Iterative(BaseBalancedRom):
    __doc__ = librom.balreal_iter.__doc__
    _bal_rom_id = 'Iterative'

    settings_types = dict()
    settings_default = dict()
    settings_description = dict()

    settings_types['lowrank'] = 'bool'
    settings_default['lowrank'] = True
    settings_description['lowrank'] = 'Use low rank methods'

    settings_types['smith_tol'] = 'float'
    settings_default['smith_tol'] = 1e-10
    settings_description['smith_tol'] = 'Smith tolerance'

    settings_types['tolSVD'] = 'float'
    settings_default['tolSVD'] = 1e-6
    settings_description['tolSVD'] = 'SVD threshold'

    settings_types['tolSVD'] = 'float'
    settings_default['tolSVD'] = 1e-6
    settings_description['tolSVD'] = 'SVD threshold'

    settings_table = settings.SettingsTable()
    __doc__ += settings_table.generate(settings_types, settings_default, settings_description)

    def __init__(self):
        self.settings = dict()

    def initialise(self, in_settings=None):
        if in_settings is not None:
            self.settings = in_settings

        settings.to_custom_types(self.settings, self.settings_types, self.settings_default)

    def run(self, ss):

        A, B, C, D = ss.get_mats()

        s, T, Tinv, rcmax, romax = librom.balreal_iter(A, B, C,
                                                       lowrank=self.settings['lowrank'],
                                                       tolSmith=self.settings['smith_tol'].value,
                                                       tolSVD=self.settings['tolSVD'].value)

        Ar = Tinv.dot(A.dot(T))
        Br = Tinv.dot(B)
        Cr = C.dot(T)

        ssrom = libss.ss(Ar, Br, Cr, D, dt=ss.dt)
        return ssrom


@rom_interface.rom
class Balanced(rom_interface.BaseRom):
    """Balancing ROM methods

    Main class to load a balancing ROM. See below for the appropriate settings to be parsed in
    the ``algorithm_settings`` based on your selection.

    Supported algorithms:
        * Direct balancing :class:`.Direct`

        * Iterative balancing

        * Frequency limited balancing :class:`.FrequencyLimited`

    """
    rom_id = 'Balanced'

    settings_types = dict()
    settings_default = dict()
    settings_description = dict()

    settings_types['algorithm'] = 'str'
    settings_default['algorithm'] = ''
    settings_description['algorithm'] = 'Balanced realisation method'

    settings_types['algorithm_settings'] = 'dict'
    settings_default['algorithm_settings'] = dict()
    settings_description['algorithm_settings'] = 'Settings for the desired algorithm'

    settings_table = settings.SettingsTable()
    __doc__ += settings_table.generate(settings_types, settings_default, settings_description)

    def __init__(self):
        self.settings = dict()
        self.algorithm = None
        self.ssrom = None
        self.ss = None
        self.dtsystem = None

    def initialise(self, in_settings=None):

        if in_settings is not None:
            self.settings = in_settings

        settings.to_custom_types(self.settings, self.settings_types, self.settings_default)

        if not (self.settings['algorithm'] in dict_of_balancing_roms):
            raise AttributeError('Balancing algorithm %s is not yet implemented' % self.settings['algorithm'])

        self.algorithm = dict_of_balancing_roms[self.settings['algorithm']]()
        self.algorithm.initialise(self.settings['algorithm_settings'])

    def run(self, ss):

        self.ss = ss

        A, B, C, D = self.ss.get_mats()

        if self.ss.dt:
            self.dtsystem = True
        else:
            self.dtsystem = False

        out = self.algorithm.run(ss)

        if type(out) == libss.ss:
            self.ssrom = out

        else:
            Ar, Br, Cr = out
            if self.dtsystem:
                self.ssrom = libss.ss(Ar, Br, Cr, D, dt=self.ss.dt)
            else:
                self.ssrom = libss.ss(Ar, Br, Cr, D)

        return self.ssrom


if __name__=='__main__':
    import sharpy.utils.docutils as docutils
    import sharpy.utils.sharpydir as sharpydir
    docutils.output_documentation_module_page(sharpydir.SharpyDir + '/sharpy/rom/balanced', '/rom')
