py-fp/lib/monads.py
2026-04-10 00:21:44 +02:00

649 lines
No EOL
16 KiB
Python

from __future__ import annotations
from abc import abstractmethod
from .typeclasses import (
Applicative,
Generic,
Monad,
MonadTransformer,
Monoid,
ListMonoid,
)
from .classmethod import Classmethod
class Maybe(Monad):
@classmethod
def pure(cls, value) -> "Maybe":
return Just(value)
@classmethod
def of(cls, value) -> "Maybe":
return Nothing() if value is None else Just(value)
@classmethod
def from_either(cls, either: "Either") -> "Maybe":
return Just(either._value) if either.is_right() else Nothing()
@abstractmethod
def bind(self, func) -> "Maybe": ...
@abstractmethod
def is_just(self) -> bool: ...
def is_nothing(self) -> bool:
return not self.is_just()
@abstractmethod
def get_or(self, default): ...
def get_or_raise(self, exc=None):
if self.is_just():
return self.get_or(None)
raise (exc if exc is not None else ValueError("Nothing.get_or_raise"))
def filter(self, predicate) -> "Maybe":
return self.bind(lambda x: Just(x) if predicate(x) else Nothing())
def or_else(self, alternative: "Maybe") -> "Maybe":
return self if self.is_just() else alternative
def to_either(self, error) -> "Either":
if self.is_just():
return Right(self.get_or(None))
return Left(error)
def to_list(self) -> "List":
return List([self.get_or(None)]) if self.is_just() else List([])
def __iter__(self):
if self.is_just():
yield self.get_or(None)
def __bool__(self) -> bool:
return self.is_just()
@abstractmethod
def __eq__(self, other: object) -> bool: ...
@abstractmethod
def __hash__(self) -> int: ...
class Just(Maybe):
__slots__ = ("_value",)
def __init__(self, value):
self._value = value
@classmethod
def pure(cls, value) -> "Just":
return cls(value)
def bind(self, func) -> Maybe:
return func(self._value)
def fmap(self, func) -> Maybe:
return Just(func(self._value))
def apply(self, wrapped_func: Applicative) -> Maybe:
if isinstance(wrapped_func, Just):
return Just(wrapped_func._value(self._value))
return Nothing()
def get_or(self, default):
return self._value
def is_just(self) -> bool:
return True
def __repr__(self) -> str:
return f"Just({self._value!r})"
def __eq__(self, other: object) -> bool:
return isinstance(other, Just) and self._value == other._value
def __hash__(self) -> int:
return hash(("Just", self._value))
class Nothing(Maybe):
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
@classmethod
def pure(cls, value) -> "Just":
return Just(value)
def bind(self, func) -> "Nothing":
return self
def fmap(self, func) -> "Nothing":
return self
def apply(self, wrapped_func: Applicative) -> "Nothing":
return self
def get_or(self, default):
return default
def is_just(self) -> bool:
return False
def __repr__(self) -> str:
return "Nothing"
def __eq__(self, other: object) -> bool:
return isinstance(other, Nothing)
def __hash__(self) -> int:
return hash("Nothing")
class Either(Monad):
@classmethod
def pure(cls, value) -> "Right":
return Right(value)
@classmethod
def try_(cls, func, *args, **kwargs) -> "Either":
try:
return Right(func(*args, **kwargs))
except Exception as exc:
return Left(exc)
@classmethod
def from_maybe(cls, maybe: Maybe, error) -> "Either":
return Right(maybe.get_or(None)) if maybe.is_just() else Left(error)
@abstractmethod
def bind(self, func) -> "Either": ...
@abstractmethod
def is_right(self) -> bool: ...
def is_left(self) -> bool:
return not self.is_right()
@abstractmethod
def get_or(self, default): ...
def get_or_raise(self):
if self.is_right():
return self.get_or(None)
err = self._error if hasattr(self, "_error") else ValueError("Left")
if isinstance(err, BaseException):
raise err
raise ValueError(err)
@abstractmethod
def map_left(self, func) -> "Either": ...
@abstractmethod
def swap(self) -> "Either": ...
def to_maybe(self) -> Maybe:
return Just(self.get_or(None)) if self.is_right() else Nothing()
def __bool__(self) -> bool:
return self.is_right()
@abstractmethod
def __eq__(self, other: object) -> bool: ...
@abstractmethod
def __hash__(self) -> int: ...
class Right(Either):
__slots__ = ("_value",)
def __init__(self, value):
self._value = value
@classmethod
def pure(cls, value) -> "Right":
return cls(value)
def bind(self, func) -> Either:
try:
return func(self._value)
except Exception as exc:
return Left(exc)
def fmap(self, func) -> Either:
try:
return Right(func(self._value))
except Exception as exc:
return Left(exc)
def apply(self, wrapped_func: Applicative) -> Either:
if isinstance(wrapped_func, Right):
return Right(wrapped_func._value(self._value))
return wrapped_func
def get_or(self, default):
return self._value
def map_left(self, func) -> "Right":
return self
def swap(self) -> "Left":
return Left(self._value)
def is_right(self) -> bool:
return True
def __repr__(self) -> str:
return f"Right({self._value!r})"
def __eq__(self, other: object) -> bool:
return isinstance(other, Right) and self._value == other._value
def __hash__(self) -> int:
return hash(("Right", self._value))
class Left(Either):
__slots__ = ("_error",)
def __init__(self, error):
self._error = error
@classmethod
def pure(cls, value) -> "Right":
return Right(value)
def bind(self, func) -> "Left":
return self
def fmap(self, func) -> "Left":
return self
def apply(self, wrapped_func: Applicative) -> "Left":
return self
def get_or(self, default):
return default
def map_left(self, func) -> "Left":
return Left(func(self._error))
def swap(self) -> "Right":
return Right(self._error)
def is_right(self) -> bool:
return False
def __repr__(self) -> str:
return f"Left({self._error!r})"
def __eq__(self, other: object) -> bool:
return isinstance(other, Left) and self._error == other._error
def __hash__(self) -> int:
return hash(("Left", self._error))
class List(Monad):
__slots__ = ("_values",)
def __init__(self, values=None):
self._values = list(values) if values is not None else []
@classmethod
def pure(cls, value) -> "List":
return cls([value])
@classmethod
def empty(cls) -> "List":
return cls([])
@classmethod
def range(cls, *args) -> "List":
return cls(range(*args))
def bind(self, func) -> "List":
result = []
for v in self._values:
out = func(v)
if isinstance(out, List):
result.extend(out._values)
else:
result.append(out)
return List(result)
def fmap(self, func) -> "List":
return List(func(v) for v in self._values)
def apply(self, wrapped_func: "List") -> "List":
return wrapped_func.bind(lambda f: self.fmap(f))
def filter(self, predicate) -> "List":
return List(v for v in self._values if predicate(v))
def concat(self, other: "List") -> "List":
return List(self._values + other._values)
def head(self) -> Maybe:
return Just(self._values[0]) if self._values else Nothing()
def tail(self) -> "List":
return List(self._values[1:])
def last(self) -> Maybe:
return Just(self._values[-1]) if self._values else Nothing()
def take(self, n: int) -> "List":
return List(self._values[:n])
def drop(self, n: int) -> "List":
return List(self._values[n:])
def zip_with(self, other: "List", func) -> "List":
return List(func(a, b) for a, b in zip(self._values, other._values))
def fold(self, func, initial):
acc = initial
for v in self._values:
acc = func(acc, v)
return acc
def traverse_maybe(self, func) -> Maybe:
results = []
for v in self._values:
m = func(v)
if m.is_nothing():
return Nothing()
results.append(m.get_or(None))
return Just(List(results))
def sequence_maybe(self) -> Maybe:
return self.traverse_maybe(lambda x: x)
def traverse_either(self, func) -> Either:
results = []
for v in self._values:
e = func(v)
if e.is_left():
return e
results.append(e.get_or(None))
return Right(List(results))
def sequence_either(self) -> Either:
return self.traverse_either(lambda x: x)
def __len__(self) -> int:
return len(self._values)
def __iter__(self):
return iter(self._values)
def __getitem__(self, index):
return self._values[index]
def __contains__(self, item) -> bool:
return item in self._values
def __repr__(self) -> str:
return f"List({self._values!r})"
def __eq__(self, other: object) -> bool:
return isinstance(other, List) and self._values == other._values
def __hash__(self) -> int:
return hash(("List", tuple(self._values)))
def __add__(self, other: "List") -> "List":
return self.concat(other)
class IO(Monad):
__slots__ = ("_thunk",)
def __init__(self, thunk):
self._thunk = thunk
@classmethod
def pure(cls, value) -> "IO":
return cls(lambda: value)
@classmethod
def from_callable(cls, func, *args, **kwargs) -> "IO":
return cls(lambda: func(*args, **kwargs))
@classmethod
def sequence(cls, ios) -> "IO":
def run_all():
return List([io.run() for io in ios])
return cls(run_all)
def bind(self, func) -> "IO":
return IO(lambda: func(self._thunk()).run())
def fmap(self, func) -> "IO":
return IO(lambda: func(self._thunk()))
def apply(self, wrapped_func: "IO") -> "IO":
return IO(lambda: wrapped_func._thunk()(self._thunk()))
def then(self, other: "IO") -> "IO":
return IO(lambda: (self._thunk(), other.run())[1])
def run(self):
return self._thunk()
def __repr__(self) -> str:
return "IO(<thunk>)"
class Writer(Monad):
__slots__ = ("_value", "_log")
def __init__(self, value, log: Monoid):
self._value = value
self._log = log
@classmethod
def pure(cls, value) -> "Writer":
return cls(value, ListMonoid([]))
@classmethod
def tell(cls, log: Monoid) -> "Writer":
return cls(None, log)
@classmethod
def with_log(cls, value, log: Monoid) -> "Writer":
return cls(value, log)
def bind(self, func) -> "Writer":
new_writer = func(self._value)
return Writer(new_writer._value, self._log.combine(new_writer._log))
def fmap(self, func) -> "Writer":
return Writer(func(self._value), self._log)
def apply(self, wrapped_func: "Writer") -> "Writer":
return Writer(
wrapped_func._value(self._value),
self._log.combine(wrapped_func._log),
)
def run(self):
return (self._value, self._log)
@property
def value(self):
return self._value
@property
def log(self):
return self._log
def listen(self) -> "Writer":
return Writer((self._value, self._log), self._log)
def censor(self, func) -> "Writer":
return Writer(self._value, func(self._log))
def __repr__(self) -> str:
return f"Writer({self._value!r}, {self._log!r})"
def __eq__(self, other: object) -> bool:
return (
isinstance(other, Writer)
and self._value == other._value
and self._log == other._log
)
def __hash__(self) -> int:
return hash(("Writer", self._value))
class State(Monad):
__slots__ = ("_run",)
def __init__(self, run_func):
self._run = run_func
@classmethod
def pure(cls, value) -> "State":
return cls(lambda s: (value, s))
@classmethod
def get(cls) -> "State":
return cls(lambda s: (s, s))
@classmethod
def put(cls, new_state) -> "State":
return cls(lambda _: (None, new_state))
@classmethod
def modify(cls, func) -> "State":
return cls(lambda s: (None, func(s)))
@classmethod
def gets(cls, func) -> "State":
return cls(lambda s: (func(s), s))
def bind(self, func) -> "State":
def _run(state):
value, new_state = self._run(state)
return func(value)._run(new_state)
return State(_run)
def fmap(self, func) -> "State":
def _run(state):
value, new_state = self._run(state)
return (func(value), new_state)
return State(_run)
def apply(self, wrapped_func: "State") -> "State":
def _run(state):
f, state1 = wrapped_func._run(state)
value, state2 = self._run(state1)
return (f(value), state2)
return State(_run)
def run(self, initial_state):
return self._run(initial_state)
def eval(self, initial_state):
return self._run(initial_state)[0]
def exec(self, initial_state):
return self._run(initial_state)[1]
def __repr__(self) -> str:
return "State(<s → (a, s)>)"
class Reader(Monad):
__slots__ = ("_run",)
def __init__(self, run_func):
self._run = run_func
@classmethod
def pure(cls, value) -> "Reader":
return cls(lambda _: value)
@classmethod
def ask(cls) -> "Reader":
return cls(lambda env: env)
@classmethod
def asks(cls, func) -> "Reader":
return cls(lambda env: func(env))
def bind(self, func) -> "Reader":
return Reader(lambda env: func(self._run(env))._run(env))
def fmap(self, func) -> "Reader":
return Reader(lambda env: func(self._run(env)))
def apply(self, wrapped_func: "Reader") -> "Reader":
return Reader(lambda env: wrapped_func._run(env)(self._run(env)))
def local(self, func) -> "Reader":
return Reader(lambda env: self._run(func(env)))
def run(self, env):
return self._run(env)
def __repr__(self) -> str:
return "Reader(<env → a>)"
class MaybeT(MonadTransformer):
__slots__ = ("_inner",)
def __init__(self, inner: Monad):
self._inner = inner
@classmethod
def pure(cls, value) -> "MaybeT":
raise NotImplementedError("Use MaybeT.pure_with(base_cls, value)")
@classmethod
def pure_with(cls, base_cls, value) -> "MaybeT":
return cls(base_cls.pure(Just(value)))
@classmethod
def nothing_with(cls, base_cls) -> "MaybeT":
return cls(base_cls.pure(Nothing()))
@classmethod
def lift(cls, monad: Monad) -> "MaybeT":
return cls(monad.fmap(Just))
def bind(self, func) -> "MaybeT":
def _step(maybe_value):
if maybe_value.is_nothing():
return self._inner.__class__.pure(Nothing())
return func(maybe_value.get_or(None))._inner
return MaybeT(self._inner.bind(_step))
def fmap(self, func) -> "MaybeT":
return MaybeT(self._inner.fmap(lambda m: m.fmap(func)))
def apply(self, wrapped_func: "MaybeT") -> "MaybeT":
return wrapped_func.bind(lambda f: self.fmap(f))
def or_else(self, alternative: "MaybeT") -> "MaybeT":
def _step(maybe_value):
return self._inner.__class__.pure(maybe_value) if maybe_value.is_just() else alternative._inner
return MaybeT(self._inner.bind(_step))
def run(self) -> Monad:
return self._inner
def __repr__(self) -> str:
return f"MaybeT({self._inner!r})"
def __rshift__(self, func) -> "MaybeT":
return self.bind(func)