from synthesis import SimpleSynthesizer, Synthesizer
import unittest
from topology import *
from property import *
from problem import SynthesisProblem, FlowSpec
   
class GenerateTests(unittest.TestCase):

    def test_gen_transitions_simple_random(self):
        rest = [True, False]

        for i in range(2):
            for restrict in rest:
                if restrict:
                    r = range(3,50,1)
                else:
                    r = range(3,18,1)

                for size in r:
                    p1, p2 = gen_random_simple(i, size, size,max_num=size+size*(i+1))[0]

                    flow = FlowSpec(p1,p2,TTrue())
                    p = SynthesisProblem([flow])

                    if restrict:
                        base = SimpleSynthesizer(p)
                    else:
                        base = Synthesizer(p)

                    for n in range(len(p1)-1):
                        from_node = base.S[p1[n]]
                        to_node = base.T[p1[n+1]]

                        exists_trans = base.trans[0] & from_node & to_node
                        self.assertTrue(exists_trans != base.bdd.false)
                    
                    for n in range(len(p2)-1):
                        from_node = base.S[p2[n]]
                        to_node = base.T[p2[n+1]]

                        exists_trans = base.trans[0] & from_node & to_node
                        self.assertTrue(exists_trans != base.bdd.false)  

    def test_gen_transitions_complex_random(self):
        rest = [True, False]

        count = 0
        for i in range(2):
            for restrict in rest:
                if restrict:
                    r = range(50,100,5)
                else:
                    r = range(50,100,5)

                for size in r:
                    G = gen_random_topology(i, size, 0.05)

                    max_connected = max(nx.strongly_connected_components(G), key=len)
                    G = G.subgraph(max_connected)

                    pairs = gen_flow_pairs(G)
                    p1 = pairs[0][0]
                    p2 = pairs[0][1]

                    flow = FlowSpec(p1,p2,TTrue())
                    p = SynthesisProblem([flow])
                    
                    count = count + 1

                    if restrict:
                        base = SimpleSynthesizer(p)
                    else:
                        base = Synthesizer(p)

                    for n in range(len(p1)-1):
                        from_node = base.S[p1[n]]
                        to_node = base.T[p1[n+1]]

                        exists_trans = base.trans[0] & from_node & to_node
                        self.assertTrue(exists_trans != base.bdd.false)
                    
                    for n in range(len(p2)-1):
                        from_node = base.S[p2[n]]
                        to_node = base.T[p2[n+1]]

                        exists_trans = base.trans[0] & from_node & to_node
                        self.assertTrue(exists_trans != base.bdd.false)      

    def test_solution_generator_dia(self):
        for i in range(3,7):
            p1, p2, _ = gen_diamondSeparateWP(i)[0]

            prop = Reachability(p1[-1])

            p = SynthesisProblem([FlowSpec(p1,p2,prop)])

            base = SimpleSynthesizer(p,collapse=False)

            sols = base.all()

            generator = base.solution_generator()

            for sol in generator:
                self.assertTrue(base.is_solution(sol, sols))
                self.assertEqual(sol.degree(), 1)

    def test_solution_generator_conf_nosimple(self):
        p1, p2, _ = gen_confluent(1)[0]
        G, p1_path, p2_path = gen_topology(p1, p2, weighted=True)
        wp_prop = Waypoint(2,p1[-1])
        reach = Reachability(p1[-1])

        prop = Conjunction(wp_prop, reach)

        flow = FlowSpec(p1_path,p2_path,prop)

        p = SynthesisProblem([flow],top=G)
        p.top.add_edges_from([(0,2, {"weight": 1})])
     
        base = Synthesizer(p)

        sols = base.all()

        generator = base.solution_generator(6)

        self.assertEqual(base.satcount(6), 2)

        for sol in generator:
            self.assertTrue(sol.degree() >= 2)

    def test_solution_generator_1(self):
        p1 = [1, 6]
        p2 = [1, 4, 6]

        prop = Reachability(6)
        flow = FlowSpec(p1,p2,prop)
        p = SynthesisProblem([flow])
        base = Synthesizer(p) 

        self.assertEqual(base.satcount(2), 1)
        self.assertEqual(base.satcount(3), 0)
        self.assertEqual(base.satcount(4), 2)
        self.assertEqual(base.satcount(5), 0)

        self.assertEqual(base.satcount(2), len(list(base.solution_generator(2))))
        self.assertEqual(base.satcount(3), len(list(base.solution_generator(3))))
        self.assertEqual(base.satcount(4), len(list(base.solution_generator(4))))
        self.assertEqual(base.satcount(5), len(list(base.solution_generator(5))))

        ls_gen  = list(base.solution_generator(4))
        self.assertEqual(len(ls_gen),2)
        self.assertEqual(ls_gen[0].degree(), 3)
        self.assertEqual(ls_gen[1].degree(), 3)

    def test_solution_generator_2(self):
        p1 = [0, 10]
        p2 = [0, 6, 3, 9, 10]

        prop = Reachability(10)
        flow = FlowSpec(p1,p2,prop)
        p = SynthesisProblem([flow])
        base = Synthesizer(p)

        no_sol_gen = base.solution_generator(3)
        smallest_sol_gen = base.solution_generator(4)
        smallest_sol_gen_2 = base.solution_generator(5)

        self.assertEqual(len(list(no_sol_gen)),0)
        self.assertEqual(len(list(smallest_sol_gen)),6)
        self.assertEqual(len(list(smallest_sol_gen_2)),0)

if __name__ == '__main__':
    unittest.main(module='test_synthesis')