import math;
import os;
import random;
import re;
import shutil;
import subprocess;
import sys;

workingDirectory = os.path.join(".", "xorSampleTemp");
"""The directory that the temporary files are created in"""
programCopyPath = os.path.join(workingDirectory, "program.lp");
"""The path to the original file (with #show's added if needed"""
programWithConstraintsPath = os.path.join(workingDirectory, "programWithConstraints.lp");
"""The path to the file that is used for solving"""
settingArgs = [];
"""the arguments that should be passed to clasp"""
settingArgsGringo = [];
"""the arguments that should be passed to gringo"""
settingC = 1;
"""the maximum number of answer sets to choose at once"""
settingN = 1;
"""the number of answer sets to compute"""
settingQ = 0.5;
""""the probability for an atom to be included in an XOR constraint; gets multiplied with 100 after the parsing of the arguments has finished"""
settingS = -1;
"""the number of constraints"""
settingT = 0;
"""the time limit for solving"""
settingFile = "";
"""the path of the ASP file"""
settingAddShows = True;
"""indicates whether the show statements show be added or not"""
atoms = [];
"""the number of atoms that occur in the grounded program"""
predicatesWithArities = [];
"""the predicates and their arities in the original program"""
xorPredicateIndex = 1;
"""the current index of the predicate for the xor constraint"""
xorPredicateName = "xorPredicate";


def cleanUp():
    """Removes the working directory"""
    shutil.rmtree(workingDirectory, True, None);       

def usage():       
    """Prints the usage information."""
    
    print(os.path.basename(sys.argv[0]) + " [--help] [--a=args]* [--g=gArgs]* [--c=C] [--n=N] [--q=Q] [--s=S] [--t=T] [--dontAddShows] file");
    print("A tool for calculating near-uniformly distributed answer sets");          
    print("usage:");
    print("--help:\tdisplays this usage info");
    print("--a=args:\tArguments that should be passed on to clasp.");    
    print("--g=Gargs:\tArguments that should be passed on to gringo.");
    print("--c=C:\t\tThe maximum number of answer sets to pick per iteration\n\t\t(default: 1)");
    print("--n=N:\t\tset an positive integer N as the number of answer sets to compute\n\t\t(default: 1)");
    print("--q=Q:\t\tset Q as the probability for an atom to be included in an XOR\n\t\tconstraint (0.01 <= Q <= 0.5, default:0.5)");
    print("--s=S:\t\tset an positive integer S as the initial number of constraints\n\t\t(default:log(X) where X is the number of atoms in the grounding)");    
    print("--t=T:\t\tset the time limit in seconds for clasp\n\t\t(default: 0, 0 = no time limit)");
    print("--dontAddShows:\tDoesn't add #show expressions if supplied");
    print("file:\tThe file that contains the ASP program");  
    sys.exit()  
    
def parseArgs():
    """Tries to parse the arguments that the program has been started with"""
    
    cSet = False;
    nSet = False;
    qSet = False;
    sSet = False;    
    tSet = False;
    fileSet = False; 
    addShowsSet = False;   
    
    global settingArgs;
    global settingArgsGringo;
    global settingC;
    global settingN;
    global settingQ;
    global settingS;
    global settingT;
    global settingFile;   
    global settingAddShows; 
    
    for arg in sys.argv[1:]: #skip the program name
        if arg == "--help":
            usage();            
        elif arg[0:4] == '--a=':            
            settingArgs.append(arg[4:]);
        elif arg[0:4] == '--g=':            
            settingArgsGringo.append(arg[4:]);         
        elif arg[0:4] == '--c=':
            if cSet:
                usage();
            cSet = True;
            settingC = tryParsePositiveInteger(arg[4:]);              
        elif arg[0:4] == '--n=':
            if nSet:
                usage();
            nSet = True;
            settingN = tryParsePositiveInteger(arg[4:]);
        elif arg[0:4] == "--q=":
            if qSet:
                usage();
            qSet = True;
            settingQ = tryPraseParameterQ(arg[4:]);
        elif arg[0:4] == "--s=":
            if sSet:
                usage();
            sSet = True;
            settingS = tryParsePositiveInteger(arg[4:]);
        elif arg[0:4] == "--t=":
            if tSet:
                usage();
            tSet = True;
            settingT = tryParseInteger(arg[4:]);
        elif arg[0:14] == "--dontAddShows":
            if addShowsSet:
                usage();
            addShowsSet = True;
            settingAddShows = False;
        elif arg[0:1] != "-":
            if fileSet:
                usage();
            fileSet = True;
            settingFile = arg;          
        else:
            usage(); 
            
    if not fileSet:
        usage();           
    
    if not os.path.isfile(settingFile):
        print("The provided file was not found");
        sys.exit();            
    
def tryParsePositiveInteger(s):
    """Tries to convert s into a positive integer and calls usage if not possible
    
    :param s: The string that should be parsed
    """
    r = tryParseInteger(s);    
    if r == 0:
        usage();
    return r;  
    
def tryParseInteger(s):
    """Tries to convert s into an integer and calls usage if not possible
    
    :param s: The string that should be parsed
    """
    if not s.isdigit():
        usage();    
    r = int(s); 
    if r < 0:
        usage();   
    return r;                   

def tryPraseParameterQ(s):
    """Tries to convert s into a real number between 0.01 (inclusive) and 0.5 (inclusive)
    
    :param s: The string that should be parsed
    """
    try:
        fs = float(s);        
        if fs < 0.01 or fs > 0.5 or math.isnan(fs):
            usage();
        return fs;
    except ValueError:
        usage();  
        
def getAtoms():
    """Tries to read the atoms of the grounded version of the provided file"""    
    regexAtom = re.compile('(?<=[0-9] ).*',re.I+re.S)          
    
    #Ground the program using gringo
    grounding = callGringo([settingFile]).split(os.linesep);          
    
    #parse the atoms from the grounding
    atoms = []
    atomsBlockStarted = False    
    for line in grounding:
        if not atomsBlockStarted: #Skip everything before the block with the atoms
            if line == "0":
                atomsBlockStarted = True;
                continue;
        
        if (atomsBlockStarted == True):  
            if line == "0": #End of the block with the atoms is reached
                break;
            
            atoms.append(regexAtom.findall(line)[0]) 
    
    
    #Create the working copy of the original file
    shutil.copyfile(settingFile, programCopyPath)    
    
    #Build predicatesWithArities and check if the original file contains a #show:
    global predicateNameAndArity;   
    with open(programCopyPath, "r") as f:
        originalFile = f.read();  
     
    with open(programCopyPath, "a") as f:                    
    
        regexShow = re.compile("#show .*($|" + os.linesep + ")", re.I);
        addShows = len(regexShow.findall(originalFile)) == 0;            
                                      
        for atom in atoms:
            if "(" in atom:
                predicateName = atom[0:atom.index("(")]; #The name of the predicate
            else:
                predicateName = atom; #The name of the predicate (arity 0)
            
            #Find the arity of the predicate
            if "(" in atom:
                arity = len(atom.split(","));
            else:
                arity = 0;
            predicateNameAndArity = predicateName + "/" + str(arity);
         
            if predicateNameAndArity not in predicatesWithArities:
                predicatesWithArities.append(predicateNameAndArity);                
                if addShows and settingAddShows:
                    f.write("#show " + predicateNameAndArity + "." + os.linesep);                                           
                                 
    return atoms;

def estimateS():
    """Tries to estimate the parameter S of the algorithm"""
    global settingS;
    
    #Define the initial guess
    if settingS == -1:
        nextEstimation = int(math.log(len(atoms), 2));
    else:
        nextEstimation = settingS;        
       
    lastEstimationWasSatifiable = False; #True if the last estimation yielded at least one answer set
    iteration = 0; #indicates the index of the current iteration 
    currentNumberOfConstraints = 0; #indicates how many constraints have already been added
    doReset = False; #indicates if a new set of constraints should be started or not
    
    while True:
        iteration += 1;
        atomsToAdd = []; #A set of sets of atoms to add as constraints
        
        #Adds (nextEstimation - currentNumberOfConstraints) constraints to the current set of constraints
        for _ in range(0, nextEstimation - currentNumberOfConstraints):
            atomsToAdd.append(chooseAtoms());                            
        
        #start a new set of constraints if the last set did not yield any answer set or append the new set of constraints to the already existing one otherwise
        addConstraints(atomsToAdd, lastEstimationWasSatifiable or doReset);
        doReset = False;
        currentNumberOfConstraints += len(atomsToAdd);        
       
        #Tests if the current set of constraints is satisfiable       
        isSatisfiable = checkForSat();   
        
        #Did we run into a timeout?
        if isSatisfiable is None:
            doReset = True;
            if lastEstimationWasSatifiable:
                nextEstimation += int(pow(2, iteration/2));
            else:
                nextEstimation -= int(pow(2, iteration/2));                        
                              
        
        if iteration == 1: #handle our first estimation
            previousEstimation = nextEstimation;
            
            if isSatisfiable:
                lastEstimationWasSatifiable = True;            
                nextEstimation += 1;
            else:
                lastEstimationWasSatifiable = False;            
                nextEstimation -= 1;
        else: #the current estimation was not our first one
            if isSatisfiable:
                if lastEstimationWasSatifiable: #the previous and the current estimation yield at least 1 answer set => increase the number of constraints
                    previousEstimation = nextEstimation;
                    nextEstimation += int(pow(2, iteration/2));
                    lastEstimationWasSatifiable = True;
                else: #the current estimation is the first that yielded at least 1 answer set => the current estimation gets chosen                               
                    return nextEstimation; 
                
            else:                            
                if lastEstimationWasSatifiable: #the previous estimation yielded an answer set but the current does not => the previous one was the correct estimation
                    return previousEstimation;
                else: #neither the previous nor the current estimation yielded an answer set => decrease the number of constraints                    
                    previousEstimation = nextEstimation;
                    nextEstimation -= int(pow(2, iteration/2)); 
                    lastEstimationWasSatifiable = False;                               
        
        if nextEstimation < 1: #the fallback case if we did not find any number of constraints that yielded at least one answer set
            return 1;
        
        if not lastEstimationWasSatifiable: #Reset the set of constraints if the current one does not yield any answer sets
            currentNumberOfConstraints = 0;
        
def chooseAtoms():
    """Randomly selects a subset of the atoms based on the given probability settingQ"""
    chosenAtoms = [];
    for atom in atoms:
        r = random.randint(0, 100);
        if(settingQ > r):
            chosenAtoms.append(atom);            
    return chosenAtoms;
    
def addConstraints(atomsForConstraints, doAppend):
    """Adds atomsForConstraints to the current encoding or starts a new one if none exists or doAppend == False
    
    :param atomsForConstraints: A list of sets of atoms that should be used to create xor-constraints
    :param doAppend: True if the constraints should be added to an existing file (if one exists), False otherwise.    
    """                
    
    #Remove a existing file if we don't want to append
    if (not doAppend) and os.path.isfile(programWithConstraintsPath):
        os.remove(programWithConstraintsPath);
    
    #Copy the initial file if we don't want to append or none exists
    if (not doAppend) or (not os.path.isfile(programWithConstraintsPath)):
        shutil.copyfile(programCopyPath, programWithConstraintsPath)
    
    global xorPredicateIndex;
    
    #Reset the predicate counter if we start a new file
    if not doAppend:
        xorPredicateIndex = 1;
    
    with open(programWithConstraintsPath, 'a') as programFile:     
        programFile.write(os.linesep);            
        
        if len(atomsForConstraints) > 0:                                    
            for atomsForConstraint in atomsForConstraints:
                if len(atomsForConstraint) > 0:
                    
                    #The predicate to use for the xor constraint
                    xorPredicateToUse = None;                                     
                    
                    #Find a predicates that we can use:
                    while xorPredicateToUse == None:
                        predicateName = xorPredicateName + str(xorPredicateIndex);
                        if (predicateName + "/1") not in predicatesWithArities:
                            xorPredicateToUse = predicateName;                        
                        xorPredicateIndex = xorPredicateIndex + 1;                                                                                                            
                    
                    #Determines whether the constraint is even or odd
                    isEven = random.randint(0,1) == 0;                           
                    
                    #The index of the current atom: 
                    currentIndex = 1;                                
                    
                    for atom in atomsForConstraint:  
                        currentAtom = xorPredicateToUse + "(" + str(currentIndex) + ")";
                        previousAtom = xorPredicateToUse + "(" + str(currentIndex - 1) + ")";
                        
                        if currentIndex == 1:
                            programFile.write(currentAtom + " :- " + atom + "." + os.linesep);
                        else:
                            programFile.write(currentAtom + " :- " + atom + ", not " + previousAtom + "." + os.linesep);
                            programFile.write(currentAtom + " :- not " + atom + ", " + previousAtom + "." + os.linesep);
                        
                        currentIndex = currentIndex + 1;   
                        
                    if isEven:
                        programFile.write(":- " + currentAtom + "." + os.linesep);
                    else:
                        programFile.write(":- not " + currentAtom + "." + os.linesep);                                  
    
    
def checkForSat():
    r = solve(True, 1);    
    if r is None:
        return None;
    else:                
        return r[0] != 0;

def solve(onlyCheckForSAT = False, answerSetsToChoose = 1):
    """Solve the current encoding using gringo and clasp and returns None if if the time limit has been reached or a tuple (X, Y) where X is the number of answer sets found and Y is a list with min(answerSetsToChoose, number of answer sets found) otherwise randomly selected answer if onlyCheckForSAT = False or a list with only one answer set (not randomly selected) otherwise.
        
    :param onlyCheckForSAT: True if at most 1 answer set (not randomly selected) should be calculated, false otherwise  
    :param answerSetsToChoose: The number of answer sets to return (ignored if onlyCheckForSAT = True). 
    """    
    
    if (not onlyCheckForSAT) and answerSetsToChoose <= 0:
        return (0, None);
    
    #Solve the program using gringo/clasp
    if onlyCheckForSAT:  
        solverResults = callSolver([programWithConstraintsPath], ["--models=1"]);
    else:
        solverResults = callSolver([programWithConstraintsPath], ["--models=0"]);
        
    #Determine whether the program is satisfiable             
    if "TIME LIMIT   : 1" in solverResults:
        return None; 
             
    isSat = "\nSATISFIABLE\n" in solverResults    
    if (not isSat):
        if not ("\nUNSATISFIABLE\n" in solverResults):
            raise Exception("Failed to solve the current program using clasp: " +  solverResults);
        else:
            return (0, None);    
    
    #The program is satisfiable => extract the answer sets                      
    if onlyCheckForSAT: #Pick the first answer set                 
        answerSetRegex = re.compile(r"(?<=Answer: 1\n)(.*)",re.I);
        return (1, [answerSetRegex.findall(solverResults)[0]]);               
    else: #Pick a list of answer sets        
        #The list of answer sets to return        
        selectedAnswerSets = [];          
                     
        #Extract the number of answer sets        
        modelsRegEx = re.compile("(?<=Models       : )[0-9]*", re.I+re.S); 
        numberOfAnswerSets = int(modelsRegEx.findall(solverResults)[0]);               
               
        #Find the index of the first answer set
        solverResultLines = solverResults.split(os.linesep);
        resultOffset = solverResultLines.index("Answer: 1");       
        
        #The number of already selected answer set
        answerSetsSelected = 0;                  
              
        #Choose the answer sets randomly        
        while (answerSetsToChoose > 0) and (numberOfAnswerSets - answerSetsSelected  > 0):                            
            chosenIndex = random.randint(1, numberOfAnswerSets - answerSetsSelected);                                                               
            selectedAnswerSets.append(solverResultLines[resultOffset + 2 * chosenIndex - 1]);  
            del solverResultLines[resultOffset + 2 * chosenIndex - 2]; #Remove the Answer: n line
            del solverResultLines[resultOffset + 2 * chosenIndex - 2]; #Remove the line with the atoms
            answerSetsToChoose = answerSetsToChoose - 1;     
            answerSetsSelected = answerSetsSelected + 1;       
    
        return (numberOfAnswerSets, selectedAnswerSets);        

def callSolver(gringoParameters, claspParameters):
    """Calls gringo and clasp with the given parameters and returns it output
    
    :param parameters: The parameters that should be passed to     
    """               
    gringoArguments = ["gringo"];
    if len(settingArgsGringo) > 0:
        gringoArguments = gringoArguments + settingArgsGringo;
    gringoArguments = gringoArguments + gringoParameters;    
    
    claspArguments = ["clasp", "--time-limit=" + str(settingT)];
    if len(settingArgs) > 0:
       claspArguments = claspArguments + settingArgs;
    claspArguments = claspArguments + claspParameters;
    
    gringoProcess = subprocess.Popen(gringoArguments, stdout=subprocess.PIPE, stderr=subprocess.PIPE);                   
    
    claspProcess = subprocess.Popen(claspArguments, stdin = gringoProcess.stdout, stdout=subprocess.PIPE, stderr = subprocess.PIPE)     
    claspOutput, claspUnused_err = claspProcess.communicate();
    
    returnValue = str(claspOutput, "utf-8");
    gringoProcess.stdout.close();
    return returnValue;       

def callGringo(parameters):
    """Calls gringo with the given parameters and returns it output
    
    :param parameters: The parameters that should be passed to gringo    
    """               
    if len(settingArgsGringo) > 0:    
        process = subprocess.Popen(["gringo"] + settingArgsGringo + parameters, stdout=subprocess.PIPE, stderr=subprocess.PIPE);
    else:
        process = subprocess.Popen(["gringo"] + parameters, stdout=subprocess.PIPE, stderr=subprocess.PIPE);
    output, unused_err = process.communicate();         
    
    return str(output, "utf-8");       

def findAndPrintAnswerSet(s, i, k):
    """Computes up to k answer set using gringo/clasp and prints them. Returns a tuple (X, Y) where X is the next suggested value for s and Y is the next answer set index.
    
    :param i: The index of the current answer set
    :param s: The parameter for the xor sampling algorithm    
    :param k: The maximum number of answer sets that should be printed
    """
    
    if k < 0:
        return (s, i);
    
    failedIterations = 0;
    timeouts = 0;
    while True:
        atomsToAdd = []; #A set of sets of atoms to add as constraints
        
        #Adds s constraints to the current set of constraints
        for _ in range(0, s):
            atomsToAdd.append(chooseAtoms());                            
        
        #start a new set of constraints
        addConstraints(atomsToAdd, False);
        
        #Solve the current encoding using gringo/clasp
        result = solve(False, k);
        
        if result is None: #Did we hit the timeout?
            if timeouts > 4:
                   s = s + 1;        
            timeouts = timeouts + 1;
        else:            
            if result[0] == 0:
                #We did not find a answer set => decrease the number of constraints if we failed to find a answer set for 5 times
                failedIterations = failedIterations + 1
                if failedIterations > 4:
                    s = s - 1;
            else:                
                answerSets = result[1];   
                             
                #The current index in the list of answer sets
                j = 0;
                
                #We found an answer set
                while j < len(answerSets):
                    print("Answer set " + str(i) + ":");                                        
                    print(answerSets[j]);
                    i = i + 1;
                    j = j + 1;
                
                if result[0] > 10:
                    s = s + 1;
                
                return (s, i);
        
        if s < 1: #Fall-back case
            s = 2                  

if __name__=="__main__":           
    try:
        if os.path.isdir(workingDirectory):
            print ("The working directory " + workingDirectory + " already exists. Aborting...");
            sys.exit();
        else:
            os.makedirs(workingDirectory);    
    except Exception as e:
        print("Failed to prepare the working directory " + workingDirectory + " : " + e.message);
    
    try:    
        parseArgs();                   
        settingQ *= 100;     
                         
        atoms = getAtoms();        
        if len(atoms) == 0:
            print("The grounding of the program contains no atoms. Aborting...");
            sys.exit();
        
        s = estimateS();
        
        #Print settingN answer sets
        i = 1;
        while i <= settingN:
            r = findAndPrintAnswerSet(s, i, min(settingC, (settingN - i) + 1));
            s = r[0];
            i = r[1];                
        
    except Exception as e:
        print("An unexpected exception occurred: " + str(e));
    finally:
        cleanUp()                