{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordUpdate #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE RebindableSyntax #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ImplicitPrelude #-}
{-# OPTIONS_GHC -Wall -Wno-unused-imports
-fno-show-error-context -fprint-explicit-kinds #-}
module Corecords where
import Data.Bool (bool)
import Data.Coerce (coerce)
import Data.Kind (Constraint, Type)
import Data.SOP (NP (..), Proxy (Proxy))
import GHC.TypeLits (Symbol)
-- The following machinery makes this possible:
--
-- named arguments to a "normal" function using recordupdate syntax
--
-- >>> ex
-- 65.7
ex :: Double
ex = exp' {base = 3, exponent = 2, coefficient = 7.3}
exp' :: Double :# "coefficient" -> Double :# "base" -> Int :# "exponent" -> Double
exp' a b x = unTag a * (unTag b ^ unTag x)
-- It can also pick out arguments that are named from the function
namedPlus3 :: Int -> Int :# "arg1" -> String -> Bool :# "arg2" -> Int :# "arg3" -> Int
namedPlus3 d a a' b c = fromIntegral $ unTag a + bool 42 69 (unTag b) - unTag c + length a' + d
np3 :: Int :# "arg1" -> Int -> [Char] -> Int
np3 = namedPlus3 {arg3 = 4, arg2 = True}
type (:#) :: Type -> Symbol -> Type
newtype t :# s = Tag {unTag :: t}
deriving stock (Eq, Ord, Show)
type NamedArgs :: Type -> [(Type, Symbol)]
type family NamedArgs f where
-- we care about the arguents that have an annotation
NamedArgs (typ :# s -> r) = '(typ, s) : NamedArgs r
-- we don't care about the rest of the function
NamedArgs (typ -> r) = NamedArgs r
-- we have reached the tail which is not a function
NamedArgs r = '[]
type RestOfFunction :: Type -> Type
type family RestOfFunction f where
-- discard the named arg
RestOfFunction (typ :# s -> r) = RestOfFunction r
-- keep the unnamed args
RestOfFunction (typ -> r) = typ -> RestOfFunction r
-- reached the end
RestOfFunction r = r
type Snd :: (a, b) -> b
type family Snd tup where
Snd '(a, b) = b
type Fst :: (a, b) -> a
type family Fst tup where
Fst '(a, b) = a
type Named :: (Type, Symbol) -> Type
newtype Named tup = MkNamed {unNamed :: Fst tup}
class RewriteToCurried a where
collectNamedArgs :: a -> NP Named (NamedArgs a) -> RestOfFunction a
instance (RewriteToCurried r) => RewriteToCurried (typ :# s -> r) where
collectNamedArgs f (x :* xs) = collectNamedArgs (f (coerce x)) xs
instance
{-# OVERLAPPABLE #-}
( RewriteToCurried r
, RestOfFunction (s -> r) ~ (s -> RestOfFunction r)
, NamedArgs (s -> r) ~ NamedArgs r
)
=> RewriteToCurried (s -> r)
where
collectNamedArgs f xs x = collectNamedArgs @r (f x) xs
instance
{-# OVERLAPPABLE #-}
( NamedArgs r ~ '[] -- this is kinda suspicious, we don't need the fact that we're matching on Nil
, RestOfFunction r ~ r
)
=> RewriteToCurried r
where
collectNamedArgs f _ = f
type SetField :: Symbol -> Type -> [(Type, Symbol)] -> Constraint
class SetField name a xs | xs -> a name where
setField' :: forall r. (NP Named xs -> r) -> a -> NP Named (Removed name a xs) -> r
type Removed :: Symbol -> Type -> [(Type, Symbol)] -> [(Type, Symbol)]
type family Removed name a xs where
Removed name a '[] = '[]
Removed name a ('(b, name) : xs) = xs
Removed name a ('(b, other) : xs) = ('(b, other) : Removed name a xs)
instance {-# OVERLAPPING #-} (a ~ b) => SetField name a ('(b, name) : xs) where
setField' f a xs = f (MkNamed a :* xs)
instance (SetField name a xs, Removed name a ('(b, other) : xs) ~ '(b, other) : Removed name a xs) => SetField name a ('(b, other) : xs) where
setField' f a = \case
(x :* xs) -> setField' @name @a @xs (\ys -> f (x :* ys)) a xs
class ExpandArgs xs r where
type Expanded xs r :: Type
expandArgs :: (NP Named xs -> r) -> Expanded xs r
instance ExpandArgs '[] r where
type Expanded '[] r = r
expandArgs f = f Nil
instance (ExpandArgs xs r) => ExpandArgs ('(typ, name) : xs) r where
type Expanded ('(typ, name) : xs) r = typ :# name -> Expanded xs r
expandArgs f n = expandArgs @xs @r \xs -> f (coerce n :* xs)
getField :: a
getField = undefined
setField
:: forall name infn field
. ( ExpandArgs (Removed name field (NamedArgs infn)) (RestOfFunction infn)
, SetField name field (NamedArgs infn)
, RewriteToCurried infn
)
=> infn
-> field
-> Expanded (Removed name field (NamedArgs infn)) (RestOfFunction infn)
setField fn a = expandArgs (setField' @name (collectNamedArgs fn) a)