module prelude;

import Stdlib.Trait open public;
import Stdlib.Trait.Ord open using {Ordering; mkOrd; Equal; isEqual} public;
import Stdlib.Trait.Eq open using {==} public;
import Stdlib.Debug.Fail open using {failwith};
import Stdlib.Data.Fixity open public;
import Stdlib.Function open using {
  <<;
  >>;
  const;
  id;
  flip;
  <|;
  |>;
  iterate;
  >->;
};
import Stdlib.Trait.Functor.Polymorphic as Functor;
import Stdlib.Trait.Applicative open using {Applicative; mkApplicative} public;

open Applicative public;

import Stdlib.Trait.Monad open using {Monad; mkMonad} public;

open Monad public;

join {M : Type -> Type} {A} {{Monad M}} (mma : M (M A)) : M A := bind mma id;

-- using the built-in `bind`
trait
type Bifunctor (F : Type -> Type -> Type) :=
  mkBifunctor@{
    bimap {A B C D} : (A -> C) -> (B -> D) -> F A B -> F C D;
  };

trait
type AssociativeProduct (F : Type -> Type -> Type) :=
  mkAssociativeProduct@{
    assocLeft {A B C} : F A (F B C) -> F (F A B) C;
    assocRight {A B C} : F (F A B) C -> F A (F B C);
  };

trait
type CommutativeProduct (F : Type -> Type -> Type) :=
  mkCommutativeProduct@{
    swap {A B} : F A B -> F B A;
  };

trait
type UnitalProduct U (F : Type -> Type -> Type) :=
  mkUnitalProduct@{
    unitLeft {A} : A -> F U A;
    unUnitLeft {A} : F U A -> A;
    unitRight {A} : A -> F A U;
    unUnitRight {A} : F A U -> A;
  };

trait
type Traversable (T : Type -> Type) :=
  mkTraversable@{
    {{functorI}} : Functor T;
    {{foldableI}} : Polymorphic.Foldable T;
    sequence
      : {F : Type -> Type}
        -> {A : Type}
        -> {{Applicative F}}
        -> T (F A)
        -> F (T A);
    traverse
      : {F : Type -> Type}
        -> {A B : Type}
        -> {{Applicative F}}
        -> (A -> F B)
        -> T A
        -> F (T B);
  };

import Stdlib.Data.Bool as Bool open using {
  Bool;
  true;
  false;
  &&;
  ||;
  not;
  or;
  and;
} public;

verdad : Bool := true;

xor (a b : Bool) : Bool :=
  if 
    | a := not b
    | else := b;

nand (a b : Bool) : Bool := not (and a b);

nor (a b : Bool) : Bool := not (or a b);

import Stdlib.Data.Nat as Nat open using {
  Nat;
  zero;
  suc;
  natToString;
  +;
  sub;
  *;
  div;
  mod;
  ==;
  <=;
  >;
  <;
  min;
  max;
} public;

ten : Nat := 10;

pred (n : Nat) : Nat :=
  case n of
    | zero := zero
    | suc k := k;

boolToNat (b : Bool) : Nat :=
  if 
    | b := 0
    | else := 1;

isZero (n : Nat) : Bool :=
  case n of
    | zero := true
    | suc k := false;

isEven (n : Nat) : Bool := mod n 2 == 0;

isOdd (n : Nat) : Bool := not (isEven n);

terminating
foldNat {B} (z : B) (f : Nat -> B -> B) (n : Nat) : B :=
  case n of
    | zero := z
    | suc k := f k (foldNat z f k);

iter {A} (f : A -> A) (n : Nat) (x : A) : A := foldNat x \{_ y := f y} n;

exp (base : Nat) (exponent : Nat) : Nat :=
  iter \{product := base * product} exponent 1;

factorial : Nat -> Nat := foldNat 1 \{k r := suc k * r};

terminating
gcd (a b : Nat) : Nat :=
  case b of
    | zero := a
    | suc _ := gcd b (mod a b);

lcm (a b : Nat) : Nat :=
  case b of
    | zero := zero
    | suc _ :=
      case a of
        | zero := zero
        | suc _ := div (a * b) (gcd a b);

import Stdlib.Data.String as String open using {String; ++str} public;

hello : String := "Hello, World!";

axiom stringCmp : String -> String -> Ordering;

instance
StringOrd : Ord String :=
  mkOrd@{
    cmp := stringCmp;
  };

ByteString : Type := String;

emptyByteString : ByteString := "";

import Stdlib.Data.Unit as Unit open using {Unit; unit} public;

unitValue : Unit := unit;

trivial {A} : A -> Unit := const unit;

axiom Empty : Type;

axiom explode {A} : Empty -> A;

import Stdlib.Data.Pair as Pair;
open Pair using {Pair} public;
open Pair using {,};

import Stdlib.Data.Pair as Pair open using {ordProductI; eqProductI} public;
import Stdlib.Data.Fixity open;

syntax operator mkPair none;
syntax alias mkPair := ,;

pair : Pair Nat Bool := mkPair 42 true;

fst {A B} : Pair A B -> A
  | (mkPair a _) := a;

snd {A B} : Pair A B -> B
  | (mkPair _ b) := b;

instance
PairCommutativeProduct : CommutativeProduct Pair :=
  mkCommutativeProduct@{
    swap := \{p := mkPair (snd p) (fst p)};
  };

instance
PairAssociativeProduct : AssociativeProduct Pair :=
  mkAssociativeProduct@{
    assocLeft :=
      \{p :=
        let
          pbc := snd p;
        in mkPair (mkPair (fst p) (fst pbc)) (snd pbc)};
    assocRight :=
      \{p :=
        let
          pab := fst p;
        in mkPair (fst pab) (mkPair (snd pab) (snd p))};
  };

instance
PairUnitalProduct : UnitalProduct Unit Pair :=
  mkUnitalProduct@{
    unitLeft := \{a := mkPair unit a};
    unUnitLeft := snd;
    unitRight := \{a := mkPair a unit};
    unUnitRight := \{{A} := fst};
  };

instance
PairBifunctor : Bifunctor Pair :=
  mkBifunctor@{
    bimap := \{f g p := mkPair (f (fst p)) (g (snd p))};
  };

fork {A B C} (f : C -> A) (g : C -> B) (c : C) : Pair A B := mkPair (f c) (g c);

import Stdlib.Data.Result.Base as Result;
open Result using {Result; ok; error} public;

syntax alias Either := Result;
syntax alias left := error;
syntax alias right := ok;

thisString : Either String Nat := left "Error!";

thisNumber : Either String Nat := right 42;

isLeft {A B} (e : Either A B) : Bool :=
  case e of
    | left _ := true
    | right _ := false;

isRight {A B} (e : Either A B) : Bool :=
  case e of
    | left _ := false
    | right _ := true;

fromLeft {A B} (e : Either A B) (d : A) : A :=
  case e of
    | left x := x
    | right _ := d;

fromRight {A B} (e : Either A B) (d : B) : B :=
  case e of
    | left _ := d
    | right x := x;

swapEither {A B} (e : Either A B) : Either B A :=
  case e of
    | left x := right x
    | right x := left x;

instance
EitherCommutativeProduct : CommutativeProduct Either :=
  mkCommutativeProduct@{
    swap := swapEither;
  };

eitherBimap {A B C D} (f : A -> C) (g : B -> D) (e : Either A B) : Either C D :=
  case e of
    | left a := left (f a)
    | right b := right (g b);

instance
EitherBifunctor : Bifunctor Either :=
  mkBifunctor@{
    bimap := eitherBimap;
  };

unUnitLeftEither {A} (e : Either Empty A) : A :=
  case e of
    | left x := explode x
    | right x := x;

unUnitRightEither {A} (e : Either A Empty) : A :=
  case e of
    | left x := x
    | right x := explode x;

instance
EitherUnitalProduct : UnitalProduct Empty Either :=
  mkUnitalProduct@{
    unitLeft := right;
    unUnitLeft := unUnitLeftEither;
    unitRight := \{{A} := left};
    unUnitRight := unUnitRightEither;
  };

fuse {A B C} (f : A -> C) (g : B -> C) (e : Either A B) : C :=
  case e of
    | left x := f x
    | right x := g x;

assocLeftEither {A B C} (e : Either A (Either B C)) : Either (Either A B) C :=
  case e of
    | left x := left (left x)
    | right ebc :=
      case ebc of
        | left y := left (right y)
        | right z := right z;

assocRightEither {A B C} (e : Either (Either A B) C) : Either A (Either B C) :=
  case e of
    | left eab :=
      case eab of {
        | left x := left x
        | right y := right (left y)
      }
    | right z := right (right z);

instance
EitherAssociativeProduct : AssociativeProduct Either :=
  mkAssociativeProduct@{
    assocLeft := assocLeftEither;
    assocRight := assocRightEither;
  };

import Stdlib.Data.Maybe as Maybe;
open Maybe using {Maybe; just; nothing};

syntax alias Option := Maybe;
syntax alias some := just;
syntax alias none := nothing;

isNone {A} (x : Option A) : Bool :=
  case x of
    | none := true
    | some _ := false;

isSome {A} (x : Option A) : Bool := not (isNone x);

fromOption {A} (x : Option A) (default : A) : A :=
  case x of
    | none := default
    | some x := x;

option {A B} (o : Option A) (default : B) (f : A -> B) : B :=
  case o of
    | none := default
    | some x := f x;

filterOption {A} (p : A -> Bool) (opt : Option A) : Option A :=
  case opt of
    | none := none
    | some x :=
      if 
        | p x := some x
        | else := none;

import Stdlib.Data.List as List open using {
  List;
  nil;
  ::;
  isElement;
  head;
  tail;
  length;
  take;
  drop;
  ++;
  reverse;
  any;
  all;
  zip;
} public;

numbers : List Nat := 1 :: 2 :: 3 :: nil;

-- alternative syntax:
niceNumbers : List Nat := [1; 2; 3];

findIndex {A} (predicate : A -> Bool) : List A -> Option Nat
  | nil := none
  | (x :: xs) :=
    if 
      | predicate x := some zero
      | else :=
        case findIndex predicate xs of
          | none := none
          | some i := some (suc i);

last {A} (lst : List A) (default : A) : A := head default (reverse lst);

most {A} (lst : List A) : List A := tail (reverse lst);

snoc {A} (xs : List A) (x : A) : List A := xs ++ [x];

uncons {A} : List A -> Option (Pair A (List A))
  | nil := none
  | (x :: xs) := some (mkPair x xs);

unsnoc {A} : List A -> Option (Pair (List A) A)
  | nil := none
  | (x :: xs) := some (mkPair (most (x :: xs)) (last xs x));

terminating
unfold {A B} (step : B -> Option (Pair A B)) (seed : B) : List A :=
  case step seed of
    | none := nil
    | some (x, seed') := x :: unfold step seed';

terminating
unzip {A B} (xs : List (Pair A B)) : Pair (List A) (List B) :=
  case xs of
    | nil := mkPair nil nil
    | p :: ps :=
      let
        unzipped := unzip ps;
      in mkPair (fst p :: fst unzipped) (snd p :: snd unzipped);

partitionEither {A B} (es : List (Either A B)) : Pair (List A) (List B) :=
  foldr
    \{e acc :=
      case e of
        | left a := mkPair (a :: fst acc) (snd acc)
        | right b := mkPair (fst acc) (b :: snd acc)}
    (mkPair nil nil)
    es;

partitionEitherWith
  {A B C} (f : C -> Either A B) (es : List C) : Pair (List A) (List B) :=
  partitionEither (map f es);

catOptions {A} : List (Option A) -> List A :=
  foldr
    \{opt acc :=
      case opt of
        | none := acc
        | some x := x :: acc}
    nil;

maximumBy {A B} {{Ord B}} (f : A -> B) (lst : List A) : Option A :=
  let
    maxHelper :=
      \{curr acc :=
        case acc of
          | none := some curr
          | some maxVal :=
            if 
              | f curr > f maxVal := some curr
              | else := some maxVal};
  in foldr maxHelper none lst;

minimalBy {A B} {{Ord B}} (f : A -> B) (lst : List A) : Option A :=
  let
    minHelper :=
      \{curr acc :=
        case acc of
          | none := some curr
          | some minVal :=
            if 
              | f curr < f minVal := some curr
              | else := some minVal};
  in foldr minHelper none lst;

instance
traversableListI : Traversable List :=
  mkTraversable@{
    sequence
      {F : Type -> Type}
      {A}
      {{appF : Applicative F}}
      (xs : List (F A))
      : F (List A) :=
      let
        cons : F A -> F (List A) -> F (List A)
          | x acc := liftA2 (::) x acc;
        
        go : List (F A) -> F (List A)
          | nil := pure nil
          | (x :: xs) := cons x (go xs);
      in go xs;
    
    traverse
      {F : Type -> Type}
      {A B}
      {{appF : Applicative F}}
      (f : A -> F B)
      (xs : List A)
      : F (List B) :=
      let
        cons : A -> F (List B) -> F (List B)
          | x acc := liftA2 (::) (f x) acc;
        
        go : List A -> F (List B)
          | nil := pure nil
          | (x :: xs) := cons x (go xs);
      in go xs;
  };

terminating
chunksOf {A} : (chunkSize : Nat) -> (list : List A) -> List (List A)
  | zero _ := nil
  | _ nil := nil
  | n xs := take n xs :: chunksOf n (drop n xs);

sliding {A} : (windowSize : Nat) -> (list : List A) -> List (List A)
  | zero _ := nil
  | n xs :=
    let
      len : Nat := length xs;
      terminating
      go : List A -> List (List A)
        | nil := nil
        | ys :=
          if 
            | length ys < n := nil
            | else := take n ys :: go (tail ys);
    in if 
      | n > len := nil
      | else := go xs;

span {A} (p : A -> Bool) : List A -> Pair (List A) (List A)
  | nil := mkPair nil nil
  | (x :: xs) :=
    if 
      | p x :=
        let
          (ys1, ys2) := span p xs;
        in mkPair (x :: ys1) ys2
      | else := mkPair nil (x :: xs);

terminating
groupBy {A} (eq : A -> A -> Bool) : List A -> List (List A)
  | nil := nil
  | (x :: xs) :=
    case span (eq x) xs of ys1, ys2 := (x :: ys1) :: groupBy eq ys2;

group {A} {{Eq A}} : List A -> List (List A) := groupBy (==);

nubBy {A} (eq : A -> A -> Bool) : List A -> List A :=
  let
    -- Checks if an element is already in the accumulator
    elemBy (x : A) : List A -> Bool
      | nil := false
      | (y :: ys) := eq x y || elemBy x ys;
    
    go : List A -> List A -> List A
      | acc nil := reverse acc
      | acc (x :: xs) :=
        if 
          | elemBy x acc := go acc xs
          | else := go (x :: acc) xs;
  in go nil;

nub {A} {{Eq A}} : List A -> List A := nubBy (==);

powerlists {A} : List A -> List (List A)
  | nil := nil :: nil
  | (x :: xs) :=
    let
      rest : List (List A) := powerlists xs;
      withX : List (List A) := map ((::) x) rest;
    in rest ++ withX;

import Stdlib.Data.Set as Set public;

open Set using {Set; difference; union; eqSetI; ordSetI; isSubset} public;

uniqueNumbers : Set Nat := Set.fromList [1; 2; 2; 2; 3];

setMap {A B} {{Ord B}} (f : A -> B) (set : Set A) : Set B :=
  Set.fromList (map f (Set.toList set));

setJoin {A} {{Ord A}} (sets : Set (Set A)) : Set A :=
  for (acc := Set.empty) (innerSet in sets) {
    Set.union acc innerSet
  };

--- Computes the disjoint union of two ;Set;s.
disjointUnion {T} {{Ord T}} (s1 s2 : Set T) : Result (Set T) (Set T) :=
  case Set.intersection s1 s2 of
    | Set.empty := ok (Set.union s1 s2)
    | s := error s;

symmetricDifference {A} {{Ord A}} (s1 s2 : Set A) : Set A :=
  let
    in1not2 := difference s1 s2;
    in2not1 := difference s2 s1;
  in union in1not2 in2not1;

cartesianProduct
  {A B} {{Ord A}} {{Ord B}} (s1 : Set A) (s2 : Set B) : Set (Pair A B) :=
  let
    -- For a fixed element from set1, create a set of all pairs with elements from s2
    pairsForElement (a : A) : Set (Pair A B) :=
      for (acc := Set.empty) (b in s2) {
        Set.insert (mkPair a b) acc
      };
    
    -- Create set of sets, each containing pairs for one element from s1
    pairSets : Set (Set (Pair A B)) :=
      for (acc := Set.empty) (a in s1) {
        Set.insert (pairsForElement a) acc
      };
  in setJoin pairSets;

powerset {A} {{Ord A}} (s : Set A) : Set (Set A) :=
  let
    elements := Set.toList s;
    subLists := powerlists elements;
  in Set.fromList (map Set.fromList subLists);

isProperSubset {A} {{Eq A}} {{Ord A}} (set1 set2 : Set A) : Bool :=
  isSubset set1 set2 && not (set1 == set2);

import Stdlib.Data.Map as Map public;

open Map using {Map} public;

codeToken : Map Nat String := Map.fromList [1, "BTC"; 2, "ETH"; 3, "ANM"];

updateLookupWithKey
  {Key Value}
  {{Ord Key}}
  (updateFn : Key -> Value -> Option Value)
  (k : Key)
  (map : Map Key Value)
  : Pair (Option Value) (Map Key Value) :=
  let
    oldValue : Option Value := Map.lookup k map;
    newMap : Map Key Value :=
      case oldValue of
        | none := map
        | some v :=
          case updateFn k v of
            | none := Map.delete k map
            | some newV := Map.insert k newV map;
  in oldValue, newMap;

mapKeys
  {Key1 Key2 Value}
  {{Ord Key2}}
  (fun : Key1 -> Key2)
  (map : Map Key1 Value)
  : Map Key2 Value :=
  Map.fromList
    (for (acc := nil) (k, v in Map.toList map) {
      (fun k, v) :: acc
    });

restrictKeys
  {Key Value}
  {{Ord Key}}
  (map : Map Key Value)
  (validKeys : Set.Set Key)
  : Map Key Value :=
  for (acc := Map.empty) (k, v in map) {
    if 
      | Set.isMember k validKeys := Map.insert k v acc
      | else := acc
  };

withoutKeys
  {Key Value}
  {{Ord Key}}
  (map : Map Key Value)
  (invalidKeys : Set.Set Key)
  : Map Key Value :=
  for (acc := Map.empty) (k, v in map) {
    if 
      | Set.isMember k invalidKeys := acc
      | else := Map.insert k v acc
  };

mapPartition
  {Key Value}
  {{Ord Key}}
  (predicate : Value -> Bool)
  (map : Map Key Value)
  : Pair (Map Key Value) (Map Key Value) :=
  for (matching, nonMatching := Map.empty, Map.empty) (k, v in map) {
    if 
      | predicate v := Map.insert k v matching, nonMatching
      | else := matching, Map.insert k v nonMatching
  };

partitionWithKey
  {Key Value}
  {{Ord Key}}
  (predicate : Key -> Value -> Bool)
  (map : Map Key Value)
  : Pair (Map Key Value) (Map Key Value) :=
  for (matching, nonMatching := Map.empty, Map.empty) (k, v in map) {
    if 
      | predicate k v := Map.insert k v matching, nonMatching
      | else := matching, Map.insert k v nonMatching
  };

mapOption
  {Key Value1 Value2}
  {{Ord Key}}
  (f : Value1 -> Option Value2)
  (map : Map Key Value1)
  : Map Key Value2 :=
  for (acc := Map.empty) (k, v in map) {
    case f v of
      | none := acc
      | some v' := Map.insert k v' acc
  };

mapOptionWithKey
  {Key Value1 Value2}
  {{Ord Key}}
  (f : Key -> Value1 -> Option Value2)
  (map : Map Key Value1)
  : Map Key Value2 :=
  for (acc := Map.empty) (k, v in map) {
    case f k v of
      | none := acc
      | some v' := Map.insert k v' acc
  };

mapEither
  {Key Value Error Result}
  {{Ord Key}}
  (f : Value -> Either Error Result)
  (map : Map Key Value)
  : Pair (Map Key Error) (Map Key Result) :=
  for (lefts, rights := Map.empty, Map.empty) (k, v in map) {
    case f v of
      | error e := Map.insert k e lefts, rights
      | ok r := lefts, Map.insert k r rights
  };

mapEitherWithKey
  {Key Value Error Result}
  {{Ord Key}}
  (f : Key -> Value -> Either Error Result)
  (map : Map Key Value)
  : Pair (Map Key Error) (Map Key Result) :=
  for (lefts, rights := Map.empty, Map.empty) (k, v in map) {
    case f k v of
      | error e := Map.insert k e lefts, rights
      | ok r := lefts, Map.insert k r rights
  };

axiom undef {A} : A;

axiom TODO {A} : A;

import Stdlib.Data.Set.AVL as AVLTree public;

open AVLTree using {AVLTree} public;