module arch.node.engines.mempool_worker_behaviour;

import arch.node.engines.mempool_worker_messages open;
import arch.node.engines.mempool_worker_config open;
import arch.node.engines.mempool_worker_environment open;
import arch.node.engines.shard_messages open;
import arch.node.engines.executor_messages open;
import arch.node.engines.executor_config open;
import arch.node.engines.executor_environment open;

import prelude open;
import Stdlib.Data.Nat open;
import Stdlib.Data.List as List;
import Stdlib.Data.Set as Set;
import arch.node.types.basics open;
import arch.node.types.identities open;
import arch.node.types.messages open;
import arch.node.types.engine open;
import arch.node.types.anoma as Anoma open;

import arch.system.state.resource_machine.notes.runnable open;

sign
  : TxFingerprint
    -> TransactionCandidate KVSKey KVSKey Executable
    -> Signature := \{txfp _ := Signature.Ed25519Signature (natToString txfp)};

hash : TxFingerprint -> TransactionCandidate KVSKey KVSKey Executable -> Hash :=
  \{txfp _ := txfp};

syntax alias MempoolWorkerActionArgument := Unit;

MempoolWorkerActionArguments : Type := List MempoolWorkerActionArgument;

MempoolWorkerAction : Type :=
  Action
    MempoolWorkerLocalCfg
    MempoolWorkerLocalState
    MempoolWorkerMailboxState
    MempoolWorkerTimerHandle
    MempoolWorkerActionArguments
    Anoma.Msg
    Anoma.Cfg
    Anoma.Env;

MempoolWorkerActionInput : Type :=
  ActionInput
    MempoolWorkerLocalCfg
    MempoolWorkerLocalState
    MempoolWorkerMailboxState
    MempoolWorkerTimerHandle
    MempoolWorkerActionArguments
    Anoma.Msg;

MempoolWorkerActionEffect : Type :=
  ActionEffect
    MempoolWorkerLocalState
    MempoolWorkerMailboxState
    MempoolWorkerTimerHandle
    Anoma.Msg
    Anoma.Cfg
    Anoma.Env;

MempoolWorkerActionExec : Type :=
  ActionExec
    MempoolWorkerLocalCfg
    MempoolWorkerLocalState
    MempoolWorkerMailboxState
    MempoolWorkerTimerHandle
    MempoolWorkerActionArguments
    Anoma.Msg
    Anoma.Cfg
    Anoma.Env;

transactionRequestAction
  {{rinst : Runnable KVSKey KVSDatum Executable ProgramState}}
  (input : MempoolWorkerActionInput)
  : Option MempoolWorkerActionEffect :=
  let
    env := ActionInput.env input;
    cfg := ActionInput.cfg input;
    local := EngineEnv.localState env;
    trigger := ActionInput.trigger input;
    keyToShard := MempoolWorkerLocalCfg.keyToShard (EngineCfg.cfg cfg);
  in case getEngineMsgFromTimestampedTrigger trigger of
       | some emsg :=
         case emsg of {
           | EngineMsg.mk@{
               sender := sender;
               target := _;
               mailbox := _;
               msg := Anoma.Msg.MempoolWorker (MempoolWorkerMsg.TransactionRequest request);
             } :=
             let
               fingerprint := MempoolWorkerLocalState.gensym local + 1;
               worker_id := getEngineIDFromEngineCfg cfg;
               candidate := TransactionRequest.tx request;
               executor_name := nameGen "executor" (snd worker_id) worker_id;
               executor_id := mkPair none executor_name;
               executorCfg :=
                 Anoma.Cfg.CfgExecutor
                   EngineCfg.mk@{
                     node := EngineCfg.node cfg;
                     -- Copies the node id from the parent engine.
                     name := executor_name;
                     cfg :=
                       ExecutorLocalCfg.mk@{
                         timestamp := fingerprint;
                         executable :=
                           TransactionCandidate.executable candidate;
                         lazy_read_keys := Set.empty;
                         eager_read_keys :=
                           Set.fromList
                             (TransactionLabel.read
                               (TransactionCandidate.label candidate));
                         will_write_keys :=
                           Set.fromList
                             (TransactionLabel.write
                               (TransactionCandidate.label candidate));
                         may_write_keys := Set.empty;
                         worker := worker_id;
                         issuer := sender;
                         keyToShard := keyToShard;
                       };
                   };
               executorEnv :=
                 Anoma.Env.EnvExecutor
                   EngineEnv.mk@{
                     localState :=
                       ExecutorLocalState.mk@{
                         program_state := Runnable.startingState {{rinst}};
                         completed_reads := Map.empty;
                         completed_writes := Map.empty;
                       };
                     mailboxCluster := Map.empty;
                     acquaintances := Set.empty;
                     timers := [];
                   };
               newState :=
                 local@MempoolWorkerLocalState{
                   gensym := fingerprint;
                   transactions := Map.insert
                     fingerprint
                     candidate
                     (MempoolWorkerLocalState.transactions local);
                   transactionEngines := Map.insert
                     executor_id
                     fingerprint
                     (MempoolWorkerLocalState.transactionEngines local);
                 };
               newEnv := env@EngineEnv{localState := newState};
               read_keys :=
                 Set.fromList
                   (TransactionLabel.read
                     (TransactionCandidate.label candidate));
               write_keys :=
                 Set.fromList
                   (TransactionLabel.write
                     (TransactionCandidate.label candidate));
               shards :=
                 Set.toList
                   (Set.map keyToShard (Set.union read_keys write_keys));
               shardMsgs :=
                 map
                   \{shard :=
                     let
                       shard_read_keys :=
                         Set.filter
                           \{key := snd (keyToShard key) == snd shard}
                           read_keys;
                       shard_write_keys :=
                         Set.filter
                           \{key := snd (keyToShard key) == snd shard}
                           write_keys;
                       lockRequest :=
                         KVSAcquireLockMsg.mkKVSAcquireLockMsg@{
                           lazy_read_keys := Set.empty;
                           eager_read_keys := shard_read_keys;
                           will_write_keys := shard_write_keys;
                           may_write_keys := Set.empty;
                           worker := worker_id;
                           executor := executor_id;
                           timestamp := fingerprint;
                         };
                     in EngineMsg.mk@{
                          sender := worker_id;
                          target := shard;
                          mailbox := some 0;
                          msg :=
                            Anoma.Msg.Shard
                              (ShardMsg.KVSAcquireLock lockRequest);
                        }}
                   shards;
               ackMsg :=
                 EngineMsg.mk@{
                   sender := worker_id;
                   target := sender;
                   mailbox := some 0;
                   msg :=
                     Anoma.Msg.MempoolWorker
                       (MempoolWorkerMsg.TransactionAck
                         TransactionAck.mkTransactionAck@{
                           tx_hash := hash fingerprint candidate;
                           batch_number :=
                             MempoolWorkerLocalState.batch_number local;
                           batch_start := 0;
                           worker_id := worker_id;
                           signature := sign fingerprint candidate;
                         });
                 };
             in some
               ActionEffect.mk@{
                 env := newEnv;
                 msgs := ackMsg :: shardMsgs;
                 timers := [];
                 engines := [mkPair executorCfg executorEnv];
               }
           | _ := none
         }
       | _ := none;

allLocksAcquired
  (keyToShard : KVSKey -> EngineID)
  (isWrite : Bool)
  (tx : TransactionCandidate KVSKey KVSKey Executable)
  (txNum : TxFingerprint)
  (locks : List (Pair EngineID KVSLockAcquiredMsg))
  : Bool :=
  let
    keys :=
      case isWrite of
        | true := TransactionLabel.write (TransactionCandidate.label tx)
        | false := TransactionLabel.read (TransactionCandidate.label tx);
    neededShards := Set.fromList (map keyToShard keys);
    lockingShards :=
      Set.fromList
        (map
          fst
          (List.filter
            \{lock := KVSLockAcquiredMsg.timestamp (snd lock) == txNum}
            locks));
  in Set.isSubset neededShards lockingShards;

--- Finds the highest transaction fingerprint N such that all transactions with fingerprints 1..N
--- have acquired all their necessary locks of the specified type (read or write). This represents
--- the "safe point" up to which shards can process transactions without worrying about missing locks.
terminating
findMaxConsecutiveLocked
  (keyToShard : KVSKey -> EngineID)
  (isWrite : Bool)
  (transactions : Map
    TxFingerprint
    (TransactionCandidate KVSKey KVSKey Executable))
  (locks : List (Pair EngineID KVSLockAcquiredMsg))
  (current : TxFingerprint)
  (prev : TxFingerprint)
  : TxFingerprint :=
  case Map.lookup current transactions of
    | none := prev
    | some tx :=
      case allLocksAcquired keyToShard isWrite tx current locks of
        | true :=
          findMaxConsecutiveLocked
            keyToShard
            isWrite
            transactions
            locks
            (current + 1)
            current
        | false := prev;

getAllShards
  (keyToShard : KVSKey -> EngineID)
  (transactions : Map
    TxFingerprint
    (TransactionCandidate KVSKey KVSKey Executable))
  : Set EngineID :=
  let
    getAllKeysFromLabel
      (label : TransactionLabel KVSKey KVSKey) : List KVSKey :=
      TransactionLabel.read label ++ TransactionLabel.write label;
    allKeys :=
      List.concatMap
        \{tx := getAllKeysFromLabel (TransactionCandidate.label tx)}
        (Map.values transactions);
  in Set.fromList (map keyToShard allKeys);

lockAcquiredAction
  (input : MempoolWorkerActionInput) : Option MempoolWorkerActionEffect :=
  let
    env := ActionInput.env input;
    local := EngineEnv.localState env;
    trigger := ActionInput.trigger input;
    keyToShard :=
      MempoolWorkerLocalCfg.keyToShard (EngineCfg.cfg (ActionInput.cfg input));
  in case getEngineMsgFromTimestampedTrigger trigger of
       | some emsg :=
         case emsg of {
           | EngineMsg.mk@{
               msg := Anoma.Msg.Shard (ShardMsg.KVSLockAcquired lockMsg);
               sender := sender;
             } :=
             let
               timestamp := KVSLockAcquiredMsg.timestamp lockMsg;
               newLocks :=
                 mkPair sender lockMsg
                   :: MempoolWorkerLocalState.locks_acquired local;
               maxConsecutiveWrite :=
                 findMaxConsecutiveLocked
                   keyToShard
                   true
                   (MempoolWorkerLocalState.transactions local)
                   newLocks
                   1
                   0;
               maxConsecutiveRead :=
                 findMaxConsecutiveLocked
                   keyToShard
                   false
                   (MempoolWorkerLocalState.transactions local)
                   newLocks
                   1
                   0;
               newState :=
                 local@MempoolWorkerLocalState{
                   locks_acquired := newLocks;
                   seen_all_writes := maxConsecutiveWrite;
                   seen_all_reads := maxConsecutiveRead;
                 };
               newEnv := env@EngineEnv{localState := newState};
               allShards :=
                 getAllShards
                   keyToShard
                   (MempoolWorkerLocalState.transactions local);
               makeUpdateMsg
                 (target : EngineID)
                 (isWrite : Bool)
                 (timestamp : TxFingerprint)
                 : EngineMsg Anoma.Msg :=
                 EngineMsg.mk@{
                   sender := getEngineIDFromEngineCfg (ActionInput.cfg input);
                   target := target;
                   mailbox := some 0;
                   msg :=
                     Anoma.Msg.Shard
                       (ShardMsg.UpdateSeenAll
                         UpdateSeenAllMsg.mkUpdateSeenAllMsg@{
                           timestamp := timestamp;
                           write := isWrite;
                         });
                 };
               writeMessages :=
                 map
                   \{shard := makeUpdateMsg shard true maxConsecutiveWrite}
                   (Set.toList allShards);
               readMessages :=
                 map
                   \{shard := makeUpdateMsg shard false maxConsecutiveRead}
                   (Set.toList allShards);
             in some
               ActionEffect.mk@{
                 env := newEnv;
                 msgs := writeMessages ++ readMessages;
                 timers := [];
                 engines := [];
               }
           | _ := none
         }
       | _ := none;

executorFinishedAction
  (input : MempoolWorkerActionInput) : Option MempoolWorkerActionEffect :=
  let
    env := ActionInput.env input;
    local := EngineEnv.localState env;
    trigger := ActionInput.trigger input;
  in case getEngineMsgFromTimestampedTrigger trigger of
       | some emsg :=
         case emsg of {
           | EngineMsg.mk@{
               msg := Anoma.Msg.Executor (ExecutorMsg.ExecutorFinished summary);
               sender := sender;
             } :=
             case
               Map.lookup
                 sender
                 (MempoolWorkerLocalState.transactionEngines local)
             of {
               | some tr :=
                 let
                   newState :=
                     local@MempoolWorkerLocalState{execution_summaries := Map.insert
                       tr
                       summary
                       (MempoolWorkerLocalState.execution_summaries local)};
                   newEnv := env@EngineEnv{localState := newState};
                 in some
                   ActionEffect.mk@{
                     env := newEnv;
                     msgs := [];
                     timers := [];
                     engines := [];
                   }
               | _ := none
             }
           | _ := none
         }
       | _ := none;

transactionRequestActionLabel : MempoolWorkerActionExec :=
  ActionExec.Seq [transactionRequestAction];

lockAcquiredActionLabel : MempoolWorkerActionExec :=
  ActionExec.Seq [lockAcquiredAction];

executorFinishedActionLabel : MempoolWorkerActionExec :=
  ActionExec.Seq [executorFinishedAction];

MempoolWorkerGuard : Type :=
  Guard
    MempoolWorkerLocalCfg
    MempoolWorkerLocalState
    MempoolWorkerMailboxState
    MempoolWorkerTimerHandle
    MempoolWorkerActionArguments
    Anoma.Msg
    Anoma.Cfg
    Anoma.Env;

MempoolWorkerGuardOutput : Type :=
  GuardOutput
    MempoolWorkerLocalCfg
    MempoolWorkerLocalState
    MempoolWorkerMailboxState
    MempoolWorkerTimerHandle
    MempoolWorkerActionArguments
    Anoma.Msg
    Anoma.Cfg
    Anoma.Env;

MempoolWorkerGuardEval : Type :=
  GuardEval
    MempoolWorkerLocalCfg
    MempoolWorkerLocalState
    MempoolWorkerMailboxState
    MempoolWorkerTimerHandle
    MempoolWorkerActionArguments
    Anoma.Msg
    Anoma.Cfg
    Anoma.Env;

transactionRequestGuard
  {{Runnable KVSKey KVSDatum Executable ProgramState}}
  (trigger : TimestampedTrigger MempoolWorkerTimerHandle Anoma.Msg)
  (cfg : MempoolWorkerCfg)
  (env : MempoolWorkerEnv)
  : Option MempoolWorkerGuardOutput :=
  case getEngineMsgFromTimestampedTrigger trigger of
    | some EngineMsg.mk@{
             msg := Anoma.Msg.MempoolWorker (MempoolWorkerMsg.TransactionRequest _);
           } :=
      some
        GuardOutput.mk@{
          action := transactionRequestActionLabel;
          args := [];
        }
    | _ := none;

lockAcquiredGuard
  (trigger : TimestampedTrigger MempoolWorkerTimerHandle Anoma.Msg)
  (cfg : MempoolWorkerCfg)
  (env : MempoolWorkerEnv)
  : Option MempoolWorkerGuardOutput :=
  case getEngineMsgFromTimestampedTrigger trigger of
    | some EngineMsg.mk@{msg := Anoma.Msg.Shard (ShardMsg.KVSLockAcquired _)} :=
      some
        GuardOutput.mk@{
          action := lockAcquiredActionLabel;
          args := [];
        }
    | _ := none;

executorFinishedGuard
  (trigger : TimestampedTrigger MempoolWorkerTimerHandle Anoma.Msg)
  (cfg : MempoolWorkerCfg)
  (env : MempoolWorkerEnv)
  : Option MempoolWorkerGuardOutput :=
  case getEngineMsgFromTimestampedTrigger trigger of
    | some EngineMsg.mk@{
             msg := Anoma.Msg.Executor (ExecutorMsg.ExecutorFinished _);
           } :=
      some
        GuardOutput.mk@{
          action := executorFinishedActionLabel;
          args := [];
        }
    | _ := none;

MempoolWorkerBehaviour : Type :=
  EngineBehaviour
    MempoolWorkerLocalCfg
    MempoolWorkerLocalState
    MempoolWorkerMailboxState
    MempoolWorkerTimerHandle
    MempoolWorkerActionArguments
    Anoma.Msg
    Anoma.Cfg
    Anoma.Env;

mempoolWorkerBehaviour : MempoolWorkerBehaviour :=
  EngineBehaviour.mk@{
    guards :=
      GuardEval.First
        [transactionRequestGuard; lockAcquiredGuard; executorFinishedGuard];
  };