from __future__ import annotations from abc import abstractmethod from .typeclasses import ( Applicative, Generic, Monad, MonadTransformer, Monoid, ListMonoid, ) from .classmethod import Classmethod class Maybe(Monad): @classmethod def pure(cls, value) -> "Maybe": return Just(value) @classmethod def of(cls, value) -> "Maybe": return Nothing() if value is None else Just(value) @classmethod def from_either(cls, either: "Either") -> "Maybe": return Just(either._value) if either.is_right() else Nothing() @abstractmethod def bind(self, func) -> "Maybe": ... @abstractmethod def is_just(self) -> bool: ... def is_nothing(self) -> bool: return not self.is_just() @abstractmethod def get_or(self, default): ... def get_or_raise(self, exc=None): if self.is_just(): return self.get_or(None) raise (exc if exc is not None else ValueError("Nothing.get_or_raise")) def filter(self, predicate) -> "Maybe": return self.bind(lambda x: Just(x) if predicate(x) else Nothing()) def or_else(self, alternative: "Maybe") -> "Maybe": return self if self.is_just() else alternative def to_either(self, error) -> "Either": if self.is_just(): return Right(self.get_or(None)) return Left(error) def to_list(self) -> "List": return List([self.get_or(None)]) if self.is_just() else List([]) def __iter__(self): if self.is_just(): yield self.get_or(None) def __bool__(self) -> bool: return self.is_just() @abstractmethod def __eq__(self, other: object) -> bool: ... @abstractmethod def __hash__(self) -> int: ... class Just(Maybe): __slots__ = ("_value",) def __init__(self, value): self._value = value @classmethod def pure(cls, value) -> "Just": return cls(value) def bind(self, func) -> Maybe: return func(self._value) def fmap(self, func) -> Maybe: return Just(func(self._value)) def apply(self, wrapped_func: Applicative) -> Maybe: if isinstance(wrapped_func, Just): return Just(wrapped_func._value(self._value)) return Nothing() def get_or(self, default): return self._value def is_just(self) -> bool: return True def __repr__(self) -> str: return f"Just({self._value!r})" def __eq__(self, other: object) -> bool: return isinstance(other, Just) and self._value == other._value def __hash__(self) -> int: return hash(("Just", self._value)) class Nothing(Maybe): _instance = None def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance @classmethod def pure(cls, value) -> "Just": return Just(value) def bind(self, func) -> "Nothing": return self def fmap(self, func) -> "Nothing": return self def apply(self, wrapped_func: Applicative) -> "Nothing": return self def get_or(self, default): return default def is_just(self) -> bool: return False def __repr__(self) -> str: return "Nothing" def __eq__(self, other: object) -> bool: return isinstance(other, Nothing) def __hash__(self) -> int: return hash("Nothing") class Either(Monad): @classmethod def pure(cls, value) -> "Right": return Right(value) @classmethod def try_(cls, func, *args, **kwargs) -> "Either": try: return Right(func(*args, **kwargs)) except Exception as exc: return Left(exc) @classmethod def from_maybe(cls, maybe: Maybe, error) -> "Either": return Right(maybe.get_or(None)) if maybe.is_just() else Left(error) @abstractmethod def bind(self, func) -> "Either": ... @abstractmethod def is_right(self) -> bool: ... def is_left(self) -> bool: return not self.is_right() @abstractmethod def get_or(self, default): ... def get_or_raise(self): if self.is_right(): return self.get_or(None) err = self._error if hasattr(self, "_error") else ValueError("Left") if isinstance(err, BaseException): raise err raise ValueError(err) @abstractmethod def map_left(self, func) -> "Either": ... @abstractmethod def swap(self) -> "Either": ... def to_maybe(self) -> Maybe: return Just(self.get_or(None)) if self.is_right() else Nothing() def __bool__(self) -> bool: return self.is_right() @abstractmethod def __eq__(self, other: object) -> bool: ... @abstractmethod def __hash__(self) -> int: ... class Right(Either): __slots__ = ("_value",) def __init__(self, value): self._value = value @classmethod def pure(cls, value) -> "Right": return cls(value) def bind(self, func) -> Either: try: return func(self._value) except Exception as exc: return Left(exc) def fmap(self, func) -> Either: try: return Right(func(self._value)) except Exception as exc: return Left(exc) def apply(self, wrapped_func: Applicative) -> Either: if isinstance(wrapped_func, Right): return Right(wrapped_func._value(self._value)) return wrapped_func def get_or(self, default): return self._value def map_left(self, func) -> "Right": return self def swap(self) -> "Left": return Left(self._value) def is_right(self) -> bool: return True def __repr__(self) -> str: return f"Right({self._value!r})" def __eq__(self, other: object) -> bool: return isinstance(other, Right) and self._value == other._value def __hash__(self) -> int: return hash(("Right", self._value)) class Left(Either): __slots__ = ("_error",) def __init__(self, error): self._error = error @classmethod def pure(cls, value) -> "Right": return Right(value) def bind(self, func) -> "Left": return self def fmap(self, func) -> "Left": return self def apply(self, wrapped_func: Applicative) -> "Left": return self def get_or(self, default): return default def map_left(self, func) -> "Left": return Left(func(self._error)) def swap(self) -> "Right": return Right(self._error) def is_right(self) -> bool: return False def __repr__(self) -> str: return f"Left({self._error!r})" def __eq__(self, other: object) -> bool: return isinstance(other, Left) and self._error == other._error def __hash__(self) -> int: return hash(("Left", self._error)) class List(Monad): __slots__ = ("_values",) def __init__(self, values=None): self._values = list(values) if values is not None else [] @classmethod def pure(cls, value) -> "List": return cls([value]) @classmethod def empty(cls) -> "List": return cls([]) @classmethod def range(cls, *args) -> "List": return cls(range(*args)) def bind(self, func) -> "List": result = [] for v in self._values: out = func(v) if isinstance(out, List): result.extend(out._values) else: result.append(out) return List(result) def fmap(self, func) -> "List": return List(func(v) for v in self._values) def apply(self, wrapped_func: "List") -> "List": return wrapped_func.bind(lambda f: self.fmap(f)) def filter(self, predicate) -> "List": return List(v for v in self._values if predicate(v)) def concat(self, other: "List") -> "List": return List(self._values + other._values) def head(self) -> Maybe: return Just(self._values[0]) if self._values else Nothing() def tail(self) -> "List": return List(self._values[1:]) def last(self) -> Maybe: return Just(self._values[-1]) if self._values else Nothing() def take(self, n: int) -> "List": return List(self._values[:n]) def drop(self, n: int) -> "List": return List(self._values[n:]) def zip_with(self, other: "List", func) -> "List": return List(func(a, b) for a, b in zip(self._values, other._values)) def fold(self, func, initial): acc = initial for v in self._values: acc = func(acc, v) return acc def traverse_maybe(self, func) -> Maybe: results = [] for v in self._values: m = func(v) if m.is_nothing(): return Nothing() results.append(m.get_or(None)) return Just(List(results)) def sequence_maybe(self) -> Maybe: return self.traverse_maybe(lambda x: x) def traverse_either(self, func) -> Either: results = [] for v in self._values: e = func(v) if e.is_left(): return e results.append(e.get_or(None)) return Right(List(results)) def sequence_either(self) -> Either: return self.traverse_either(lambda x: x) def __len__(self) -> int: return len(self._values) def __iter__(self): return iter(self._values) def __getitem__(self, index): return self._values[index] def __contains__(self, item) -> bool: return item in self._values def __repr__(self) -> str: return f"List({self._values!r})" def __eq__(self, other: object) -> bool: return isinstance(other, List) and self._values == other._values def __hash__(self) -> int: return hash(("List", tuple(self._values))) def __add__(self, other: "List") -> "List": return self.concat(other) class IO(Monad): __slots__ = ("_thunk",) def __init__(self, thunk): self._thunk = thunk @classmethod def pure(cls, value) -> "IO": return cls(lambda: value) @classmethod def from_callable(cls, func, *args, **kwargs) -> "IO": return cls(lambda: func(*args, **kwargs)) @classmethod def sequence(cls, ios) -> "IO": def run_all(): return List([io.run() for io in ios]) return cls(run_all) def bind(self, func) -> "IO": return IO(lambda: func(self._thunk()).run()) def fmap(self, func) -> "IO": return IO(lambda: func(self._thunk())) def apply(self, wrapped_func: "IO") -> "IO": return IO(lambda: wrapped_func._thunk()(self._thunk())) def then(self, other: "IO") -> "IO": return IO(lambda: (self._thunk(), other.run())[1]) def run(self): return self._thunk() def __repr__(self) -> str: return "IO()" class Writer(Monad): __slots__ = ("_value", "_log") def __init__(self, value, log: Monoid): self._value = value self._log = log @classmethod def pure(cls, value) -> "Writer": return cls(value, ListMonoid([])) @classmethod def tell(cls, log: Monoid) -> "Writer": return cls(None, log) @classmethod def with_log(cls, value, log: Monoid) -> "Writer": return cls(value, log) def bind(self, func) -> "Writer": new_writer = func(self._value) return Writer(new_writer._value, self._log.combine(new_writer._log)) def fmap(self, func) -> "Writer": return Writer(func(self._value), self._log) def apply(self, wrapped_func: "Writer") -> "Writer": return Writer( wrapped_func._value(self._value), self._log.combine(wrapped_func._log), ) def run(self): return (self._value, self._log) @property def value(self): return self._value @property def log(self): return self._log def listen(self) -> "Writer": return Writer((self._value, self._log), self._log) def censor(self, func) -> "Writer": return Writer(self._value, func(self._log)) def __repr__(self) -> str: return f"Writer({self._value!r}, {self._log!r})" def __eq__(self, other: object) -> bool: return ( isinstance(other, Writer) and self._value == other._value and self._log == other._log ) def __hash__(self) -> int: return hash(("Writer", self._value)) class State(Monad): __slots__ = ("_run",) def __init__(self, run_func): self._run = run_func @classmethod def pure(cls, value) -> "State": return cls(lambda s: (value, s)) @classmethod def get(cls) -> "State": return cls(lambda s: (s, s)) @classmethod def put(cls, new_state) -> "State": return cls(lambda _: (None, new_state)) @classmethod def modify(cls, func) -> "State": return cls(lambda s: (None, func(s))) @classmethod def gets(cls, func) -> "State": return cls(lambda s: (func(s), s)) def bind(self, func) -> "State": def _run(state): value, new_state = self._run(state) return func(value)._run(new_state) return State(_run) def fmap(self, func) -> "State": def _run(state): value, new_state = self._run(state) return (func(value), new_state) return State(_run) def apply(self, wrapped_func: "State") -> "State": def _run(state): f, state1 = wrapped_func._run(state) value, state2 = self._run(state1) return (f(value), state2) return State(_run) def run(self, initial_state): return self._run(initial_state) def eval(self, initial_state): return self._run(initial_state)[0] def exec(self, initial_state): return self._run(initial_state)[1] def __repr__(self) -> str: return "State()" class Reader(Monad): __slots__ = ("_run",) def __init__(self, run_func): self._run = run_func @classmethod def pure(cls, value) -> "Reader": return cls(lambda _: value) @classmethod def ask(cls) -> "Reader": return cls(lambda env: env) @classmethod def asks(cls, func) -> "Reader": return cls(lambda env: func(env)) def bind(self, func) -> "Reader": return Reader(lambda env: func(self._run(env))._run(env)) def fmap(self, func) -> "Reader": return Reader(lambda env: func(self._run(env))) def apply(self, wrapped_func: "Reader") -> "Reader": return Reader(lambda env: wrapped_func._run(env)(self._run(env))) def local(self, func) -> "Reader": return Reader(lambda env: self._run(func(env))) def run(self, env): return self._run(env) def __repr__(self) -> str: return "Reader()" class MaybeT(MonadTransformer): __slots__ = ("_inner",) def __init__(self, inner: Monad): self._inner = inner @classmethod def pure(cls, value) -> "MaybeT": raise NotImplementedError("Use MaybeT.pure_with(base_cls, value)") @classmethod def pure_with(cls, base_cls, value) -> "MaybeT": return cls(base_cls.pure(Just(value))) @classmethod def nothing_with(cls, base_cls) -> "MaybeT": return cls(base_cls.pure(Nothing())) @classmethod def lift(cls, monad: Monad) -> "MaybeT": return cls(monad.fmap(Just)) def bind(self, func) -> "MaybeT": def _step(maybe_value): if maybe_value.is_nothing(): return self._inner.__class__.pure(Nothing()) return func(maybe_value.get_or(None))._inner return MaybeT(self._inner.bind(_step)) def fmap(self, func) -> "MaybeT": return MaybeT(self._inner.fmap(lambda m: m.fmap(func))) def apply(self, wrapped_func: "MaybeT") -> "MaybeT": return wrapped_func.bind(lambda f: self.fmap(f)) def or_else(self, alternative: "MaybeT") -> "MaybeT": def _step(maybe_value): return self._inner.__class__.pure(maybe_value) if maybe_value.is_just() else alternative._inner return MaybeT(self._inner.bind(_step)) def run(self) -> Monad: return self._inner def __repr__(self) -> str: return f"MaybeT({self._inner!r})" def __rshift__(self, func) -> "MaybeT": return self.bind(func)