{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE FlexibleInstances #-}


-- Three implementations of the QRAM model: two as state monads, and
-- one merely logging what would have to be done.
-- 
-- The QRAM is abstracted away with the class QC.

module QRAM where

import Data.Complex
import qualified Data.IntMap as IntMap


-- We consider vectors as lists. We need some operations on them.

instance (Num a) => Num [a] where
    (+) a b = map (\(x,y) -> x+y) (zip a b)
    (*) a b = map (\(x,y) -> x*y) (zip a b)
    negate a = map negate a
    abs a = map abs a
    signum a = map signum a
    fromInteger i = [fromInteger i]


-- scalar multiplication

(*.) :: Num a => a -> [a] -> [a]
(*.) x a = map (\y -> x*y) a


-- tensor

(***) ::  Num a => [a] -> [a] -> [a]
(***) x y = concat (map (\c -> c *. y) x)

-- norm of the vector

sqnorm :: RealFloat a => [Complex a] -> a
sqnorm a = foldl (\z (x:+y) -> z + x*x + y*y) 0 a

-- numbers of qubits in the array: log2 of the size
qlength :: [a] -> Int
qlength [] = -1
qlength [x] = 0
qlength l = let (a,b) = splitAt (div (length l) 2) l in 1 + qlength a 

-- A quantum memory

class RealFloat a => Qmem m a | m -> a where
    qmem_empty :: m
    qmem_new :: Bool -> m -> (m,Int)
    qmem_meas :: Int -> m -> (a,m, a,m)       -- (if False, if True)
    qmem_meas_strict :: Int -> m -> (a,m, a,m)
    qmem_had :: Int -> m -> m
    qmem_CNOT :: Int -> Int -> m -> m
    -- the ones for the teleportation algorithm
    qmem_NOT :: Int -> m -> m
    qmem_Id2 :: Int -> m -> m
    qmem_NOT2 :: Int -> m -> m



-- A quantum memory is just a list of 2^n complex numbers. The index gives the base in lexicographic order:
-- |00000> is the very first element, |11111> the very last one.
-- The Haskell booleans are transposed as follows:
--   |0> == False
--   |1> == True


instance RealFloat a => Qmem [Complex a] a where
    qmem_empty = [1:+0]
    qmem_new False l = (concat (map (\x -> [x, 0:+0]) l), length l)
    qmem_new True l = (concat (map (\x -> [0:+0,x]) l), length l)
    qmem_meas 0 m = error "no qbit number 0"
    qmem_meas 1 m = 
       let (a,b) = splitAt (div (length m) 2) m in
       (sqnorm a, a ++ (0 *. b), sqnorm b, (0 *. a) ++ b)
    qmem_meas n m = 
       let (a,b) = splitAt (div (length m) 2) m in
       let (a1,b1,c1,d1) = qmem_meas (n-1) a in
       let (a2,b2,c2,d2) = qmem_meas (n-1) b in
       (a1+a2,b1++b2,c1+c2,d1++d2)
    qmem_meas_strict 1 m = 
       let (a,b) = splitAt (div (length m) 2) m in
       let a_prob = sqnorm a in 
       let b_prob = sqnorm b in
       (a_prob,a,b_prob,b) -- (a_prob, [a_prob:+0], b_prob, [b_prob:+0])
    qmem_meas_strict n m = 
       let (a,b) = splitAt (div (length m) 2) m in
       let (a1,b1,c1,d1) = qmem_meas_strict (n-1) a in
       let (a2,b2,c2,d2) = qmem_meas_strict (n-1) b in
       (a1+a2,b1++b2,c1+c2,d1++d2)
    qmem_had n m = qmem_twobytwo (1/(sqrt 2),1/(sqrt 2),1/(sqrt 2),-1/(sqrt 2)) n m
    qmem_CNOT n1 n2 mat = qmem_fourbyfour ((1,0,0,1),(0,0,0,0),(0,0,0,0),(0,1,1,0)) n1 n2 mat
    qmem_NOT n m = qmem_twobytwo (0,1,1,0) n m
    qmem_Id2 n m = qmem_twobytwo (1,0,0,-1) n m
    qmem_NOT2 n m = qmem_twobytwo (0,1,-1,0) n m



-- a 2x2 matrix is  ( a c )
--                  ( b d )

qmem_twobytwo (a,b,c,d) 0 m = error "no qbit number 0"
qmem_twobytwo (a,b,c,d) 1 m =
       let (x1,x2) = splitAt (div (length m) 2) m in
       let y1 = (a *. x1) ++ (b *. x1) in
       let y2 = (c *. x2) ++ (d *. x2) in y1 + y2
qmem_twobytwo mat n m =
       let (x1,x2) = splitAt (div (length m) 2) m in
       let y1 = qmem_twobytwo mat (n-1) x1 in
       let y2 = qmem_twobytwo mat (n-1) x2 in
       y1 ++ y2

-- now, the a b c d are 2x2 matrices

qmem_fourbyfour (a,b,c,d) 0 0 m = error "2-qbits gates waits for non-equal wires"
qmem_fourbyfour (a,b,c,d) 1 1 m = error "2-qbits gates waits for non-equal wires"
qmem_fourbyfour (a,b,c,d) 0 n m = error "2-qbits gates waits for non-zero wire"
qmem_fourbyfour (a,b,c,d) n 0 m = error "2-qbits gates waits for non-zero wire"
qmem_fourbyfour (a,b,c,d) 1 n m =
       let (x1,x2) = splitAt (div (length m) 2) m in
       let y1 = (qmem_twobytwo a (n-1) x1) ++ (qmem_twobytwo b (n-1) x1) in
       let y2 = (qmem_twobytwo c (n-1) x2) ++ (qmem_twobytwo d (n-1) x2) in y1 + y2
qmem_fourbyfour mat n1 n2 m =
       if (n1 == n2) then  error "2-qbits gates waits for non-equal wires"
       else if (n1 > n2) then qmem_fourbyfour (swap_fourbyfour mat) n2 n1 m
       else
       let (x1,x2) = splitAt (div (length m) 2) m in
       let y1 = qmem_fourbyfour mat (n1-1) (n2-1) x1 in
       let y2 = qmem_fourbyfour mat (n1-1) (n2-1) x2 in
       y1 ++ y2

{-
  
  a1 a3  c1 c3        a1 c1  a3 c3
  a2 a4  c2 c4        b1 d1  b3 d3
                 -->  
  b1 b3  d1 d3        a2 c2  a4 c4
  b2 b4  d2 d4        b2 d2  b4 d4
  
-}

swap_fourbyfour ((a1,a2,a3,a4),(b1,b2,b3,b4),(c1,c2,c3,c4),(d1,d2,d3,d4)) = 
           ((a1,b1,c1,d1),(a2,b2,c2,d2),(a3,b3,c3,d3),(a4,b4,c4,d4))


-- make |00>, |01>, |10> and |11>

ket00 :: [Complex Double]
ket00 = fst (qmem_new False (fst (qmem_new False (qmem_empty))))

ket01 :: [Complex Double]
ket01 = fst (qmem_new True (fst (qmem_new False (qmem_empty))))

ket10 :: [Complex Double]
ket10 = fst (qmem_new False (fst (qmem_new True (qmem_empty))))

ket11 :: [Complex Double]
ket11 = fst (qmem_new True (fst (qmem_new True (qmem_empty))))



-- ----------------------------------------------------------------------------
-- ----------------------------------------------------------------------------
-- IMPLEMENTATIONS
--

-- We first need to give more structure on monads

class (Monad m) => StrongMonad m where
    strength :: m a -> m b -> m (a,b)

data Qbit = Qbit Int
  deriving Show

class QC m where
   qc_new :: Bool -> m Qbit
   qc_had :: m Qbit -> m Qbit
   qc_CNOT :: m (Qbit, Qbit) -> m (Qbit, Qbit)
   qc_meas :: m Qbit -> m Bool
   -- the ones for the teleportation algorithm
   qc_NOT :: m Qbit -> m Qbit
   qc_Id2 :: m Qbit -> m Qbit
   qc_NOT2 :: m Qbit -> m Qbit



-- ----------------------------------------------------------------------------
-- ----------------------------------------------------------------------------
-- An implementation of a quantum machine with a non-destructive measure
-- (i.e. no garbage collection)
--

-- We never delete any qbit. New qubits are just added at the end;
-- measurement do not change the state. Reference to qbits can be done
-- by their index in the array without more overhead.

type Qram = [Complex Double]

empty_qram :: Qram
empty_qram = [1.0]

-- The state monad. The probabilistic superposition is merely a list:
-- we do not renormalize the Qram, so that the actual probability to
-- get to the corresponding branch is just the norm of the qram.

data PQM a = PQM (Qram -> [(Qram, a)])

instance (Show b) => Show (PQM b) where
  show (PQM a) = show (a empty_qram)


instance Monad PQM where
    (>>=) (PQM pqm) f = 
       PQM $ \qram -> 
          let p = pqm qram in 
          concat (map (\(qram,x) -> let (PQM q) = f x in q qram) p)
    return a = PQM $ \q -> [(q,a)]


instance StrongMonad PQM where
     strength (PQM qm1) (PQM qm2) = 
       PQM $ \qram1 ->
         concat (map (\(qram2,a) -> aux_fun a (qm2 qram2)) (qm1 qram1))



qc_make_qbit :: Bool -> [Complex Double]
qc_make_qbit True = [0,1]
qc_make_qbit False = [1,0]


instance QC PQM where
     qc_new b = 
        PQM $ \q -> 
             let loc = (qlength q) + 1 in
             [(q *** (qc_make_qbit b), Qbit loc)]


     qc_had (PQM qm1) = 
        PQM $ \qram -> 
             let f (q, Qbit t) = (qmem_had t q, Qbit t) in
             map f (qm1 qram)

     qc_CNOT (PQM qm1) = 
        PQM $ \qram -> 
             let f (q, (Qbit t1, Qbit t2)) =
                  (qmem_CNOT t1 t2 q, (Qbit t1, Qbit t2)) in
             map f (qm1 qram)

     qc_meas (PQM qm1) = 
        PQM $ \qram -> concat (map qc_meas_aux (qm1 qram))

     qc_NOT (PQM qm1) = 
        PQM $ \qram -> 
             let f (q, Qbit t) = (qmem_NOT t q, Qbit t) in
             map f (qm1 qram)
             
     qc_Id2 (PQM qm1) = 
        PQM $ \qram -> 
             let f (q, Qbit t) = (qmem_Id2 t q, Qbit t) in
             map f (qm1 qram)

     qc_NOT2 (PQM qm1) = 
        PQM $ \qram -> 
             let f (q, Qbit t) = (qmem_NOT2 t q, Qbit t) in
             map f (qm1 qram)

        
qc_meas_aux (q, Qbit t) =
   let (_,qF,_,qT) = qmem_meas t q in
   [(qF,False), (qT,True)]




-- ----------------------------------------------------------------------------
-- ----------------------------------------------------------------------------
-- An implementation of a quantum machine with a destructive measure
-- We need pointers.
--

-- Here, we delete the qubit that was just measured. So the reference
-- to qbit cannot only rely on their index. We need a level of
-- indirection, using an IntMap. A new qubit will be assigned an
-- unused integer. The very first Int in the tuple is the last ID
-- created. IDs are created from 0, by increasing by one the last
-- created ID.

type QramStrict = (Int, [Complex Double], (IntMap.IntMap Int))

empty_qram_strict = (0,[1],IntMap.empty)

-- We follow the same policy as before with respect to the
-- probabilistic distribution.

data PQMStrict a = PQMS (QramStrict -> [(QramStrict, a)])

instance (Show b) => Show (PQMStrict b) where
  show (PQMS a) = show (a empty_qram_strict)


instance Monad PQMStrict where
    (>>=) (PQMS pqm) f = 
       PQMS $ \qram -> 
          let p = pqm qram in 
          concat (map (\(qram,x) -> let (PQMS q) = f x in q qram) p)
    return a = PQMS $ \q -> [(q,a)]


aux_fun :: a -> [(q,b)] -> [(q,(a,b))]
aux_fun a p = map (\(q,b) -> (q,(a,b))) p

instance StrongMonad PQMStrict where
     strength (PQMS qm1) (PQMS qm2) = 
       PQMS $ \qram1 ->
         concat (map (\(qram2,a) -> aux_fun a (qm2 qram2)) (qm1 qram1))

instance QC PQMStrict where
     qc_new b = 
        PQMS $ \(n,q,m) -> 
             let loc = (qlength q) + 1 in
             [((n+1, q *** (qc_make_qbit b), IntMap.insert (n+1) loc m), Qbit loc)]


     qc_had (PQMS qm1) = 
        PQMS $ \qram -> 
             let f ((n1,q1,m1), Qbit t1) = ((n1, qmem_had (m1 IntMap.! t1) q1, m1), Qbit t1) in
             map f (qm1 qram)

     qc_CNOT (PQMS qm1) = 
        PQMS $ \qram -> 
             let f ((n1,q1,m1), (Qbit t11, Qbit t12)) =
                  ((n1, qmem_CNOT (m1 IntMap.! t11) (m1 IntMap.! t12) q1, m1), (Qbit t11, Qbit t12)) in
             map f (qm1 qram)

     qc_meas (PQMS qm1) = 
        PQMS $ \qram -> concat (map qc_meas_aux_strict (qm1 qram))

     qc_NOT (PQMS qm1) = 
        PQMS $ \qram -> 
             let f ((n1,q1,m1), Qbit t1) = ((n1, qmem_NOT (m1 IntMap.! t1) q1, m1), Qbit t1) in
             map f (qm1 qram)

     qc_Id2 (PQMS qm1) = 
        PQMS $ \qram -> 
             let f ((n1,q1,m1), Qbit t1) = ((n1, qmem_Id2 (m1 IntMap.! t1) q1, m1), Qbit t1) in
             map f (qm1 qram)

     qc_NOT2 (PQMS qm1) = 
        PQMS $ \qram -> 
             let f ((n1,q1,m1), Qbit t1) = ((n1, qmem_NOT2 (m1 IntMap.! t1) q1, m1), Qbit t1) in
             map f (qm1 qram)


-- This time, there is some non-trivial operations to perform:
-- 1) Measure against m[t]
-- 2) Remove key t from m
-- 3) decrease by one all the values that were greater than m[t] (since the qbit at that spot disappeared)
qc_meas_aux_strict ((n,q,m), Qbit t) =
   let (_,qF,_,qT) = qmem_meas_strict (m IntMap.! t) q in
   let m_strict = IntMap.delete t m in
   let (m_before, m_after) = IntMap.partitionWithKey (\k v -> v < (m IntMap.! t)) m_strict in
   let m_after_shifted = IntMap.fromList (map (\(k,v) -> (k,v-1)) (IntMap.toList m_after)) in
   let m_new = IntMap.union m_before m_after_shifted in
   [((n,qF,m_new),False), ((n,qT,m_new),True)]



data FakeQM a = FQM [(String,a)]

instance (Show b) => Show (FakeQM b) where
  show (FQM s) = concat $ map (\l -> l ++ "\n") (map fst s)


instance Monad FakeQM where
    (>>=) (FQM l) f = FQM $ concat $ map (\(s,a) -> let (FQM l') = f a in map (\(t,b) -> (s ++ t, b)) l') l
    return a = FQM [("",a)]


instance StrongMonad FakeQM where
     strength x y = do
       xx <- x
       yy <- y
       return (xx,yy)


instance QC FakeQM where
     qc_new b = FQM [("New " ++ (show b) ++ "\n", Qbit 0)]
     qc_had (FQM l) = FQM $ map (\(s,a) -> (s ++ "Had\n", a)) l
     qc_CNOT (FQM l) = FQM $ map (\(s,a) -> (s ++ "CNOT\n", a)) l
     qc_meas (FQM l) = FQM $ concat $ map (\(s,a) -> [(s ++ "meas True\n", True), (s ++ "meas False\n", False)]) l
     qc_NOT (FQM l) = FQM $ map (\(s,a) -> (s ++ "NOT\n",a)) l
     qc_Id2 (FQM l) = FQM $ map (\(s,a) -> (s ++ "Id2\n",a)) l
     qc_NOT2 (FQM l) = FQM $ map (\(s,a) -> (s ++ "NOT2\n",a)) l
