291 lines
No EOL
7.7 KiB
Python
291 lines
No EOL
7.7 KiB
Python
from __future__ import annotations
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
|
class TypeVar:
|
|
_registry: dict = {}
|
|
|
|
def __init__(self, name: str, *constraints, bound=None,
|
|
covariant: bool = False, contravariant: bool = False):
|
|
if covariant and contravariant:
|
|
raise ValueError("the type var is both covariant and contravariant. impossible!!!!!!!!!!!!")
|
|
self.name = name
|
|
self.constraints = constraints
|
|
self.bound = bound
|
|
self.covariant = covariant
|
|
self.contravariant = contravariant
|
|
TypeVar._registry[name] = self
|
|
|
|
def __repr__(self) -> str:
|
|
extras = []
|
|
if self.bound:
|
|
extras.append(f"bound={self.bound!r}")
|
|
if self.covariant:
|
|
extras.append("covariant=True")
|
|
if self.contravariant:
|
|
extras.append("contravariant=True")
|
|
if self.constraints:
|
|
args = ", ".join(repr(c) for c in self.constraints)
|
|
return f"TypeVar({self.name!r}, {args})"
|
|
suffix = ", ".join(extras)
|
|
return f"TypeVar({self.name!r}{', ' + suffix if suffix else ''})"
|
|
|
|
def __class_getitem__(cls, item):
|
|
return cls
|
|
|
|
|
|
class _GenericAlias:
|
|
__slots__ = ("__origin__", "__args__", "__parameters__")
|
|
|
|
def __init__(self, origin, args):
|
|
self.__origin__ = origin
|
|
self.__args__ = args if isinstance(args, tuple) else (args,)
|
|
self.__parameters__ = tuple(a for a in self.__args__ if isinstance(a, TypeVar))
|
|
|
|
def __repr__(self) -> str:
|
|
def _name(a):
|
|
return a.__name__ if hasattr(a, "__name__") else repr(a)
|
|
return f"{self.__origin__.__name__}[{', '.join(_name(a) for a in self.__args__)}]"
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
return self.__origin__(*args, **kwargs)
|
|
|
|
def __instancecheck__(self, instance):
|
|
return isinstance(instance, self.__origin__)
|
|
|
|
def __subclasscheck__(self, subclass):
|
|
return issubclass(subclass, self.__origin__)
|
|
|
|
def __class_getitem__(cls, params):
|
|
return cls
|
|
|
|
def __getitem__(self, params):
|
|
return _GenericAlias(self.__origin__, params)
|
|
|
|
|
|
class Generic:
|
|
def __class_getitem__(cls, params):
|
|
return _GenericAlias(cls, params)
|
|
|
|
def __init_subclass__(cls, **kwargs):
|
|
super().__init_subclass__(**kwargs)
|
|
|
|
|
|
A = TypeVar("A")
|
|
B = TypeVar("B")
|
|
C = TypeVar("C")
|
|
S = TypeVar("S")
|
|
W = TypeVar("W")
|
|
R = TypeVar("R")
|
|
E = TypeVar("E")
|
|
F = TypeVar("F")
|
|
|
|
|
|
class Semigroup(ABC):
|
|
@abstractmethod
|
|
def combine(self, other: "Semigroup") -> "Semigroup":
|
|
...
|
|
|
|
def __add__(self, other: "Semigroup") -> "Semigroup":
|
|
return self.combine(other)
|
|
|
|
|
|
class Monoid(Semigroup, ABC):
|
|
@classmethod
|
|
@abstractmethod
|
|
def empty(cls) -> "Monoid":
|
|
...
|
|
|
|
@classmethod
|
|
def concat(cls, xs) -> "Monoid":
|
|
result = cls.empty()
|
|
for x in xs:
|
|
result = result.combine(x)
|
|
return result
|
|
|
|
|
|
class ListMonoid(Monoid):
|
|
def __init__(self, values=None):
|
|
self.values = list(values) if values is not None else []
|
|
|
|
def combine(self, other: "ListMonoid") -> "ListMonoid":
|
|
if not isinstance(other, ListMonoid):
|
|
raise TypeError(f"Cannot combine ListMonoid with {type(other).__name__}")
|
|
return ListMonoid(self.values + other.values)
|
|
|
|
@classmethod
|
|
def empty(cls) -> "ListMonoid":
|
|
return cls([])
|
|
|
|
def __repr__(self) -> str:
|
|
return f"ListMonoid({self.values!r})"
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
return isinstance(other, ListMonoid) and self.values == other.values
|
|
|
|
def __hash__(self) -> int:
|
|
return hash(("ListMonoid", tuple(self.values)))
|
|
|
|
def __iter__(self):
|
|
return iter(self.values)
|
|
|
|
def __len__(self) -> int:
|
|
return len(self.values)
|
|
|
|
|
|
class StringMonoid(Monoid):
|
|
def __init__(self, value: str = ""):
|
|
self.value = value
|
|
|
|
def combine(self, other: "StringMonoid") -> "StringMonoid":
|
|
if not isinstance(other, StringMonoid):
|
|
raise TypeError(f"Cannot combine StringMonoid with {type(other).__name__}")
|
|
return StringMonoid(self.value + other.value)
|
|
|
|
@classmethod
|
|
def empty(cls) -> "StringMonoid":
|
|
return cls("")
|
|
|
|
def __repr__(self) -> str:
|
|
return f"StringMonoid({self.value!r})"
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
return isinstance(other, StringMonoid) and self.value == other.value
|
|
|
|
def __hash__(self) -> int:
|
|
return hash(("StringMonoid", self.value))
|
|
|
|
|
|
class SumMonoid(Monoid):
|
|
def __init__(self, value=0):
|
|
self.value = value
|
|
|
|
def combine(self, other: "SumMonoid") -> "SumMonoid":
|
|
return SumMonoid(self.value + other.value)
|
|
|
|
@classmethod
|
|
def empty(cls) -> "SumMonoid":
|
|
return cls(0)
|
|
|
|
def __repr__(self) -> str:
|
|
return f"Sum({self.value})"
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
return isinstance(other, SumMonoid) and self.value == other.value
|
|
|
|
def __hash__(self) -> int:
|
|
return hash(("Sum", self.value))
|
|
|
|
|
|
class ProductMonoid(Monoid):
|
|
def __init__(self, value=1):
|
|
self.value = value
|
|
|
|
def combine(self, other: "ProductMonoid") -> "ProductMonoid":
|
|
return ProductMonoid(self.value * other.value)
|
|
|
|
@classmethod
|
|
def empty(cls) -> "ProductMonoid":
|
|
return cls(1)
|
|
|
|
def __repr__(self) -> str:
|
|
return f"Product({self.value})"
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
return isinstance(other, ProductMonoid) and self.value == other.value
|
|
|
|
def __hash__(self) -> int:
|
|
return hash(("Product", self.value))
|
|
|
|
|
|
class Functor(ABC, Generic):
|
|
@abstractmethod
|
|
def fmap(self, func) -> "Functor":
|
|
...
|
|
|
|
def __mod__(self, func) -> "Functor":
|
|
return self.fmap(func)
|
|
|
|
|
|
class Applicative(Functor, ABC):
|
|
@classmethod
|
|
@abstractmethod
|
|
def pure(cls, value) -> "Applicative":
|
|
...
|
|
|
|
@abstractmethod
|
|
def apply(self, wrapped_func: "Applicative") -> "Applicative":
|
|
...
|
|
|
|
def map2(self, other: "Applicative", func) -> "Applicative":
|
|
return other.apply(self.fmap(lambda a: lambda b: func(a, b)))
|
|
|
|
def __mul__(self, wrapped_func: "Applicative") -> "Applicative":
|
|
return self.apply(wrapped_func)
|
|
|
|
|
|
class Monad(Applicative, ABC):
|
|
@abstractmethod
|
|
def bind(self, func) -> "Monad":
|
|
...
|
|
|
|
def fmap(self, func) -> "Monad":
|
|
return self.bind(lambda x: self.__class__.pure(func(x)))
|
|
|
|
def apply(self, wrapped_func: "Applicative") -> "Monad":
|
|
return wrapped_func.bind(
|
|
lambda f: self.bind(lambda x: self.__class__.pure(f(x)))
|
|
)
|
|
|
|
def then(self, other: "Monad") -> "Monad":
|
|
return self.bind(lambda _: other)
|
|
|
|
def __rshift__(self, func) -> "Monad":
|
|
return self.bind(func)
|
|
|
|
def __or__(self, other: "Monad") -> "Monad":
|
|
return self.then(other)
|
|
|
|
@classmethod
|
|
def do(cls, gen_func) -> "Monad":
|
|
return _drive_do(gen_func)
|
|
|
|
|
|
def _drive_do(gen_func) -> "Monad":
|
|
gen = gen_func()
|
|
|
|
def step(value):
|
|
try:
|
|
next_monad = gen.send(value)
|
|
return next_monad.bind(step)
|
|
except StopIteration as exc:
|
|
return exc.value
|
|
|
|
try:
|
|
first = next(gen)
|
|
return first.bind(step)
|
|
except StopIteration as exc:
|
|
return exc.value
|
|
|
|
|
|
class MonadTransformer(ABC):
|
|
@classmethod
|
|
@abstractmethod
|
|
def lift(cls, monad: "Monad") -> "MonadTransformer":
|
|
...
|
|
|
|
@abstractmethod
|
|
def run(self) -> "Monad":
|
|
...
|
|
|
|
@abstractmethod
|
|
def bind(self, func) -> "MonadTransformer":
|
|
...
|
|
|
|
@classmethod
|
|
@abstractmethod
|
|
def pure(cls, value) -> "MonadTransformer":
|
|
...
|
|
|
|
def __rshift__(self, func) -> "MonadTransformer":
|
|
return self.bind(func) |