# 2020.02.19, Yaohua Xie
# This is the source code of paper "Identifying patients from a large number of people by a small number of tests (2020)"
# Author: Yaohua Xie
# Email: Yaohua.Xie@hotmail.com
# ORCID: 0000-0001-6780-3156 or 0000-0002-7345-4316
# This study is performed mainly for pandemic prevention, but the proposed approach is versatile for various diseases.
#
# This program is originally used to verify the proposed approach.
# But is can also be used as a tool to find out optimal value of M for a certain pair of P and Q, where:
# P is the number of people to be screened, which is known before screening.
# Q is the number of patients among the above people, which can only be estimated before screening.
# M is the "division parameter", and each large group is divided into M smaller groups recursively.
# First, input the value of P and Q, respectively.
# Then, the optimal M value can be found in matrix TotalChk after running.
# It is the serial number corresponds to the smaller value of TotalChk.


from numpy import *
from pylab import plot, stem, title, show


# 2020.02.19, Yaohua Xie
#
# pInfo - the information of people in this group.
# cResult - the checked results of people in this group.
# M - each group will be divided into M smaller groups.
def HandleThisGroup(pInfo, M):
    global chkResult
    if (pInfo.shape[0] > M):  # The group has more people than M.
        # Divide pInfo into smaller sub-groups:
        N = floor(pInfo.shape[0] / M);  # The number of people in each of the sub-group.
        for i in range(1, M + 1):  # Handle each sub-group.
            s = int((i - 1) * N)  # The starting number of the present sub-group.
            if (i < M):
                s2 = int(s + N)
                pInfoSub = pInfo[s:s2, :]
            else:
                # The last sub-group includes all the rest people:
                pInfoSub = pInfo[s:pInfo.shape[0], :]

            rSubGroup = CheckOneGroup(pInfoSub)  # Get the result of the present sub-group.
            if (rSubGroup == True):
                # If this sub-group includes ill people, further check this sub-group recursively:
                HandleThisGroup(pInfoSub, M)
            else:
                ExcludeOneGroup(pInfoSub)

    else:  # The group has no more people than M.
        # Do not divide pInfo again, but check each person one by one:
        for i in range(0, pInfo.shape[0]):
            # Check each person's illness status, and recorded in
            # the cResult element with the same serial number:
            chkResult[int(pInfo[i, 0]) - 1, 1] = CheckOnePerson(pInfo[i, :])


# 2020.02.19, Yaohua Xie
# This function represents the procedure of checking whether a person is
# ill or not (0: has no illness, 1: has illness). In this function,
# this procedure is simulated by returning the value of the 2nd element of pInfo.
#
# pInfo - the information of people in this group.
def CheckOnePerson(pInfo):
    global nPeopleChecked
    result = pInfo[1]
    nPeopleChecked = nPeopleChecked + 1

    return result


# 2020.02.19, Yaohua Xie
# This function represents the procedure of checking whether a group
# includes any ill people or not. In this function, this procedure is
# simulated by checking the sum of all the elements in this group.
# (sum=0: has no ill person, sum>0: has at least one ill person).
#
# pInfo - the information of people in this group.
def CheckOneGroup(pInfo):
    global nGroupChecked

    S = sum(pInfo[:, 1])
    if (S == 0):
        result = 0
        nGroupChecked = nGroupChecked + 1
    elif (S > 0):
        result = 1
        nGroupChecked = nGroupChecked + 1
    else:
        result = -1

    return result


# 2020.02.19, Yaohua Xie
# This function marks all the people's checked results in a group to zeros (have no illness).
#
# pInfo - the information of people in this group.
def ExcludeOneGroup(pInfo):
    global chkResult
    for i in range(0, pInfo.shape[0]):
        chkResult[int(pInfo[i, 0]) - 1, 1] = 0


## Simulate the information of people, including a small percentage of patients:

# Sometimes we do not need to identify each patient, but just want to know
# whether there are any patients among a large group of people.
# In that case, we only need to perform a "merged test" on the whole group,
# and do not need to further divide it into smaller groups.
#
# The following codes are used for identifying each patients.

P = int(input('Please input the number of people to be screened (greater than 1): '))
Q = int(input('Please input the estimate number of patients (between 2 and the above number): '))

print('Processing...')

RPT = 30  # repeat multiple times
OptCheckingAll = full((RPT,2), inf)
OptResourceAll = full((RPT,2), inf)

for t in range(0, RPT):
    # Column 1: each element represents a person's serial number.
    # Column 2: each element represents a person's true status of illness.
    # (0: has no illness, 1: has illness):
    pplInfo = zeros((P, 2))
    pplInfo[:, 0] = range(1, P + 1)

    # Simulate several patients, i.e., set several elements of pplInfo to 1:
    r = Q
    while (r > 0):
        curNum = int(floor(random.rand() * P))
        if (pplInfo[curNum, 1] != 1):
            pplInfo[curNum, 1] = 1
            r = r - 1

    ## Test and compare different M values:

    # Column 1: each element represents a person's serial number.
    # Column 2: each element represents a person's checked result of illness.
    # (-1: unknown, 0: has no illness, 1: has illness):
    global chkResult
    global nGroupChecked
    global nPeopleChecked

    # These arrays use M value as indices:
    TotalGrpChk = full(P+1, inf)
    TotalPplChk = full(P+1, inf)
    TotalChk = full(P+1, inf)
    SynEvaluation = full(P+1, inf)

    for M in range(2, P + 1):  # The number of sub-groups in each time of division (should between 2 and P).
        ## Screen the people recursively:

        chkResult = full((P, 2), inf)
        chkResult[:, 0] = range(1, P + 1)
        nGroupChecked = 0
        nPeopleChecked = 0

        HandleThisGroup(pplInfo, M)

        # After screening, chkResult should be the same as pplInfo:
        if ((chkResult == pplInfo).all()):
            pass  # print('Screening completed!')
        else:
            print('Incorrect results!')

        ## Analysis of performance:

        if ((chkResult == pplInfo).all()):
            # The number of checked group:
            TotalGrpChk[M] = nGroupChecked

            # The number of checked people:
            TotalPplChk[M] = nPeopleChecked

            # The total number of both type of checking:
            TotalChk[M] = nGroupChecked + nPeopleChecked

            # Calculate synthesized evaluation in case the two types require different resources:
            w1 = 0.6
            w2 = 0.4
            SynEvaluation[M] = w1 * nGroupChecked + w2 * nPeopleChecked
            # Assume that checking a group requires w1 resources, and checking a person requires w2 resources.

    ## Compare all the results of different M values:

    # print('Total people number is:' + str(P))
    # print('Ill people number is:' + str(Q))

    TotalChk_min = min(nan_to_num(TotalChk[0:P+1]))  # Regular checking (one by one) is also included.
    TotalChk_min_idx = argmin(nan_to_num(TotalChk[0:P+1]))
    SynEval_min = min(nan_to_num(SynEvaluation[0:P+1]))  # Regular checking (one by one) is also included.
    SynEval_min_idx = argmin(nan_to_num(SynEvaluation[0:P+1]))

    # In optimal situation, the required checking is OptChecking times of the usual one:
    OptChecking = TotalChk_min / TotalChk[P]

    # In optimal situation, the required resource is OptResource times of the usual one:
    OptResource = SynEval_min / SynEvaluation[P]

    # # Display the results (may look differently each time because pplInfo is random):
    # stem(TotalGrpChk); title('TotalGrpChk'); show()
    # stem(TotalPplChk); title('TotalPplChk'); show()
    # stem(TotalChk); title('TotalChk'); show()
    # stem(SynEvaluation); title('SynEvaluation'); show()

    # Display one of the results:
    # if (t == 0):
    #    stem(TotalChk); title('TotalChk'); show()

    OptCheckingAll[t,0] = OptChecking  # Record all the OptChecking values in an array for comparison.
    OptCheckingAll[t,1] = TotalChk_min_idx # Record all the M value corresponding to OptChecking values.
    OptResourceAll[t,0] = OptResource  # Record all the OptResource values in an array for comparison.
    OptResourceAll[t,1] = SynEval_min_idx # Record all the M value corresponding to OptResource values.

#print(OptCheckingAll.T)

# Determine the statistically best M value:
M_values = OptCheckingAll[:,1].astype(int)
M_best = argmax(bincount(M_values))

print('Each group is recursively divided into M smaller groups, and recommended value of M is:', M_best)
