#
#  Copyright (C) 2007, 2016, 2018  Smithsonian Astrophysical Observatory
#
#
#  This program is free software; you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation; either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#
#  You should have received a copy of the GNU General Public License along
#  with this program; if not, write to the Free Software Foundation, Inc.,
#  51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
import pytest

from numpy import arange
import sherpa.astro.models as models
from sherpa.utils import SherpaFloat
from sherpa.utils.testing import SherpaTestCase
from sherpa.models.model import ArithmeticModel, RegriddableModel2D, RegriddableModel1D, boolean_to_byte
from sherpa.models.basic import Const


class test_models(SherpaTestCase):
    excluded_model_classes = (ArithmeticModel, RegriddableModel1D, RegriddableModel2D, Const)

    def test_create_and_evaluate(self):
        x = arange(1.0, 5.0)
        count = 0

        for cls in dir(models):
            clsobj = getattr(models, cls)

            if ((not isinstance(clsobj, type)) or
                (not issubclass(clsobj, ArithmeticModel)) or
                (clsobj in self.excluded_model_classes)):
                continue

            # These have a very different interface than the others
            if cls in ('JDPileup', 'MultiResponseSumModel'):
                continue

            m = clsobj()
            self.assertEqual(type(m).__name__.lower(), m.name)
            count += 1

            if m.name == 'linebroad':
                m.vsini = 1e6

            try:
                if m.name.count('2d') or (m.name == 'hubblereynolds'):
                    pt_out  = m(x, x)
                    int_out = m(x, x, x, x)
                else:
                    pt_out  = m(x)
                    int_out = m(x, x)
            except ValueError:
                self.fail("evaluation of model '%s' failed" % cls)

            for out in (pt_out, int_out):
                self.assertTrue(out.dtype.type is SherpaFloat)
                self.assertEqual(out.shape, x.shape)

        self.assertEqual(count, 18)


@pytest.mark.parametrize("test_input, expected", [
    (True, b'1'),
    (False, b'0'),
    (None, b'0'),
    ("foo", b'0')
])
def test_boolean_to_byte(test_input, expected):
    assert boolean_to_byte(test_input) == expected
