-- | Implements Tarjan's algorithm for computing the strongly connected
-- components of a graph.  For more details see:
-- <http://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm>
{-# LANGUAGE Rank2Types, Trustworthy #-}
module Data.Graph.ArraySCC(scc) where

import Data.Graph(Graph,Vertex)
import Data.Array.ST(STUArray, newArray, readArray, writeArray)
import Data.Array as A
import Data.Array.Unsafe(unsafeFreeze)
import Control.Monad.ST
import Control.Monad(ap)

-- | Computes the strongly connected components (SCCs) of the graph in
-- O(#edges + #vertices) time.  The resulting tuple contains:
--
--   * A (reversed) topologically sorted list of SCCs.
--     Each SCCs is assigned a unique identifier of type 'Int'.
--
--   * An O(1) mapping from vertices in the original graph to the identifier
--     of their SCC.  This mapping will raise an \"out of bounds\"
--     exception if it is applied to integers that do not correspond to
--     vertices in the input graph.
--
-- This function assumes that the adjacency lists in the original graph
-- mention only nodes that are in the graph. Violating this assumption
-- will result in \"out of bounds\" array exception.
scc :: Graph -> ([(Int,[Vertex])], Vertex -> Int)
scc :: Graph -> ([(Int, [Int])], Int -> Int)
scc Graph
g = (forall s. ST s ([(Int, [Int])], Int -> Int))
-> ([(Int, [Int])], Int -> Int)
forall a. (forall s. ST s a) -> a
runST (
  do STUArray s Int Int
ixes <- (Int, Int) -> Int -> ST s (STUArray s Int Int)
forall i. Ix i => (i, i) -> Int -> ST s (STUArray s i Int)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> e -> m (a i e)
newArray (Graph -> (Int, Int)
forall i e. Array i e -> (i, i)
bounds Graph
g) Int
0
     STUArray s Int Int
lows <- (Int, Int) -> Int -> ST s (STUArray s Int Int)
forall i. Ix i => (i, i) -> Int -> ST s (STUArray s i Int)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> e -> m (a i e)
newArray (Graph -> (Int, Int)
forall i e. Array i e -> (i, i)
bounds Graph
g) Int
0
     S
s <- Func s ([Int] -> ST s S)
forall s. Func s ([Int] -> ST s S)
roots Graph
g STUArray s Int Int
ixes STUArray s Int Int
lows ([Int] -> Int -> [(Int, [Int])] -> Int -> S
S [] Int
1 [] Int
1) (Graph -> [Int]
forall i e. Ix i => Array i e -> [i]
indices Graph
g)
     Array Int Int
sccm <- STUArray s Int Int -> ST s (Array Int Int)
forall i (a :: * -> * -> *) e (m :: * -> *) (b :: * -> * -> *).
(Ix i, MArray a e m, IArray b e) =>
a i e -> m (b i e)
unsafeFreeze STUArray s Int Int
ixes
     ([(Int, [Int])], Int -> Int) -> ST s ([(Int, [Int])], Int -> Int)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (S -> [(Int, [Int])]
sccs S
s, \Int
i -> Array Int Int
sccm Array Int Int -> Int -> Int
forall i e. Ix i => Array i e -> i -> e
! Int
i)
  )

type Func s a =
     Graph                    -- The original graph
  -> STUArray s Vertex Int    -- Index in DFS traversal, or SCC for vertex.
    -- Legend for the index array:
    --    0:    Node not visited
    --    -ve:  Node is on the stack with the given number
    --    +ve:  Node belongs to the SCC with the given number
  -> STUArray s Vertex Int    -- Least reachable node
  -> S                        -- State
  -> a

data S = S { S -> [Int]
stack    :: ![Vertex]          -- ^ Traversal stack
           , S -> Int
num      :: !Int               -- ^ Next node number
           , S -> [(Int, [Int])]
sccs     :: ![(Int,[Vertex])]  -- ^ Finished SCCs
           , S -> Int
next_scc :: !Int               -- ^ Next SCC number
           }


roots :: Func s ([Vertex] -> ST s S)
roots :: forall s. Func s ([Int] -> ST s S)
roots Graph
g STUArray s Int Int
ixes STUArray s Int Int
lows S
st (Int
v:[Int]
vs) =
  do Int
i <- STUArray s Int Int -> Int -> ST s Int
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STUArray s Int Int
ixes Int
v
     if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 then do S
s1 <- Func s (Int -> ST s S)
forall s. Func s (Int -> ST s S)
from_root Graph
g STUArray s Int Int
ixes STUArray s Int Int
lows S
st Int
v
                       Func s ([Int] -> ST s S)
forall s. Func s ([Int] -> ST s S)
roots Graph
g STUArray s Int Int
ixes STUArray s Int Int
lows S
s1 [Int]
vs
               else Func s ([Int] -> ST s S)
forall s. Func s ([Int] -> ST s S)
roots Graph
g STUArray s Int Int
ixes STUArray s Int Int
lows S
st [Int]
vs
roots Graph
_ STUArray s Int Int
_ STUArray s Int Int
_ S
s [] = S -> ST s S
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return S
s


from_root :: Func s (Vertex -> ST s S)
from_root :: forall s. Func s (Int -> ST s S)
from_root Graph
g STUArray s Int Int
ixes STUArray s Int Int
lows S
s Int
v =
  do let me :: Int
me = S -> Int
num S
s
     STUArray s Int Int -> Int -> Int -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STUArray s Int Int
ixes Int
v (Int -> Int
forall a. Num a => a -> a
negate Int
me)
     STUArray s Int Int -> Int -> Int -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STUArray s Int Int
lows Int
v Int
me
     S
newS <- Func s (Int -> [Int] -> ST s S)
forall s. Func s (Int -> [Int] -> ST s S)
check_adj Graph
g STUArray s Int Int
ixes STUArray s Int Int
lows
                        S
s { stack = v : stack s, num = me + 1 } Int
v (Graph
g Graph -> Int -> [Int]
forall i e. Ix i => Array i e -> i -> e
! Int
v)

     Int
x <- STUArray s Int Int -> Int -> ST s Int
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STUArray s Int Int
lows Int
v
     if Int
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
me then S -> ST s S
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return S
newS else
       case (Int -> Bool) -> [Int] -> ([Int], [Int])
forall a. (a -> Bool) -> [a] -> ([a], [a])
span (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
v) (S -> [Int]
stack S
newS) of
         ([Int]
as,Int
b:[Int]
bs) ->
           do let this :: [Int]
this = Int
b Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
as
                  n :: Int
n = S -> Int
next_scc S
newS
              (Int -> ST s ()) -> [Int] -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\Int
i -> STUArray s Int Int -> Int -> Int -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STUArray s Int Int
ixes Int
i Int
n) [Int]
this
              S -> ST s S
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return S { stack :: [Int]
stack = [Int]
bs
                       , num :: Int
num = S -> Int
num S
newS
                       , sccs :: [(Int, [Int])]
sccs = (Int
n,[Int]
this) (Int, [Int]) -> [(Int, [Int])] -> [(Int, [Int])]
forall a. a -> [a] -> [a]
: S -> [(Int, [Int])]
sccs S
newS
                       , next_scc :: Int
next_scc = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
                       }
         ([Int], [Int])
_ -> [Char] -> ST s S
forall a. HasCallStack => [Char] -> a
error ([Char]
"bug in scc---vertex not on the stack: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
v)

check_adj :: Func s (Vertex -> [Vertex] -> ST s S)
check_adj :: forall s. Func s (Int -> [Int] -> ST s S)
check_adj Graph
g STUArray s Int Int
ixes STUArray s Int Int
lows S
st Int
v (Int
v':[Int]
vs) =
  do Int
i <- STUArray s Int Int -> Int -> ST s Int
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STUArray s Int Int
ixes Int
v'
     case () of
       ()
_ | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 ->
             do S
newS <- Func s (Int -> ST s S)
forall s. Func s (Int -> ST s S)
from_root Graph
g STUArray s Int Int
ixes STUArray s Int Int
lows S
st Int
v'
                Int
new_low <- Int -> Int -> Int
forall a. Ord a => a -> a -> a
min (Int -> Int -> Int) -> ST s Int -> ST s (Int -> Int)
forall a b. (a -> b) -> ST s a -> ST s b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` STUArray s Int Int -> Int -> ST s Int
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STUArray s Int Int
lows Int
v ST s (Int -> Int) -> ST s Int -> ST s Int
forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
`ap` STUArray s Int Int -> Int -> ST s Int
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STUArray s Int Int
lows Int
v'
                STUArray s Int Int -> Int -> Int -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STUArray s Int Int
lows Int
v Int
new_low
                Func s (Int -> [Int] -> ST s S)
forall s. Func s (Int -> [Int] -> ST s S)
check_adj Graph
g STUArray s Int Int
ixes STUArray s Int Int
lows S
newS Int
v [Int]
vs
         | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 ->
             do Int
j <- STUArray s Int Int -> Int -> ST s Int
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STUArray s Int Int
lows Int
v
                STUArray s Int Int -> Int -> Int -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STUArray s Int Int
lows Int
v (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
j (Int -> Int
forall a. Num a => a -> a
negate Int
i))
                Func s (Int -> [Int] -> ST s S)
forall s. Func s (Int -> [Int] -> ST s S)
check_adj Graph
g STUArray s Int Int
ixes STUArray s Int Int
lows S
st Int
v [Int]
vs
         | Bool
otherwise -> Func s (Int -> [Int] -> ST s S)
forall s. Func s (Int -> [Int] -> ST s S)
check_adj Graph
g STUArray s Int Int
ixes STUArray s Int Int
lows S
st Int
v [Int]
vs
check_adj Graph
_ STUArray s Int Int
_ STUArray s Int Int
_ S
st Int
_ [] = S -> ST s S
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return S
st