#include "cwe1314.h"
std::string keywords1314[11] = {"cipher","key","out","reg","start","sel","data","ctrl","reglk","acct","mem"};
std::string not_protection_signals[8] = {"clk","CLK","Clk","rst","RST","reset","RESET","write"};
bool protection_against_write  = true;
void Visitor1314::compareArrayProtection(){
    std::set<std::string> visited;
    for(auto& i:protection_map){
        // log_detail<<"in protection map "<<i.first<<std::endl;
        if(i.first.find("[" )== std::string::npos)
            continue;
        //find an array
        std::string name = i.first.substr(0,i.first.find("["));    
        //only check the name with specific keywords
        bool match = false;
        for(auto& keyword : keywords1314){
            if(name.find(keyword) != std::string::npos){
                match = true;
                break;
            }
        }
        if(!match) continue;
        //if the signal has been traversed, skip it
        if(visited.find(name) != visited.end())
            continue;
        visited.insert(name);
        int size = i.second.size();
        bool diff = false;
        for(auto &j :protection_map){
            if(i == j) continue;
            std::string name2 = j.first.substr(0,j.first.find("["));  
            if(name != name2) continue;
            if(j.first.find("[" )== std::string::npos) continue;
            // if(j.second.size() != size){
            if(j.second != i.second){
                log_detail<<"potential cwe 1314 detected\n";
                result.cwe_1314 = true;
                diff = true;
                log_detail<<"the signal array "<<name<<" has different protection\n";
                log_detail<<i.first<<" is protected by ";
                for (auto &k : i.second){
                    log_detail<<k<<" ";
                }
                log_detail<<"\n"<<j.first<<" is protected by ";
                for (auto &k : j.second){
                    log_detail<<k<<" ";
                }
                log_detail<<"\n";
                break;
            }
        }
        // if(diff){
        //     log_detail<<"the signal array "<<name<<" has different protection\n";
        // }

    }
}
void Visitor1314::identifyNotProtected(){
    // for(auto& i : protectedSignal){
    //     if(std::get<1>(i) != ""){
    //         protection_map[std::get<0>(i)].emplace_back(std::get<1>(i));
    //         // log_detail<<"signal "<<get<0>(i)<<" is protected by "<<get<1>(i)<<endl;
    //     }
    // }
    for(auto& i:protection_map){
        // prune out clock and reset as protected signals
        std::vector<std::string>::iterator it = i.second.begin();
        for(;it!=i.second.end();){
            // std::cout<< (*it)<<"\n";    
            bool ignore_controling_signal = false;
            for(int8_t i=0; i< 8; i++){

                if ( (*it).find(not_protection_signals[i])!=std::string::npos ) {
                    // log_detail<<"lock keyword matches\n";
                    ignore_controling_signal = true;
                    break;
                }
            }

            if(ignore_controling_signal)
                it = i.second.erase(it);
            else
                ++it;
        }

        // log_detail<< "after removal\n";
        // std::cout<<"size "<<i.second.size()<<std::endl;
        // getchar();
        if ( (i.second).empty() ){
        // if(i.second.size()<2){

            // check if the unprotected signal matches relevant keywords
            
            bool keyword_matches = true;
            // bool keyword_matches = false;
            // std::string keywords1314[11] = {"cipher","key","out","reg","start","sel","data","ctrl","reglk","acct","mem"};
            // for(int8_t j=0; j< 11; j++){

            //     if ( i.first.find(keywords1314[j])!=std::string::npos ) {
            //         // log_detail<<"lock keyword matches\n";
            //         keyword_matches = true;
            //         break;
            //     }
            // }

            if(keyword_matches){
                // if (!result.cwe_1314){
                //     log_detail<<"potential cwe 1314 detected\n";
                // }
                result.cwe_1314 = true;
                log_detail<<"potential cwe 1314 detected\n";
                log_detail<<"signal "<<i.first<<" is unprotected \n";
            }
        }
        // else{
        //     log_detail<<"signal "<<i.first<<" protected by ";
        //     for(auto& j:i.second){
        //         log_detail<<j<<" ";
        //     }
        //     log_detail<<"\n";
        // }

    }
}
void Visitor1314::VERI_VISIT(VeriAlwaysConstruct, node){
    result.cwe_1314_relevant_nodes++;

    protection.clear();
    signal_count = 0;
    // std::cout<<"================= new block ==================\n";
    //one always block record a list of protection signals
    VeriStatement *stmt = node.GetStmt() ; // Get the statement
    TraverseNode(stmt) ;
}
void Visitor1314::VERI_VISIT(VeriFor, node){
    result.cwe_1314_relevant_nodes++;

    // Array *initials_arr = node->GetInitials() ; // Gets the array of VeriModuleItem :the comma separated list of assignments and declarations
    // unsigned j ;
    // VeriModuleItem *item ;
    // FOREACH_ARRAY_ITEM(initials_arr, j, item) {
    //     TraverseNode(item) ; // Traverse if you want to ...
    // }

    // VeriExpression *cond = node->GetCondition() ; // Gets the test-condition (to stay or leave the loop)
    // TraverseNode(cond) ;  // Traverse if you want...

    // Array *repetitions_arr = node->GetRepetitions() ; // Gets the array of VeriStatement : the comma separated list of assignment statements
    // FOREACH_ARRAY_ITEM(repetitions_arr, j, item) {
    //     TraverseNode(item) ; // Traverse if you want...
    // }

    VeriStatement *for_content = node.GetStmt() ; // Gets 'for' statement
    //setting rhs = true to skip the code
    // is_rhs = true;
    TraverseNode(for_content) ; // Traverse if you want...
    // is_rhs = false;
}
void Visitor1314::VERI_VISIT(VeriIntVal, node){
    result.cwe_1314_relevant_nodes++;

    // std::cout<<"get index "<<node.GetNum()<<"\n";
    if(range_index){
        index2 = node.GetNum();
        range_index = false; 
    }
    else{
        index = node.GetNum();
    }
    // log_detail<<"index "<<index <<endl;
}
void Visitor1314::VERI_VISIT(VeriQuestionColon, node){
    result.cwe_1314_relevant_nodes++;

    // if statement in an assignment
    VeriExpression *if_expr = node.GetIfExpr() ; // Gets the condition
    //first time 
    //if no wdata, don't need to check the protection list
    VeriExpression *then_expr = node.GetThenExpr() ; // Gets the then part
    TraverseNode(then_expr) ; // Traverse (if you like ...)
    VeriExpression *else_expr = node.GetElseExpr() ; // Gets the else part
    TraverseNode(else_expr) ; // Traverse (if you like ...)
    control = true;
    TraverseNode(if_expr);
    control = false;
}
void Visitor1314::VERI_VISIT(VeriCaseStatement, node){
    result.cwe_1314_relevant_nodes++;

    VeriExpression *condition = node.GetCondition() ;
    //case statement not counted as protection for now
    // TraverseNode(condition);
    unsigned i ;
    VeriCaseItem *item ;
    Array *items = node.GetCaseItems() ;
    FOREACH_ARRAY_ITEM(items, i, item) {
        VeriStatement *case_stmt = item->GetStmt() ;
        // CustomTraverseStmt(case_stmt);
        TraverseNode(item) ;
    }
}
void Visitor1314::VERI_VISIT(VeriConditionalStatement, node){
    result.cwe_1314_relevant_nodes++;

    signal_count = 0;
    VeriExpression *if_expr = node.GetIfExpr() ;
    // std::cout<<"if\n";
    TraverseNode(if_expr) ; // Traverse if you want...
    int count = signal_count;
    VeriStatement *then_stmt = node.GetThenStmt() ;
    // std::cout<<"then\n";
    TraverseNode(then_stmt) ; // Traverse if you want...
    VeriStatement *else_stmt = node.GetElseStmt() ;
    // std::cout<<"else\n";
    TraverseNode(else_stmt) ; // Traverse if you want...
    //remove the signals in the sensitivity list of this if/else block
    
    // std::cout<<"count = "<<count<<"\n";
    // std::cout<<"protection size "<<protection.size()<<"\n";
    for(int i = 0;i < count;i++){
        // assert(protection.size() >= count);
        // std::string ss = protection.back();
        protection.pop_back();
        // std::cout<<"protection pop: "<<ss<<"\n";
    }
    signal_count -= count;
    // std::cout<<"finish pop out \n";
}

void Visitor1314::VERI_VISIT(VeriIdRef, node){
    result.cwe_1314_relevant_nodes++;
    // protection is to record the protection signal list
    // protectedsignal is to record: <the signal being protected,those protecting signals>
    std::string ss;
    if(index2 != INT_MAX){
        ss= std::string(node.GetName())+"["+std::to_string(index2)+":"+std::to_string(index)+"]";
    }
    else 
        ss= (index == INT_MAX) ? node.GetName() :std::string(node.GetName())+"["+std::to_string(index)+"]";
    // std::cout<<"name "<<ss<<"\n";
    if(protection_against_write && is_rhs){
        //see if the rhs of this assignment is wdata 
        if(ss.find("wdata")!=std::string::npos){
        // if(ss == "wdata"){
            is_wdata =true;
            return;
        }
    }
    if(!is_rhs && !is_lhs && !control){
        // prune out clock and reset signals as protections
        //insert signals in if_expr or case_expr
        protection.emplace_back(ss);
        signal_count++;
        // std::cout<<"signal count = "<<signal_count<<"\n";
    }
    // log_detail<<"insert signal : "<<ss<<endl;
    if(is_lhs){
        // protectedSignal.emplace_back(make_tuple(ss,""));
        // if(protection_map.find(ss) != protection_map.end()){
        //     for(auto &signal : protection){
        //         protection_map[ss].emplace_back(signal);
        //     }
        // }
        // else{
            protection_map.insert(make_pair(ss,protection));
        // }
        for(auto &it : protectedSignal){
            protection_map[ss].emplace_back(it);

        }
        protectedSignal.clear();
    }
    if(control){
        protectedSignal.emplace_back(ss);
    }
    index = INT_MAX;
    index2 = INT_MAX;
}

void Visitor1314::VERI_VISIT(VeriIndexedId, node)
{
    result.cwe_1314_relevant_nodes++;
    
    // log_detail<<"another index "<<endl;
    // Call of Base class Visit
    // log_detail<<"index expression range? "<<node.GetIndexExpr()->IsRange()<<endl;
    range_index = node.GetIndexExpr()->IsRange();
    TraverseNode(node.GetIndexExpr());
    // getchar();
    // log_detail<<"prefix "<<endl;
    TraverseNode(node.GetPrefix()) ;
    // getchar();
    // Do not traverse the index!
}

void
Visitor1314::VERI_VISIT(VeriIndexedMemoryId, node)
{
    result.cwe_1314_relevant_nodes++;

    // Call of Base class Visit
    VeriVisitor::VERI_VISIT_NODE(VeriName, node) ;
    TraverseNode(node.GetPrefix()) ;
    // Do not traverse the index!
}


void Visitor1314::VERI_VISIT(VeriBlockingAssign, node){
    result.cwe_1314_relevant_nodes++;

	// index = INT_MAX;
    is_wdata = false;
    // Traverse left hand side of assignment
    is_rhs = true;
    TraverseNode(node.GetValue()) ;
    is_rhs = false;
    if(!is_wdata){
        //clear the recorded protection signal 
        protectedSignal.clear();
        return;
    }
    is_lhs = true;
    TraverseNode(node.GetLVal()) ;
    is_lhs = false;
    // Pre increment decrement has right expression set, so do not reset it here:
    
    // Traverse the value of assignment
    // Traverse delay or event control
    // control = true;
    // TraverseNode(node.GetControl()) ;
    // control = false;
}
void Visitor1314::VERI_VISIT(VeriNonBlockingAssign, node){
    result.cwe_1314_relevant_nodes++;

    // index = INT_MAX;
    is_wdata = false;
	
    is_rhs = true;
    TraverseNode(node.GetValue()) ;
    is_rhs = false;
    if(!is_wdata){
        //clear the recorded protection signal 
        protectedSignal.clear();
        return;
    }
    // Traverse left hand side of assignment
    is_lhs = true;
    TraverseNode(node.GetLVal()) ;
    is_lhs = false;
    // Pre increment decrement has right expression set, so do not reset it here:
    
    // Traverse the value of assignment
    
    // Traverse delay or event control
    // control = true;
    // TraverseNode(node.GetControl()) ;
    // control = false;
}
