py-fp/lib/typeclasses.py
2026-04-10 00:21:44 +02:00

291 lines
No EOL
7.7 KiB
Python

from __future__ import annotations
from abc import ABC, abstractmethod
class TypeVar:
_registry: dict = {}
def __init__(self, name: str, *constraints, bound=None,
covariant: bool = False, contravariant: bool = False):
if covariant and contravariant:
raise ValueError("the type var is both covariant and contravariant. impossible!!!!!!!!!!!!")
self.name = name
self.constraints = constraints
self.bound = bound
self.covariant = covariant
self.contravariant = contravariant
TypeVar._registry[name] = self
def __repr__(self) -> str:
extras = []
if self.bound:
extras.append(f"bound={self.bound!r}")
if self.covariant:
extras.append("covariant=True")
if self.contravariant:
extras.append("contravariant=True")
if self.constraints:
args = ", ".join(repr(c) for c in self.constraints)
return f"TypeVar({self.name!r}, {args})"
suffix = ", ".join(extras)
return f"TypeVar({self.name!r}{', ' + suffix if suffix else ''})"
def __class_getitem__(cls, item):
return cls
class _GenericAlias:
__slots__ = ("__origin__", "__args__", "__parameters__")
def __init__(self, origin, args):
self.__origin__ = origin
self.__args__ = args if isinstance(args, tuple) else (args,)
self.__parameters__ = tuple(a for a in self.__args__ if isinstance(a, TypeVar))
def __repr__(self) -> str:
def _name(a):
return a.__name__ if hasattr(a, "__name__") else repr(a)
return f"{self.__origin__.__name__}[{', '.join(_name(a) for a in self.__args__)}]"
def __call__(self, *args, **kwargs):
return self.__origin__(*args, **kwargs)
def __instancecheck__(self, instance):
return isinstance(instance, self.__origin__)
def __subclasscheck__(self, subclass):
return issubclass(subclass, self.__origin__)
def __class_getitem__(cls, params):
return cls
def __getitem__(self, params):
return _GenericAlias(self.__origin__, params)
class Generic:
def __class_getitem__(cls, params):
return _GenericAlias(cls, params)
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
A = TypeVar("A")
B = TypeVar("B")
C = TypeVar("C")
S = TypeVar("S")
W = TypeVar("W")
R = TypeVar("R")
E = TypeVar("E")
F = TypeVar("F")
class Semigroup(ABC):
@abstractmethod
def combine(self, other: "Semigroup") -> "Semigroup":
...
def __add__(self, other: "Semigroup") -> "Semigroup":
return self.combine(other)
class Monoid(Semigroup, ABC):
@classmethod
@abstractmethod
def empty(cls) -> "Monoid":
...
@classmethod
def concat(cls, xs) -> "Monoid":
result = cls.empty()
for x in xs:
result = result.combine(x)
return result
class ListMonoid(Monoid):
def __init__(self, values=None):
self.values = list(values) if values is not None else []
def combine(self, other: "ListMonoid") -> "ListMonoid":
if not isinstance(other, ListMonoid):
raise TypeError(f"Cannot combine ListMonoid with {type(other).__name__}")
return ListMonoid(self.values + other.values)
@classmethod
def empty(cls) -> "ListMonoid":
return cls([])
def __repr__(self) -> str:
return f"ListMonoid({self.values!r})"
def __eq__(self, other: object) -> bool:
return isinstance(other, ListMonoid) and self.values == other.values
def __hash__(self) -> int:
return hash(("ListMonoid", tuple(self.values)))
def __iter__(self):
return iter(self.values)
def __len__(self) -> int:
return len(self.values)
class StringMonoid(Monoid):
def __init__(self, value: str = ""):
self.value = value
def combine(self, other: "StringMonoid") -> "StringMonoid":
if not isinstance(other, StringMonoid):
raise TypeError(f"Cannot combine StringMonoid with {type(other).__name__}")
return StringMonoid(self.value + other.value)
@classmethod
def empty(cls) -> "StringMonoid":
return cls("")
def __repr__(self) -> str:
return f"StringMonoid({self.value!r})"
def __eq__(self, other: object) -> bool:
return isinstance(other, StringMonoid) and self.value == other.value
def __hash__(self) -> int:
return hash(("StringMonoid", self.value))
class SumMonoid(Monoid):
def __init__(self, value=0):
self.value = value
def combine(self, other: "SumMonoid") -> "SumMonoid":
return SumMonoid(self.value + other.value)
@classmethod
def empty(cls) -> "SumMonoid":
return cls(0)
def __repr__(self) -> str:
return f"Sum({self.value})"
def __eq__(self, other: object) -> bool:
return isinstance(other, SumMonoid) and self.value == other.value
def __hash__(self) -> int:
return hash(("Sum", self.value))
class ProductMonoid(Monoid):
def __init__(self, value=1):
self.value = value
def combine(self, other: "ProductMonoid") -> "ProductMonoid":
return ProductMonoid(self.value * other.value)
@classmethod
def empty(cls) -> "ProductMonoid":
return cls(1)
def __repr__(self) -> str:
return f"Product({self.value})"
def __eq__(self, other: object) -> bool:
return isinstance(other, ProductMonoid) and self.value == other.value
def __hash__(self) -> int:
return hash(("Product", self.value))
class Functor(ABC, Generic):
@abstractmethod
def fmap(self, func) -> "Functor":
...
def __mod__(self, func) -> "Functor":
return self.fmap(func)
class Applicative(Functor, ABC):
@classmethod
@abstractmethod
def pure(cls, value) -> "Applicative":
...
@abstractmethod
def apply(self, wrapped_func: "Applicative") -> "Applicative":
...
def map2(self, other: "Applicative", func) -> "Applicative":
return other.apply(self.fmap(lambda a: lambda b: func(a, b)))
def __mul__(self, wrapped_func: "Applicative") -> "Applicative":
return self.apply(wrapped_func)
class Monad(Applicative, ABC):
@abstractmethod
def bind(self, func) -> "Monad":
...
def fmap(self, func) -> "Monad":
return self.bind(lambda x: self.__class__.pure(func(x)))
def apply(self, wrapped_func: "Applicative") -> "Monad":
return wrapped_func.bind(
lambda f: self.bind(lambda x: self.__class__.pure(f(x)))
)
def then(self, other: "Monad") -> "Monad":
return self.bind(lambda _: other)
def __rshift__(self, func) -> "Monad":
return self.bind(func)
def __or__(self, other: "Monad") -> "Monad":
return self.then(other)
@classmethod
def do(cls, gen_func) -> "Monad":
return _drive_do(gen_func)
def _drive_do(gen_func) -> "Monad":
gen = gen_func()
def step(value):
try:
next_monad = gen.send(value)
return next_monad.bind(step)
except StopIteration as exc:
return exc.value
try:
first = next(gen)
return first.bind(step)
except StopIteration as exc:
return exc.value
class MonadTransformer(ABC):
@classmethod
@abstractmethod
def lift(cls, monad: "Monad") -> "MonadTransformer":
...
@abstractmethod
def run(self) -> "Monad":
...
@abstractmethod
def bind(self, func) -> "MonadTransformer":
...
@classmethod
@abstractmethod
def pure(cls, value) -> "MonadTransformer":
...
def __rshift__(self, func) -> "MonadTransformer":
return self.bind(func)