#!/usr/bin/env Rscript

library(parallel)

bindir <- ""

# Utility function.
strsplit1 <- function(x, split = "") {
  
  # Split input string by split, extract top element from list.
  s <- strsplit(x, split = split)[[1]]
  
  # Return character vector of substrings.
  return(s)
}

# Utility function.
generateAlphabet <- function() {
  alphabet <- c("A", "R", "N", "D", "C", "Q", "E", "G", "H",
                "I", "L", "K", "M", "F", "P", "S", "T", "W", "Y", "V")
  return(alphabet)
}

# Function for tabulating amino acid counts.
countAA <- function(df) {
  
  alphabet <- generateAlphabet()
  
  counts <- lapply(df$Alignment, function(x)
    table(factor(strsplit1(x, ""), levels = alphabet)))
  counts <- do.call(rbind.data.frame, counts)
  rownames(counts) <- df$IDs
  colnames(counts) <- alphabet
  
  return(counts)
}

# Function for reading in multiple sequence alignment.
readAln <- function(infile, format = "phylip-interleaved") {
  
  # Read in file.
  lines <- readLines(infile)
  
  # Parsing multiple sequence alignment in sequential PHYLIP format.
  if (any(format == "phylip", format == "phylip-sequential")) {
    # Extract number of sequences and alignment length.
    dimensions <- strsplit1(lines[c(1)], " ")
    dimensions <- dimensions[nchar(dimensions) > 0]
    dimensions <- as.numeric(dimensions)
    
    # Split ID and sequence fields in each line by whitespace,
    # bind each split pair as a row in a data.frame.
    df <- data.frame(do.call(rbind, strsplit(lines[-c(1)], "\\s+")))
  }
  
  # Parsing multiple sequence alignment in interleaved PHYLIP format.
  else if (format == "phylip-interleaved") {

    # Extract number of sequences and alignment length.
    dimensions <- strsplit1(lines[c(1)], " ")
    dimensions <- dimensions[nchar(dimensions) > 0]
    dimensions <- as.numeric(dimensions)
    # Extract ID and sequence lines while removing dimensions in header, and
    # filter to remove empty lines (lines with length = 0).
    lines <- Filter(length, strsplit(lines[-c(1)], "\\s+"))
    
    # Extract IDs from lines by extracting the first element of the first n
    # lines, for n = number of sequences.
    id <- sapply(lines[seq_len(dimensions[1])], "[[", 1)
    
    # Extract sequence from lines by taking every ith line out of n lines,
    # from i..n+i to n..n+n where i is the ith sequence out of n sequences,
    # removing sequence IDs and pasting all sublines together into a single
    # aligned sequence.
    alignment <- lapply(seq_len(dimensions[1]), function(i) {
      paste0(unlist(lines[seq(i, length(lines), by = dimensions[1])])[-c(1)],
             collapse = "")
    })
    
    # Bind ID and alignment columns together in a data.frame.
    df <- data.frame(cbind(id, alignment))
  }
  
  # Parsing multiple sequence alignment in FASTA format.
  else if (format == "fasta") {
    # Extract sequence header for each sequence by matching ">" in lines,
    # then removing ">" from the match.
    id <- gsub(">", "", lines[grepl(">", lines)])
    
    # Identify which consecutive lines do not start with ">", i.e. must be
    # alignment data (this holds for single- and multi-line FASTA files).
    matches <- which(!grepl(">", lines))
    matches <- split(matches, cumsum(c(1, diff(matches) > 1)))
    
    # For each set of consecutive sequence lines, extract them from lines and
    # paste into a single alignment.
    alignment <- lapply(matches, function(i) paste0(lines[i], collapse = ""))
    
    # Bind ID and alignment columns together in a data.frame.
    df <- data.frame(cbind(id, alignment))
  }
  
  # Add names to dataframe.
  colnames(df) <- c("IDs", "Alignment")
  
  return(df)
}

# Function for writing out multiple sequence alignment.
writeAln <- function(alignment, outfile, format = "phylip") {
  
  ntaxa <- nrow(alignment)
  nsites <- nchar(alignment$Alignment)[1]
  
  if (any(format == "phylip", format == "phylip-sequential",
          format == "phylip-interleaved")) {
    dimensions <- paste0(" ", ntaxa, " ", nsites)
    pad <- lapply(10 - nchar(alignment$IDs),
                  function(x) paste0(rep(" ", x), collapse = ""))
    alignment <- paste0(alignment$IDs, pad, alignment$Alignment)
    lines <- c(dimensions, alignment)
    
    suffix <- ".phy"
  }
  
  else if (format == "fasta") {
    lines <- paste0(">", alignment$IDs, "\n", alignment$Alignment, collapse = "\n")
    
    suffix <- ".fasta"
  }
  
  outfile <- paste0(outfile, suffix)
  
  cat(lines, sep = "\n", file = outfile)
  return(outfile)
}

# Convert sequence IDs in files to 0, 1, 2...
convertIDs <- function(msa_path, treefile_path, rootfile_path, format) {
  
  msa <- readAln(msa_path, format = format)
  treefile <- readLines(treefile_path)
  rootfile <- readLines(rootfile)
  
  
  new_ids <- seq(0, nrow(msa) - 1)
  
  for (i in seq(new_ids)) {
    treefile <- gsub(msa$IDs[i], new_ids[i], treefile)
    rootfile <- gsub(msa$IDs[i], new_ids[i], rootfile)
  }
  
  msa$IDs <- new_ids
  
  con_msa_path <- paste0(msa_path, "_INT")
  con_treefile_path <- paste0(treefile_path, "_INT")
  con_rootfile_path <- paste0(rootfile_path, "_INT")
  
  writeAln(msa, con_msa_path, format)
  write(treefile, con_treefile_path)
  write(rootfile, con_rootfile_path)
  
  prnt <- paste("Converted IDs in input alignment, tree and rootfile to 0-index integers.",
                "File paths:", con_msa_path, ",", con_treefile_path,
                " and", con_rootfile_path, ".")
  write(prnt, stdout())
}

# Running GFmix within GFselector.
GFmix <- function(msa_path, mixture, tree, iqtree, rootfile,
                  g_class, f_class, plusF = FALSE, optimizer = FALSE,
                  afo = NULL, params = NULL) {
  
  program <- paste0(bindir, "gfmix_for_opt")
    
  g_class <- paste0(g_class, collapse = "")
  f_class <- paste0(f_class, collapse = "")
  
  cmd <- paste(program, "-s", msa_path, "-t", tree, "-i", iqtree,
               "-f", mixture, "-r", rootfile, "-gclass", g_class,
               "-fclass", f_class)
  
  if(!is.null(afo)) {
    afo <- paste0(afo, collapse = ",")
    cmd <- paste(cmd, "-afo", afo)
  }
  
  if (!plusF) {
    cmd <- paste(cmd, "-d")
  }
  if (!is.null(params)) {
    cmd <- paste(cmd, "-x", params)
  }
  
  lnl <- system(cmd, intern = TRUE)
  lnl <- as.numeric(lnl)
  
  return(lnl)
  
}

# Function for running binomial test.
binomialBinning <- function(dendrogram = NULL, group_1 = NULL, counts, critical = 1.96) {
  
  alphabet <- generateAlphabet()
  
  if (is.null(group_1)) {
    clusters <- cutree(dendrogram, 2)
    clusters <- lapply(seq(unique(clusters)), function(i) names(clusters[clusters == i]))
    
    group_1 <- unlist(clusters[which.max(lengths(clusters))])
    group_2 <- unlist(clusters[!unlist(lapply(clusters, function(x) all(x %in% group_1)))])
  }
  
  else {
    group_2 <- rownames(counts)[!rownames(counts) %in% group_1]
  }
  
  group_1_prnt <- paste("Group 1: ", paste(group_1, collapse = ", "))
  group_2_prnt <- paste("Group 2: ", paste(group_2, collapse = ", "))
  write(group_1_prnt, stdout())
  write(group_2_prnt, stdout())
  
  x1 <- colSums(counts[group_1,])
  x2 <- colSums(counts[group_2,])
  
  n1 <- sum(x1)
  n2 <- sum(x2)
  
  pdiff <- (x1 / n1) - (x2 / n2)
  phat <- (x1 + x2) / (n1 + n2)
  
  se <- sqrt(phat * (1 - phat) * (1/n1 + 1/n2))
  
  zscores <- sort(pdiff / se, decreasing = TRUE)
  
  AAs <- names(zscores)
  
  assignment <- ifelse(zscores > critical, "G-class",
                       ifelse(zscores < -critical, "F-class", "O-class"))
  
  zscores <- cbind("AA" = AAs, "Z-score" = round(zscores, 2), "Assignment" = assignment)
  
  return(zscores)
}

# Generate all 1 amino acid swaps between G/O/F classes.
generateSwaps <- function(g_class, o_class, f_class) {
  
  gswap <- seq(g_class)
  fswap <- seq(f_class)
  oswap <- seq(o_class)
  
  gto <- lapply(gswap, function(i) {
    no <- c(g_class[i], o_class)
    ng <- g_class[-c(i)]
    nf <- f_class
    
    if (any(is.null(ng), is.null(nf))) {
      swap <- NULL
    }
    else {
      swap <- list("G" = ng, "O" = no, "F" = nf)
    }
    
    return(swap)
    
  })
  gto <- do.call(rbind, gto)
  
  otg <- lapply(oswap, function(i) {
    ng <- c(g_class, o_class[i])
    no <- o_class[-c(i)]
    nf <- f_class
    
    if (any(is.null(ng), is.null(nf))) {
      swap <- NULL
    }
    else {
      swap <- list("G" = ng, "O" = no, "F" = nf)
    }
    
    return(swap)
    
  })
  otg <- do.call(rbind, otg)
  
  otf <- lapply(oswap, function(i) {
    ng <- g_class
    nf <- c(o_class[i], f_class)
    no <- o_class[-c(i)]
    
    if (any(is.null(ng), is.null(nf))) {
      swap <- NULL
    }
    else {
      swap <- list("G" = ng, "O" = no, "F" = nf)
    }
    
    return(swap)
    
  })
  otf <- do.call(rbind, otf)
  
  fto <- lapply(fswap, function(i) {
    ng <- g_class
    no <- c(o_class, f_class[i])
    nf <- f_class[-c(i)]
    
    if (any(is.null(ng), is.null(nf))) {
      swap <- NULL
    }
    else {
      swap <- list("G" = ng, "O" = no, "F" = nf)
    }
    
    return(swap)
  })
  fto <- do.call(rbind, fto)
  
  starting <- list("G" = g_class, "O" = o_class, "F" = f_class)
  starting <- rbind(starting)
  
  swaps <- rbind.data.frame("Starting" = starting, gto, otg, otf, fto)
  
  return(swaps)
  
}

# Chi2 optimization routine.
optimizeByChi <- function(msa, binomial_test, garpfymink = FALSE, limit = NULL,
                          track_swaps = FALSE, threads = 1) {
  
  counts <- countAA(msa)
  
  if (garpfymink) {
    g_class <- c("G", "A", "R", "P")
    f_class <- c("F", "Y", "M", "I", "N", "K")
    o_class <- generateAlphabet()[!generateAlphabet() %in% c(g_class, f_class)]
  }
  else{
    g_class <- binomial_test$AA[binomial_test$Assignment == "G-class"]
    o_class <- binomial_test$AA[binomial_test$Assignment == "O-class"]
    f_class <- binomial_test$AA[binomial_test$Assignment == "F-class"]
  }
  
  
  if (is.null(limit)) {
    limit <- Inf
  }
  
  optimize <- TRUE
  itr <- 1
  opt_time <- Sys.time()
  
  itr_table <- data.frame()
  
  while(optimize) {
    
    itr_time <- Sys.time()
    
    swaps <- generateSwaps(g_class, o_class, f_class)
    
    nswaps <- seq(nrow(swaps))
    
    opt_prnt <- paste0("Iteration ", itr, ": Assessing ",
                       length(nswaps), " swaps from starting point ",
                       paste0(g_class, collapse = ""), "/",
                       paste0(o_class, collapse = ""), "/",
                       paste0(f_class, collapse = ""), ".")
    
    write(opt_prnt, stdout())
    
    swaps$Criterion <- mclapply(nswaps, function(i) {
      
      swap <- swaps[i,]
      
      g_counts <- counts[unlist(swap$G), drop = FALSE]
      f_counts <- counts[unlist(swap$F), drop = FALSE]
      
      g_counts <- rowSums(g_counts)
      f_counts <- rowSums(f_counts)
      
      swap_time <- Sys.time()
      chi2 <- chisq.test(cbind(g_counts, f_counts))
      chi2 <- as.numeric(chi2$statistic)
      swap_time <- as.numeric(difftime(Sys.time(), swap_time ,units = "secs"))
      
      return(list("Chi2" = chi2, "Swap" = swap_time))
    }, mc.cores = threads)
    
    swap_times <- unlist(lapply(swaps$Criterion, function(x) x$Swap))
    swaps$Criterion <- unlist(lapply(swaps$Criterion, function(x) x$Chi2))
    
    best <- swaps[which.max(swaps$Criterion),]
    starting <- swaps["Starting",]
    
    if (unlist(best$Criterion) > unlist(starting$Criterion)) {
      
      g_class <- unlist(best$G)
      o_class <- unlist(best$O)
      f_class <- unlist(best$F)
      
      itr_time <- as.numeric(difftime(Sys.time(), itr_time ,units = "secs"))
      curr_time <- as.numeric(difftime(Sys.time(), opt_time ,units = "secs"))
      
      if (itr == limit) {
        
        opt_time <- curr_time
        
        opt_prnt <- paste0("Iteration ", itr, ": G/O/F at iteration limit: ",
                           paste0(g_class, collapse = ""), "/",
                           paste0(o_class, collapse = ""), "/",
                           paste0(f_class, collapse = ""), ". ",
                           "Iteration time: ", round(itr_time, 3), ", ",
                           "Final optimization time: ", round(opt_time, 3), ".")
        
        itr_row <- cbind.data.frame(best, "Optimizer" = "Chi2",
                                    "Iterations" = itr, "Iteration time" = itr_time,
                                    "Optimization time" = opt_time)
        
        if (track_swaps) {
          itr_row <- cbind.data.frame(itr_row, "Swap times" = I(list(swap_times)))
        }
        
        itr_table <- rbind(itr_table, itr_row)
        
        write(opt_prnt, stdout())
        optimize <- FALSE
        
      } else {
        
        opt_prnt <- paste0("Iteration ", itr, ": Found new starting point ",
                           paste0(g_class, collapse = ""), "/",
                           paste0(o_class, collapse = ""), "/",
                           paste0(f_class, collapse = ""), ". ",
                           "Iteration time (s): ", round(itr_time, 3), ", ",
                           "Current optimization time (s): ", round(curr_time, 3), ".")
        
        itr_row <- cbind.data.frame(best, "Optimizer" = "Chi2",
                                    "Iterations" = itr, "Iteration time" = itr_time,
                                    "Optimization time" = curr_time)
        
        if (track_swaps) {
          itr_row <- cbind.data.frame(itr_row, "Swap times" = I(list(swap_times)))
        }
        
        itr_table <- rbind(itr_table, itr_row)
        
        itr <- itr + 1
        write(opt_prnt, stdout())
      }
      
      
    } else {
      g_class <- unlist(starting$G)
      o_class <- unlist(starting$O)
      f_class <- unlist(starting$F)
      
      itr_time <- as.numeric(difftime(Sys.time(), itr_time ,units = "secs"))
      opt_time <- as.numeric(difftime(Sys.time(), opt_time ,units = "secs"))
      
      opt_prnt <- paste0("Iteration ", itr, ": G/O/F at optimization limit: ",
                         paste0(g_class, collapse = ""), "/",
                         paste0(o_class, collapse = ""), "/",
                         paste0(f_class, collapse = ""), ". ",
                         "Iteration time: ", round(itr_time, 3), ", ",
                         "Final optimization time: ", round(opt_time, 3), ".")
      
      itr_row <- cbind.data.frame(starting, "Optimizer" = "Chi2",
                                  "Iterations" = itr, "Iteration time" = itr_time,
                                  "Optimization time" = opt_time)
      
      if (track_swaps) {
        itr_row <- cbind.data.frame(itr_row, "Swap times" = I(list(swap_times)))
      }
      
      itr_table <- rbind(itr_table, itr_row)
      
      write(opt_prnt, stdout())
      optimize <- FALSE
    }
    
  }
  
  rownames(itr_table) <- seq(nrow(itr_table))
  return(itr_table)
}

# "Original GFmix" optimization routine.
optimizeByOGF <- function(msa_path, binomial_test, garpfymink = FALSE, limit = NULL,
                          tree = NULL, iqtree = NULL, rootfile = NULL,
                          mixture = NULL, track_swaps = FALSE, threads = 1) {
  
  if (garpfymink) {
    g_class <- c("G", "A", "R", "P")
    f_class <- c("F", "Y", "M", "I", "N", "K")
    o_class <- generateAlphabet()[!generateAlphabet() %in% c(g_class, f_class)]
  }
  else{
    g_class <- binomial_test$AA[binomial_test$Assignment == "G-class"]
    o_class <- binomial_test$AA[binomial_test$Assignment == "O-class"]
    f_class <- binomial_test$AA[binomial_test$Assignment == "F-class"]
  }
  
  if (is.null(limit)) {
    limit <- Inf
  }
  
  optimize <- TRUE
  itr <- 1
  opt_time <- Sys.time()
  
  itr_table <- data.frame()
  
  while(optimize) {
    
    itr_time <- Sys.time()
    
    swaps <- generateSwaps(g_class, o_class, f_class)
    
    nswaps <- seq(nrow(swaps))
    
    opt_prnt <- paste0("Iteration ", itr, ": Assessing ",
                       length(nswaps), " swaps from starting point ",
                       paste0(g_class, collapse = ""), "/",
                       paste0(o_class, collapse = ""), "/",
                       paste0(f_class, collapse = ""), ".")
    
    write(opt_prnt, stdout())
    
    swaps$Criterion <- mclapply(nswaps, function(i) {
      
      swap <- swaps[i,]
      g_class <- unlist(swap$G)
      f_class <- unlist(swap$F)
      
      swap_time <- Sys.time()
      
      lnl <- GFmix(msa_path, mixture, tree, iqtree, rootfile, g_class, f_class,
                   plusF = FALSE, optimizer = TRUE)
      
      swap_time <- as.numeric(difftime(Sys.time(), swap_time ,units = "secs"))
      
      return(list("lnL" = lnl, "Swap" = swap_time))
    }, mc.cores = threads)
    
    swap_times <- unlist(lapply(swaps$Criterion, function(x) x$Swap))
    swaps$Criterion <- unlist(lapply(swaps$Criterion, function(x) x$lnL))
    
    best <- swaps[which.max(swaps$Criterion),]
    starting <- swaps["Starting",]
    
    if (unlist(best$Criterion) > unlist(starting$Criterion)) {
      
      g_class <- unlist(best$G)
      o_class <- unlist(best$O)
      f_class <- unlist(best$F)
      
      itr_time <- as.numeric(difftime(Sys.time(), itr_time ,units = "secs"))
      curr_time <- as.numeric(difftime(Sys.time(), opt_time ,units = "secs"))
      
      if (itr == limit) {
        
        opt_time <- curr_time
        
        opt_prnt <- paste0("Iteration ", itr, ": G/O/F at iteration limit: ",
                           paste0(g_class, collapse = ""), "/",
                           paste0(o_class, collapse = ""), "/",
                           paste0(f_class, collapse = ""), ". ",
                           "Iteration time: ", round(itr_time, 3), ", ",
                           "Final optimization time: ", round(opt_time, 3), ".")
        
        itr_row <- cbind.data.frame(best, "Optimizer" = "GFmix",
                                    "Iterations" = itr, "Iteration time" = itr_time,
                                    "Optimization time" = curr_time)
        
        if (track_swaps) {
          itr_row <- cbind.data.frame(itr_row, "Swap times" = I(list(swap_times)))
        }
        
        itr_table <- rbind(itr_table, itr_row)
        
        write(opt_prnt, stdout())
        optimize <- FALSE
        
      } else {
        
        opt_prnt <- paste0("Iteration ", itr, ": Found new starting point ",
                           paste0(g_class, collapse = ""), "/",
                           paste0(o_class, collapse = ""), "/",
                           paste0(f_class, collapse = ""), ". ",
                           "Iteration time (s): ", round(itr_time, 3), ", ",
                           "Current optimization time (s): ", round(curr_time, 3), ".")
        
        itr_row <- cbind.data.frame(best, "Optimizer" = "GFmix",
                                    "Iterations" = itr, "Iteration time" = itr_time,
                                    "Optimization time" = curr_time)
        
        if (track_swaps) {
          itr_row <- cbind.data.frame(itr_row, "Swap times" = I(list(swap_times)))
        }
        
        itr_table <- rbind(itr_table, itr_row)
        
        itr <- itr + 1
        write(opt_prnt, stdout())
      }
      
      
    } else {
      g_class <- unlist(starting$G)
      o_class <- unlist(starting$O)
      f_class <- unlist(starting$F)
      
      itr_time <- as.numeric(difftime(Sys.time(), itr_time ,units = "secs"))
      opt_time <- as.numeric(difftime(Sys.time(), opt_time ,units = "secs"))
      
      opt_prnt <- paste0("Iteration ", itr, ": G/O/F at optimization limit: ",
                         paste0(g_class, collapse = ""), "/",
                         paste0(o_class, collapse = ""), "/",
                         paste0(f_class, collapse = ""), ". ",
                         "Iteration time: ", round(itr_time, 3), ", ",
                         "Final optimization time: ", round(opt_time, 3), ".")
      
      itr_row <- cbind.data.frame(starting, "Optimizer" = "GFmix",
                                  "Iterations" = itr, "Iteration time" = itr_time,
                                  "Optimization time" = opt_time)
      
      if (track_swaps) {
        itr_row <- cbind.data.frame(itr_row, "Swap times" = I(list(swap_times)))
      }
      
      itr_table <- rbind(itr_table, itr_row)
      
      write(opt_prnt, stdout())
      optimize <- FALSE
    }
    
  }
  
  rownames(itr_table) <- seq(nrow(itr_table))
  return(itr_table)
  
}

# Format optimization output.
formatOptTable <- function(optimization) {
  
  optimization$G <- lapply(optimization$G, function(i) paste(i, collapse = ""))
  optimization$O <- lapply(optimization$O, function(i) paste(i, collapse = ""))
  optimization$F <- lapply(optimization$F, function(i) paste(i, collapse = ""))
  
  optimization$`Swap times` <- lapply(optimization$`Swap times`, function(i) paste(i, collapse = ","))
  optimization <- apply(optimization, 2, as.character)
  
  return(optimization)
  
}

# Handling arguments.
args <- commandArgs()
iarg <- length(args)

# Handling arguments.
seqfile <- groupfile <- iqtreefile <- treefile <- rootfile <- NULL
mixture <- frmt <- criterion <- limit <- NULL
format <- "phylip"
criterion <- "bi"
threads <- 1
outfile <- stdout()
convert <- FALSE

# Read arguments from command line.
while(iarg>=5){
  if(substring(args[iarg],1,1)=='-'){
    opt <- args[iarg]
    is.opt <- TRUE
  }else{
    val <- args[iarg];
    is.opt <- FALSE
  }
  if(is.opt){
    not.an.option <- TRUE
    if(opt=="-aln"){
      seqfile <- val; not.an.option <- FALSE
    }
    if(opt=="-fmt"){
      format <- val; not.an.option <- FALSE
    }
    if (opt=="-crit"){
      criterion <- val; not.an.option <- FALSE
    }
    if (opt=="-group"){
      groupfile <- val; not.an.option <- FALSE
    }
    if (opt=="-mix"){
      mixture <- val; not.an.option <- FALSE
    }
    if (opt=="-iqtree"){
      iqtreefile <- val; not.an.option <- FALSE
    }
    if (opt=="-tree"){
      treefile <- val; not.an.option <- FALSE
    }
    if (opt=="-root"){
      rootfile <- val; not.an.option <- FALSE
    }
    if (opt=="-iter"){
      limit <- as.numeric(val); not.an.option <- FALSE
    }
    if (opt=="-out"){
      outfile <- val; not.an.option <- FALSE
    }
    if (opt=="-conv"){
      convert <- TRUE; not.an.option <- FALSE
    }
    if (opt=="-cpu") {
      threads <- val; not.an.option <- FALSE
    }
  }
  iarg <- iarg-1
}

# Other arguments.
garpfymink <- FALSE
track_swaps <- TRUE

# Check that necessary input files for GFmix optimization have been specified.
if (criterion %in% c("gfmix")) {
  if (any(is.null(mixture), is.null(treefile),
          is.null(iqtreefile), is.null(rootfile))) {
    stop(paste0("Check command line: input missing for SS/lnL optimization",
                " (-mix, -iqtree, -tree or -root)."))
  }
}

# Read alignment and count amino acids.
alignment <- readAln(seqfile, format = format)
counts <- countAA(alignment)

# If no groupfile is specified, cluster taxa into groups based on Chi2 residuals.
if (!is.null(groupfile)) {
  group_1 <- strsplit1(readLines(groupfile), " ")
  
  binomial_test <- data.frame(binomialBinning(group_1 = group_1, counts = counts))
} else {
  chi_residuals <- chisq.test(counts)$residuals
  chi_dend <- hclust(dist(chi_residuals), method = "average")
  
  # Run binomial test and present results.
  binomial_test <- data.frame(binomialBinning(dendrogram = chi_dend, counts = counts))
}

# Convert IDs to 0-th index integers, if specified.
if (convert) {
  
  convertIDs(seqfile, treefile, rootfile, format)
  treefile <- paste0(treefile, "_INT")
  rootfile <- paste0(rootfile, "_INT")
  if (format == "phylip" || format == "phylip-interleaved") {
    seqfile <- paste0(seqfile, "_INT.phy")
  }
  else {
    seqfile <- paste0(seqfile, "_Int.fasta")
  }
  
}

# Handle G/F class determination.
if (criterion == "bi") {
  
  g_class <- binomial_test[binomial_test$Assignment == "G-class",]$AA
  o_class <- binomial_test[binomial_test$Assignment == "O-class",]$AA
  f_class <- binomial_test[binomial_test$Assignment == "F-class",]$AA
  
  optimization <- cbind.data.frame("G" = I(list(g_class)), "O" = I(list(o_class)),
                                   "F" = I(list(f_class)), "Criterion" = 0,
                                   "Optimizer" = "Binomial", "Iterations" = 0,
                                   "Iteration time" = 0, "Optimization time" = 0,
                                   "Swap times" = I(list(0)))
  
} else if (criterion == "chi") {
  optimization <- optimizeByChi(msa = alignment, binomial_test = binomial_test,
                                garpfymink = garpfymink,
                                limit = limit, track_swaps = track_swaps,
                                threads = threads)
} else if (criterion == "gfmix") {
  optimization <- optimizeByOGF(msa_path = seqfile,
                                binomial_test = binomial_test,
                                garpfymink = garpfymink,
                                limit = limit, track_swaps = track_swaps,
                                tree = treefile,
                                iqtree = iqtreefile, rootfile = rootfile,
                                mixture = mixture,
                                threads = threads)
}

# Format optimization output.
optimization <- formatOptTable(optimization)

# If output defaults to 1col instead of 1row (for binomial and sometimes chi2), transpose it.
# This is stupid and should be a temporary fix (something to do with apply?).
if (is.null(ncol(optimization))) {
  optimization <- t(optimization)
}

# Write optimization output to outfile (or screen).
write.table(optimization, file = outfile, sep = "\t", quote = FALSE, row.names = FALSE)
