-- | Code to convert a (non-square) matrix to Smith Normal Form.
module Algorithms.CL.SNF (
    matrixFromList,
    structureConstantsFromMatrix,
    classNumberFromMatrix,
    testSNF
) where

import Algorithms.CL.SNFMatrix
import Algorithms.CL.Auxiliary
import Data.Array
import Data.Maybe
import Data.List (find)
import Control.Exception

-- | If the pivot does not divide an entry in another row, apply transformations
--   so that it does.
improvePivot :: (Show int, Integral int) => SNFMatrix int -> Int -> SNFMatrix int
improvePivot m j =
    foldl tryImprovePivot m [ k | k <- row_list m, k /= j ]
    where
        tryImprovePivot m k =
            -- If already divides, no need to improve
            if (m_j_j `divides` m_k_j) then m else m'
            where
                m' = addmulrow (mulrow m j sigma) j k tau
                (beta, sigma, tau) = extendedEuclid m_j_j m_k_j
                m_j_j = mtx_elem m j j
                m_k_j = mtx_elem m k j

-- | Eliminate a column assuming m(j,j) `divides` m(<any>,j).
eliminateCol :: (Show int, Integral int) => SNFMatrix int -> Int -> SNFMatrix int
eliminateCol m j =
    foldl eliminateEntry (improvePivot m j) [ k | k <- row_list m, k /= j ]
    where
        eliminateEntry :: (Show int, Integral int) => SNFMatrix int -> Int -> SNFMatrix int
        eliminateEntry m k =
            if not (m_j_j `divides` m_k_j)
                then error ("not dividies! j=" ++ show j ++ ",k=" ++ show k ++ ", matrix:" ++ (unlines $ printSNF m))
                else addmulrow m k j (-(m_k_j `div` m_j_j))
            where
                m_j_j = mtx_elem m j j
                m_k_j = mtx_elem m k j

-- | Check if a given column has zero in all entries except row entry.
isColZeroExcept :: (Show int, Integral int) => SNFMatrix int -> Int -> Int -> Bool
isColZeroExcept m row col =
    if (col >= cols m) then True else
        all (== 0) [ mtx_elem m k col | k <- row_list m, k /= row ]

-- | Eliminate both row and column leaving nonzero value at (j,j).
eliminateColRow :: (Show int, Integral int) => SNFMatrix int -> Int -> SNFMatrix int
eliminateColRow m j =  m_reduced
    where
        m_with_j_j =
            if ((mtx_elem m j j) /= 0)
            then m                         -- Have nonzero element at m(j,j)
            else if (isColZeroExcept m j j)
                 then m                    -- The column is all zeros, nothing to do.
                 else findAndSwap m j

        -- by this point m(j,j) is nonzero, can do reductions.
        m_reduced = reduce m_with_j_j

        reduce m =
            if (isColZeroExcept m j j && isColZeroExcept (transpose m) j j)
            then m
            else reduce $ transpose $ eliminateCol m j

        findAndSwap m j =
            -- Assuming at least one element is nonzero, otherwise wouldn't be
            -- called.
            swaprows m j k
            where k = fromJust $ find (\k -> mtx_elem m k j /= 0)
                                    [ k | k <- row_list m, k /= j ]


-- | Make the diagonal SNF matrix, but do not sort the diagonal elements. Thus
--   this is not a proper Smith Normal Form, but sufficient for our purpose.
makeDiagonalSNFLikeMatrix :: (Show int, Integral int) => SNFMatrix int -> SNFMatrix int
makeDiagonalSNFLikeMatrix m =
    foldl (eliminateColRow) (m) (col_list m)

-- | Compute the structure constants from a matrix by re-expressing the matrix
--   in Smith Normal Form and extracting the (nonzero) diagonal. Note that the
--   structure constants are not sorted according to the definition of SNF.
--   This is because for this algorithm we are interested in their product, so
--   order does not matter.
structureConstantsFromMatrix :: (Show int, Integral int) => SNFMatrix int -> [int]
structureConstantsFromMatrix m =
    map abs $ filter (/= 0) diagonal
    where
        diagonal = [ mtx_elem snf_m k k | k <- [0 .. (min (rows snf_m) (cols snf_m))-1] ]
        snf_m    = makeDiagonalSNFLikeMatrix m

-- | Compute the class number from a matrix by re-expressing the matrix
--   in Smith Normal Form and taking the product of the nonzero entries on
--   the diagonal.
classNumberFromMatrix :: (Show int, Integral int) => SNFMatrix int -> int
classNumberFromMatrix m =
    foldl (*) 1 (structureConstantsFromMatrix m)


testData :: [[Int]]
testData = [
    [  8,  16, 16 ],
    [ 32,  6,  12 ],
    [  8, -4, -16 ]
 ]

testData2 :: [[Int]]
testData2 = [
    [ 5, 1, 5, 253, 15, -725, 1 ],
    [ 253,2,1001,11,23,273,14079 ],
    [ 1,-185861,-28,11,91,29,-2717 ],
    [ -319,1,-19,11,3146,1,-1 ],
    [ 19285,-493,145,25,-1482,1,6647]
 ]

testData3 :: [[Int]]
testData3 = [
    [ 4, 8, 4 ],
    [ 8, 4, 8 ]
 ]

-- | Test the Smith Normal Form code.
testSNF :: IO()
testSNF = do
    let m = matrixFromList testData3
--    putStrLn $ unlines $ printSNF $ m
--    putStrLn $ unlines $ printSNF $ mulrow m 0 10
--    putStrLn $ unlines $ printSNF $ mulrow m 1 20
--    putStrLn $ unlines $ printSNF $ mulrow m 2 30
--    putStrLn $ unlines $ printSNF $ swaprows m 0 1
--    putStrLn $ unlines $ printSNF $ swaprows m 2 1
--    putStrLn $ unlines $ printSNF $ eliminateCol m 0
--    putStrLn $ unlines $ printSNF $ eliminateColRow m 0
    putStrLn $ show $ structureConstantsFromMatrix m
    putStrLn $ show $ classNumberFromMatrix m
--    putStrLn $ show $ isColZeroExcept (transpose $ transpose $ transpose m) 0 1