public

Corecords.hs

ownermangoivcreated04.11.2024uuidd204d07f-6292-4eb7-9aff-4cabc78394b9
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordUpdate #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE RebindableSyntax #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ImplicitPrelude #-}
{-# OPTIONS_GHC -Wall -Wno-unused-imports 
    -fno-show-error-context -fprint-explicit-kinds #-}

module Corecords where

import Data.Bool (bool)
import Data.Coerce (coerce)
import Data.Kind (Constraint, Type)
import Data.SOP (NP (..), Proxy (Proxy))
import GHC.TypeLits (Symbol)

-- The following machinery makes this possible:
--
-- named arguments to a "normal" function using recordupdate syntax
--
-- >>> ex
-- 65.7
ex :: Double
ex = exp' {base = 3, exponent = 2, coefficient = 7.3}

exp' :: Double :# "coefficient" -> Double :# "base" -> Int :# "exponent" -> Double
exp' a b x = unTag a * (unTag b ^ unTag x)

-- It can also pick out arguments that are named from the function
namedPlus3 :: Int -> Int :# "arg1" -> String -> Bool :# "arg2" -> Int :# "arg3" -> Int
namedPlus3 d a a' b c = fromIntegral $ unTag a + bool 42 69 (unTag b) - unTag c + length a' + d

np3 :: Int :# "arg1" -> Int -> [Char] -> Int
np3 = namedPlus3 {arg3 = 4, arg2 = True}

type (:#) :: Type -> Symbol -> Type
newtype t :# s = Tag {unTag :: t}
  deriving stock (Eq, Ord, Show)

type NamedArgs :: Type -> [(Type, Symbol)]
type family NamedArgs f where
  -- we care about the arguents that have an annotation
  NamedArgs (typ :# s -> r) = '(typ, s) : NamedArgs r
  -- we don't care about the rest of the function
  NamedArgs (typ -> r) = NamedArgs r
  -- we have reached the tail which is not a function
  NamedArgs r = '[]

type RestOfFunction :: Type -> Type
type family RestOfFunction f where
  -- discard the named arg
  RestOfFunction (typ :# s -> r) = RestOfFunction r
  -- keep the unnamed args
  RestOfFunction (typ -> r) = typ -> RestOfFunction r
  -- reached the end
  RestOfFunction r = r

type Snd :: (a, b) -> b
type family Snd tup where
  Snd '(a, b) = b

type Fst :: (a, b) -> a
type family Fst tup where
  Fst '(a, b) = a

type Named :: (Type, Symbol) -> Type
newtype Named tup = MkNamed {unNamed :: Fst tup}

class RewriteToCurried a where
  collectNamedArgs :: a -> NP Named (NamedArgs a) -> RestOfFunction a

instance (RewriteToCurried r) => RewriteToCurried (typ :# s -> r) where
  collectNamedArgs f (x :* xs) = collectNamedArgs (f (coerce x)) xs

instance
  {-# OVERLAPPABLE #-}
  ( RewriteToCurried r
  , RestOfFunction (s -> r) ~ (s -> RestOfFunction r)
  , NamedArgs (s -> r) ~ NamedArgs r
  )
  => RewriteToCurried (s -> r)
  where
  collectNamedArgs f xs x = collectNamedArgs @r (f x) xs

instance
  {-# OVERLAPPABLE #-}
  ( NamedArgs r ~ '[] -- this is kinda suspicious, we don't need the fact that we're matching on Nil
  , RestOfFunction r ~ r
  )
  => RewriteToCurried r
  where
  collectNamedArgs f _ = f

type SetField :: Symbol -> Type -> [(Type, Symbol)] -> Constraint
class SetField name a xs | xs -> a name where
  setField' :: forall r. (NP Named xs -> r) -> a -> NP Named (Removed name a xs) -> r

type Removed :: Symbol -> Type -> [(Type, Symbol)] -> [(Type, Symbol)]
type family Removed name a xs where
  Removed name a '[] = '[]
  Removed name a ('(b, name) : xs) = xs
  Removed name a ('(b, other) : xs) = ('(b, other) : Removed name a xs)

instance {-# OVERLAPPING #-} (a ~ b) => SetField name a ('(b, name) : xs) where
  setField' f a xs = f (MkNamed a :* xs)

instance (SetField name a xs, Removed name a ('(b, other) : xs) ~ '(b, other) : Removed name a xs) => SetField name a ('(b, other) : xs) where
  setField' f a = \case
    (x :* xs) -> setField' @name @a @xs (\ys -> f (x :* ys)) a xs

class ExpandArgs xs r where
  type Expanded xs r :: Type
  expandArgs :: (NP Named xs -> r) -> Expanded xs r

instance ExpandArgs '[] r where
  type Expanded '[] r = r
  expandArgs f = f Nil

instance (ExpandArgs xs r) => ExpandArgs ('(typ, name) : xs) r where
  type Expanded ('(typ, name) : xs) r = typ :# name -> Expanded xs r
  expandArgs f n = expandArgs @xs @r \xs -> f (coerce n :* xs)

getField :: a
getField = undefined

setField
  :: forall name infn field
   . ( ExpandArgs (Removed name field (NamedArgs infn)) (RestOfFunction infn)
     , SetField name field (NamedArgs infn)
     , RewriteToCurried infn
     )
  => infn
  -> field
  -> Expanded (Removed name field (NamedArgs infn)) (RestOfFunction infn)
setField fn a = expandArgs (setField' @name (collectNamedArgs fn) a)