-- | A data structure representing an /m/×/n/ matrix and some operations on it.
--
--   The SNF transformation is similar to Gaussian elimination, but using integers,
--   which form a principal ideal domain. There is a handful of operations that we implement
--   as operations on rows, and then using efficient transposition, these operations
--   can be applied to columns.
module Algorithms.CL.SNFMatrix where

import Data.Array
import Control.Exception

-- | A data type to hold an /m/×/n/ integer matrix (/a/[sub /ij/]). 
-- The fields are: /m/, the number of rows; /n/, the number of
-- columns; /a/, the matrix data, and /isTranspose/, a flag to
-- indicate whether the matrix should be transposed.
data SNFMatrix a = SNFMatrix Int Int (Array (Int, Int) a) Bool deriving (Show)

-- | Transpose a matrix.
transpose :: SNFMatrix a -> SNFMatrix a
transpose (SNFMatrix rows cols mtx trans) = SNFMatrix rows cols mtx (not trans)

-- | Get the number of rows in the matrix.
rows :: SNFMatrix a -> Int
rows (SNFMatrix rows cols mtx trans) = if trans then cols else rows

-- | Get the number of columns in the matrix.
cols :: SNFMatrix a -> Int
cols (SNFMatrix rows cols mtx trans) = if trans then rows else cols

-- | Get a list of row indices for a given matrix.
row_list :: SNFMatrix a -> [Int]
row_list m = [0..((rows m)-1)]

-- | Get a list of column indices for a given matrix.
col_list :: SNFMatrix a -> [Int]
col_list m = [0..((cols m)-1)]

-- | Get an index tuple into the matrix given row and column.
idx :: SNFMatrix a -> Int -> Int -> (Int, Int)
idx (SNFMatrix rows cols mtx trans) i j = if trans then (j,i) else (i,j)

-- | Get a matrix element.
mtx_elem :: (Show a) => SNFMatrix a -> Int -> Int -> a
mtx_elem m@(SNFMatrix rows cols mtx trans) i j =
    if (ri >= rows || rj >= cols)
        then throw (IndexOutOfBounds ("Bad index i=" ++ show i ++ ",j=" ++ show j ++ ",m=" ++ show m))
        else mtx ! (ri,rj)
    where (ri, rj)  = idx m i j

-- | Multiply a row of the matrix by a multiplier.
mulrow :: (Show a, Num a) => SNFMatrix a -> Int -> a -> SNFMatrix a
mulrow m@(SNFMatrix rows cols mtx trans) row mul = 
    SNFMatrix rows cols mtx' trans
    where
        new_elem i j = (mtx_elem m i j) * mul
        mtx'  = mtx // [ ((idx m row j), new_elem row j) | j <- col_list m ]

-- | Swap two rows of a matrix.
swaprows :: (Show a) => SNFMatrix a -> Int -> Int -> SNFMatrix a
swaprows m@(SNFMatrix rows cols mtx trans) row1 row2 =
    SNFMatrix rows cols mtx'' trans
    where
        mtx'  = mtx  // [ ((idx m row1 j), mtx_elem m row2 j) | j <- col_list m ]
        mtx'' = mtx' // [ ((idx m row2 j), mtx_elem m row1 j) | j <- col_list m ]

-- | Add a multiple of a row to another row.
addmulrow :: (Show a, Num a) => SNFMatrix a -> Int -> Int -> a -> SNFMatrix a
addmulrow m@(SNFMatrix rows cols mtx trans) rowtgt rowsrc mul = 
    SNFMatrix rows cols mtx' trans
    where
        new_elem   j = (mtx_elem m rowsrc j) * mul + (mtx_elem m rowtgt j)
        mtx' = mtx // [ ((idx m rowtgt j), new_elem j ) | j <- col_list m ]

-- | Print an 'SNFMatrix' for debugging.
printSNF :: (Show a) => SNFMatrix a -> [String]
printSNF m@(SNFMatrix rows cols mtx trans) =
    ["mtx <" ++ show rows ++ "," ++ show cols ++ ">" ] ++
    [ "{" ++ show mtx ++ "}" ] ++
    [ print_row m i | i <- [ 0 .. (rows-1) ] ]
    where
        print_row m i = concat [ " " ++ show (mtx ! (i,j)) | j <- [0 .. (cols-1)] ]

-- | Construct an 'SNFMatrix' from a list.
matrixFromList :: [[a]] -> SNFMatrix a
matrixFromList list@(x:xs) = SNFMatrix rows cols mtx trans
    where
        rows = length list
        cols = length x     -- Assume all inside-lists are of the same length
        trans= False
        mtx  = array ((0, 0),(rows-1, cols-1)) $ read_data 0 list

        read_data i []           = []
        read_data i (line:lines) = (read_line i 0 line) ++ read_data (i+1) lines

        read_line i j []         = []
        read_line i j (x:xs)     = ((i,j), x) : (read_line i (j+1) xs)
matrixFromList [] = error "matrixFromList: empty list"