package org.example.assertmessageanalysis.actions;

import com.intellij.execution.PsiLocation;
import com.intellij.execution.junit.JUnitUtil;
import com.intellij.ide.highlighter.JavaFileType;
import com.intellij.openapi.actionSystem.AnAction;
import com.intellij.openapi.actionSystem.AnActionEvent;
import com.intellij.openapi.editor.Document;
import com.intellij.openapi.fileEditor.FileDocumentManager;
import com.intellij.openapi.project.Project;
import com.intellij.openapi.project.ProjectManager;
import com.intellij.openapi.project.ProjectUtil;
import com.intellij.openapi.ui.Messages;
import com.intellij.openapi.vfs.VirtualFile;
import com.intellij.psi.*;
import com.intellij.psi.search.FileTypeIndex;
import com.intellij.psi.search.GlobalSearchScope;
import com.intellij.util.DocumentUtil;
import com.intellij.util.indexing.FileBasedIndex;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.sql.SQLException;
import java.util.*;

public class ScanProject extends AnAction {

    static final String ASSERT_CLASS = "org.junit.Assert";
    AssertMethodEntityService assertMethodEntityService;

    @Override
    public void actionPerformed(@NotNull AnActionEvent e) {
        try {
            assertMethodEntityService = AssertMethodEntityService.getInstance(e.getProject().getName(), e.getProject().getBasePath());
        } catch (SQLException exception) {
            throw new RuntimeException(exception);
        }

        HashMap<VirtualFile, PsiClass> testFiles = GetTestFiles();

        for (PsiClass psiClass : testFiles.values()) {
            psiClass.accept(new JavaRecursiveElementVisitor() {
                @Override
                public void visitMethod(PsiMethod method) {
                    System.out.println(method.getName());
                    VirtualFile parentVirtualFile = method.getContainingFile().getOriginalFile().getVirtualFile();
                    String parentFilePath = parentVirtualFile.getPath();
                    String parentClassName = method.getContainingClass().getQualifiedName();
                    boolean isTestMethod = JUnitUtil.isTestMethod(new PsiLocation<>(method));
                    List<PsiMethodCallExpression> methodCallExpressions = getMethodCallExpressions(method);
                    for (PsiMethodCallExpression methodCallExpression : methodCallExpressions) {
                        String calledMethodName = methodCallExpression.getMethodExpression().getQualifiedName();
                        PsiMethod calledMethod = methodCallExpression.resolveMethod();
                        String calledMethodClassName;
                        try {
                            calledMethodClassName = methodCallExpression.resolveMethod().getContainingClass().getQualifiedName();
                        }catch (NullPointerException exception) {
                            continue;
                        }
                        @NotNull PsiExpressionList methodArguments = methodCallExpression.getArgumentList();

                        String message = "";
                        if (assertsWithoutParameter.contains(calledMethodName) && calledMethodClassName.equals(ASSERT_CLASS)) {
                            if (methodArguments.getExpressions().length == 1)
                                if (methodArguments.getExpressions()[0].getType().equals(PsiType.getJavaLangString(methodArguments.getManager(), methodArguments.getResolveScope())))
                                    if (methodCallExpression.resolveMethod().getParameters()[0].getType().equals(PsiType.getJavaLangString(methodArguments.getManager(), methodArguments.getResolveScope())))
                                        message = methodArguments.getExpressions()[0].getText();
                            saveData(calledMethodName, message, methodCallExpression.getText(), getLineNumber(parentVirtualFile, methodCallExpression), method.getName(), isTestMethod, parentClassName, parentFilePath, methodCallExpression.getProject().getName());
                        } else if (assertsWithOneParameter.contains(calledMethodName) && calledMethodClassName.equals(ASSERT_CLASS)) {
                            if (methodArguments.getExpressions().length == 2)
                                if (methodArguments.getExpressions()[0].getType().equals(PsiType.getJavaLangString(methodArguments.getManager(), methodArguments.getResolveScope())))
                                    if (methodCallExpression.resolveMethod().getParameters()[0].getType().equals(PsiType.getJavaLangString(methodArguments.getManager(), methodArguments.getResolveScope())))
                                        message = methodArguments.getExpressions()[0].getText();
                            saveData(calledMethodName, message, methodCallExpression.getText(), getLineNumber(parentVirtualFile, methodCallExpression), method.getName(), isTestMethod, parentClassName, parentFilePath, methodCallExpression.getProject().getName());
                        } else if (assertsWithTwoParameters.contains(calledMethodName) && calledMethodClassName.equals(ASSERT_CLASS)) {
                            if (methodArguments.getExpressions().length == 3)
                                if (methodArguments.getExpressions()[0].getType().equals(PsiType.getJavaLangString(methodArguments.getManager(), methodArguments.getResolveScope())))
                                    if (methodCallExpression.resolveMethod().getParameters()[0].getType().equals(PsiType.getJavaLangString(methodArguments.getManager(), methodArguments.getResolveScope())))
                                        message = methodArguments.getExpressions()[0].getText();
                            saveData(calledMethodName, message, methodCallExpression.getText(), getLineNumber(parentVirtualFile, methodCallExpression), method.getName(), isTestMethod, parentClassName, parentFilePath, methodCallExpression.getProject().getName());
                        } else if (assertsWithMultipleParameters.contains(calledMethodName) && calledMethodClassName.equals(ASSERT_CLASS)) {
                            if (methodArguments.getExpressions()[0].getType().equals(PsiType.getJavaLangString(methodArguments.getManager(), methodArguments.getResolveScope())))
                                if (methodCallExpression.resolveMethod().getParameters()[0].getType().equals(PsiType.getJavaLangString(methodArguments.getManager(), methodArguments.getResolveScope())))
                                    message = methodArguments.getExpressions()[0].getText();
                            saveData(calledMethodName, message, methodCallExpression.getText(), getLineNumber(parentVirtualFile, methodCallExpression), method.getName(), isTestMethod, parentClassName, parentFilePath, methodCallExpression.getProject().getName());
                        }
                    }
                }
            });
        }
        try {
            assertMethodEntityService.closeConnection();
        } catch (SQLException ex) {
            throw new RuntimeException(ex);
        }
        Messages.showMessageDialog("Analysis Completed!", "Assert Message Analysis", Messages.getInformationIcon());
        System.out.println("Completed!");
    }


    private void saveData(String assertMethodName, String assertMessage, String assertMethodText, int assertMethodLineNumber, String parentMethodName, boolean isTestMethod, String parentClassName, String parentFilePath, String projectName) {
        AssertMethodEntity assertMethodEntity = new AssertMethodEntity();
        assertMethodEntity.setAssertMethodName(assertMethodName);
        assertMethodEntity.setAssertMessage(assertMessage);
        assertMethodEntity.setAssertMethodText(assertMethodText);
        assertMethodEntity.setAssertMethodLineNumber(assertMethodLineNumber);
        assertMethodEntity.setParentMethodName(parentMethodName);
        assertMethodEntity.setTestMethod(isTestMethod);
        assertMethodEntity.setParentClassName(parentClassName);
        assertMethodEntity.setParentFilePath(parentFilePath);
        assertMethodEntity.setProjectName(projectName);
        System.out.println(assertMethodEntity.toString());
        try {
            assertMethodEntityService.writeAsertMethodData(assertMethodEntity);
        } catch (SQLException e) {
            throw new RuntimeException(e);
        }
    }

    HashSet<String> assertsWithoutParameter = new HashSet<>(
            List.of(
                    "fail"
            ));
    HashSet<String> assertsWithOneParameter = new HashSet<>(
            Arrays.asList(
                    "assertTrue",
                    "assertFalse",
                    "assertNotNull",
                    "assertNull"
            ));
    HashSet<String> assertsWithTwoParameters = new HashSet<>(
            Arrays.asList(
                    "assertNotSame",
                    "assertSame",
                    "assertThat"
                    //junit5-> "assertThrows",
                    //?"assertNotEquals"
            ));
    HashSet<String> assertsWithMultipleParameters = new HashSet<>(
            Arrays.asList(
                    "assertArrayEquals",
                    "assertEquals"
            ));

    private int getLineNumber(VirtualFile virtualFile, PsiMethodCallExpression psiMethod) {
        @Nullable Document document = FileDocumentManager.getInstance().getDocument(virtualFile);
        int startOffset = DocumentUtil.getLineStartOffset(psiMethod.getTextOffset(), document);
        return document.getLineNumber(startOffset) + 1;
    }

    private List<PsiMethodCallExpression> getMethodCallExpressions(PsiMethod method) {
        List<PsiMethodCallExpression> methodCallExpressionList = new ArrayList<>();
        PsiStatement @NotNull [] statements;
        try {
            statements = Objects.requireNonNull(method.getBody()).getStatements();
        }catch (NullPointerException e) {
            return methodCallExpressionList;
        }
        for (PsiStatement statement : statements) {
            if (statement instanceof PsiExpressionStatement) {
                PsiExpressionStatement expressionStatement = (PsiExpressionStatement) statement;
                PsiExpression expression = expressionStatement.getExpression();
                if (expression instanceof PsiMethodCallExpression) {
                    PsiMethodCallExpression methodCallExpression = (PsiMethodCallExpression) expression;
                    methodCallExpressionList.add(methodCallExpression);
                }
            }
        }
        return methodCallExpressionList;
    }

    private HashMap<VirtualFile, PsiClass> GetTestFiles() {
        HashMap<VirtualFile, PsiClass> testClassesMap = new HashMap<>();

        for (Project project : ProjectManager.getInstance().getOpenProjects()) {
            @NotNull Collection<VirtualFile> s = FileBasedIndex.getInstance()
                    .getContainingFiles(FileTypeIndex.NAME, JavaFileType.INSTANCE, GlobalSearchScope.projectScope(project));

            for (VirtualFile file : s) {
                PsiClass psiClass = JUnitUtil.getTestClass(PsiManager.getInstance(project).findFile(file));
                if (psiClass != null) {
                    testClassesMap.put(file, psiClass);
                }
            }
        }

        return testClassesMap;
    }
}
