# BrightSeq V0.0.2 --- by Andre Zenz --- free under GPLv3
# https://zenzlab.com/brightseq

import argparse
import itertools

CONST_VERSION = "V0.0.3"

CONST_ADENOSINE_MASS = 267.241
CONST_GUANOSINE_MASS = 283.241
CONST_CYTIDINE_MASS = 243.217
CONST_THYMIDINE_MASS = 258.228
CONST_BTADINE_MASS = 362.199
CONST_DEOXYADENOSINE_MASS = 251.242
CONST_DEOXYGUANOSINE_MASS = 267.241
CONST_DEOXYCYTIDINE_MASS = 227.217
CONST_DEOXYTHYMIDINE_MASS = 242.228
CONST_DEOXYBTADINE_MASS = 344.967
CONST_PHOSPHATE_MASS = 61.947

modifications = []
modMasses = []
nucleosides=['C','A','G','T']
nucMasses=[CONST_DEOXYCYTIDINE_MASS, CONST_DEOXYADENOSINE_MASS, CONST_DEOXYGUANOSINE_MASS, CONST_DEOXYTHYMIDINE_MASS]  


def seqToOccurances(inSeq):
    occurance = ""
    for nuc in nucleosides:
        
        occurance = occurance+ (str(inSeq.count(nuc))+"x"+str(nuc)+"\t")
    return occurance

def seqConstitution(inSeq):
    constitution = []
    for letter in nucleosides:
        constitution.append(inSeq.count(letter))
    return constitution;
    
def calcMass(inSeq):
    if(len(inSeq)==0):
        return 0
    retMass = 0;
    
    for letter in inSeq:          
        retMass += (nucMasses[nucleosides.index(letter)]+CONST_PHOSPHATE_MASS)
    retMass-=CONST_PHOSPHATE_MASS
    if(argument.phosphorylated):
        retMass+= 77.974
    
    return retMass

def calcMasszOLD(inSeq):
    seqMass = inSeq.count('B')*(CONST_DEOXYBTADINE_MASS+CONST_PHOSPHATE_MASS)+inSeq.count('G')*(CONST_DEOXYGUANOSINE_MASS+CONST_PHOSPHATE_MASS)+inSeq.count('A')*(CONST_DEOXYADENOSINE_MASS+CONST_PHOSPHATE_MASS)+inSeq.count('C')*(CONST_DEOXYCYTIDINE_MASS+CONST_PHOSPHATE_MASS)+inSeq.count('T')*(CONST_DEOXYTHYMIDINE_MASS+CONST_PHOSPHATE_MASS)-CONST_PHOSPHATE_MASS
    if(argument.phosphorylated):
        seqMass += 77.974
    if(len(inSeq)==0):
        seqMass = 0;
    return seqMass

def expandToWidth(string, width):
    expansion = width - len(string)
    if(expansion>0):
        for i in range(0,expansion) :
            string = string + " ";
    return string;

def printTable(content, headers):
    columnWidth=[]
    for col in content[0]:
        columnWidth.append(0)
    for line in content:
        for col in line:
            if columnWidth[line.index(col)] < len(str(col)):
                columnWidth[line.index(col)] = len(str(col))+2;
    for col in headers:
        if columnWidth[headers.index(col)] < len(str(col)):
                columnWidth[headers.index(col)] = len(str(col))+2;
    tableBorder=""
    for size in columnWidth:
        tableBorder = tableBorder + "+-"
        for i in range(0,size):
            tableBorder = tableBorder + "-"
        tableBorder = tableBorder + "-"
    tableBorder = tableBorder+ "+"
    
    print("")
    print(tableBorder)
    headerLine = "|"
    for name in headers:
        headerLine=headerLine+" "+expandToWidth(str(name),columnWidth[headers.index(name)])+" "
        headerLine=headerLine + "|"
    print(headerLine)
    print(tableBorder)

    for line in content:
        lineBuffer = "| "
        for col in line:
            lineBuffer = lineBuffer+expandToWidth(str(col),columnWidth[line.index(col)])
            lineBuffer = lineBuffer+" | "
        print(lineBuffer)
    print(tableBorder)
    print("")

parser = argparse.ArgumentParser(description = "Description for my parser")
parser.add_argument("-H", "--Help", help = "Show this help text", required = False, default = "")
parser.add_argument("-s", "--sequence", help = "sequence to be matched (default 5->3)", required = False, default = "")
parser.add_argument("-r", "--reverse", help = "reverses sequence for 3'->5' input", required = False, default = False)
parser.add_argument("-ms", "--mass", help = "mass of the measured oligo", required = False, default = False)
parser.add_argument("-msd", "--mass_delta", help = "mass-range to be searched", required = False, default = "1")
parser.add_argument("-fss", "--full_sub_sequences", help = "show masses of sub-sequences that fall outside of the searched mass range", required = False, default = False)
parser.add_argument("-p", "--phosphorylated", help = "5' phosphorylation modification is applied", required = False, default = False)
parser.add_argument("-c", "--custom_modifications", help = "add custom modifications by passing a string like 'M:123.456' where M is the sequence code and 123.456 is the free nucleoside mass. Multiple can be devided by character '+', i.e B:123.000+K:333.213", required = False, default = "")
argument = parser.parse_args()
status = False

if argument.Help:
    print("You have used '-H' or '--Help' with argument: {0}".format(argument.Help))
    status = True
if argument.sequence:
    sequence = argument.sequence
else:
    sequence = ""
if argument.mass:
    mass = float(argument.mass)
else:
    mass = 0
if argument.mass_delta:
    massDelta = float(argument.mass_delta)
if argument.reverse:
    sequence = argument.sequence[::-1]
if argument.custom_modifications:
    mods = argument.custom_modifications.split("+")
    for mod in mods:
        parts = mod.split(":")
        modifications.append(str(parts[0]))
        modMasses.append(float(parts[1]))
    nucleosides.extend(modifications)
    nucMasses.extend(modMasses)
if not status:
    print("Maybe you want to use -H or -s or -p or -o as arguments ?") 


print("")
print("BrightSeq "+CONST_VERSION+" --- by Andre Zenz --- free under GPLv3")
print("")
if(argument.custom_modifications):
    print("Custom Modifications were specified:")
    modHeaders = ["symbol", "mass"]
    modContent = list(zip(modifications, modMasses))
    printTable(modContent, modHeaders)
if(argument.phosphorylated):
    print("5'-phosphorylation modification was selected: all masses shown account for this modification!")
matchMasses = []
matches = []
headers = ["mass", "sequence (3'->5')", "note"]
    


matchMass = []
maxMass = 0
seqLength = 0;
delimiter = "";
matchFound = False;
while maxMass < (int(mass)+1000):
    seqLength +=1
    for seq in itertools.combinations_with_replacement(nucleosides, seqLength):
#       print(delimiter.join(seq))
        cmass = calcMass(seq)
        if (cmass>maxMass):
            maxMass = cmass
        if(cmass > (mass-massDelta) and cmass < (mass+massDelta)):
            matchMasses.append(round(cmass, 3))
            matches.append(seq)
            matchFound = True;
if(not matchFound):
    print("no match found")
else:
    print("A sequence with the following constituent nucleotides would match your search mass of "+str(mass)+" within the given mass error range of "+str(massDelta)+" :")
    constitutionTableContent = []
    for s in matches:
        line = []
        line.append(str(round(calcMass(s),3)))
        line.extend(seqConstitution(s))
        constitutionTableContent.append(line)
    constitutionTableHeaders = nucleosides[:]
    constitutionTableHeaders.insert(0, "mass")
    printTable(constitutionTableContent, constitutionTableHeaders)

if sequence:
    print("You have put in a search sequence of 3'-"+sequence+"-5'")
    print("This sequence has a calculated mass of "+str(round(calcMass(sequence), 3)), end=" ")
    if (calcMass(sequence)> mass-massDelta and calcMass(sequence)< mass+massDelta):
        print("and therefore does matches your input search mass of "+str(mass)+" within the given error of "+str(massDelta)+"")
    else:
        print("and therefore does NOT matches your input search mass of "+str(mass)+" within the given error of "+str(massDelta)+"")
        print("Fragment sequences that are shortened on either the 3' or 5' end were searched for possible mass-matches:")
        trimmedSequences = []
        trimmedOperations = []
        trimmedMatches = []
        trimmedMatchMasses = []
        trimmedMatchOperations = []
        for cursor in range(len(sequence)):
            trimmedSequences.append(sequence[cursor:(len(sequence))])
            trimmedOperations.append("3' trimmed N-"+str(cursor))
            trimmedSequences.append(sequence[0:cursor])
            trimmedOperations.append("5' trimmed N-"+str(len(sequence)-cursor))


        trimmedSequenceMasses = []
        for seq in trimmedSequences:
            cmass = calcMass(seq)
            trimmedSequenceMasses.append(round(cmass,3))
            if(cmass > (mass-massDelta) and cmass < (mass+massDelta)):
                trimmedMatches.append(seq)
                trimmedMatchMasses.append(round(cmass,3))
                trimmedMatchOperations.append(trimmedOperations[trimmedSequences.index(seq)])

        if(argument.full_sub_sequences):
            print("   -> Option --full-sub-sequences! Masses of all checked trimmed sequences is listed:")
            sequencesWithMasses = list(sorted(zip(trimmedSequenceMasses, trimmedSequences, trimmedOperations)))
            printTable(sequencesWithMasses, headers)
        else:
            
            if(len(trimmedMatches)==0):
                print("   -> No trimmed sequences matched the search mass")
            else:
                sequencesWithMasses = list(sorted(zip(trimmedMatchMasses, trimmedMatches, trimmedMatchOperations)))
                printTable(sequencesWithMasses, headers)

    