import ivadomed.mixup as imed_mixup
import torch
import pytest
import logging
import os
from unit_tests.t_utils import remove_tmp_dir, create_tmp_dir,  __tmp_dir__
logger = logging.getLogger(__name__)


def setup_function():
    create_tmp_dir()


@pytest.mark.parametrize("debugging", [False, True])
@pytest.mark.parametrize("ofolder", [os.path.join(__tmp_dir__, "test"),
                                     os.path.join(__tmp_dir__, "mixup_test")])
def test_mixup(debugging, ofolder):
    inp = [[[[0 for i in range(40)] for i in range(40)]]]
    targ = [[[[0 for i in range(40)] for i in range(40)]]]
    for i in range(10):
        for j in range(10):
            targ[0][0][i][j] = 1
    inp = torch.tensor(inp).float()
    targ = torch.tensor(targ).float()
    # just testing if mixup function run
    imed_mixup.mixup(inp, targ, alpha=0.5, debugging=debugging, ofolder=ofolder)


def teardown_function():
    remove_tmp_dir()
