# Copyright 2023 The Langfun Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Messages that are exchanged between users and agents."""

import abc
import collections
from collections.abc import Mapping
import contextlib
import functools
import inspect
import io
from typing import Annotated, Any, Callable, ClassVar, Optional, Type, Union

from langfun.core import modality
from langfun.core import natural_language
import pyglove as pg


class Message(
    natural_language.NaturalLanguageFormattable,
    pg.Object,
    pg.views.HtmlTreeView.Extension
):
  """Message.

  ``Message`` is the protocol for users and the system to interact with
  LLMs. It consists of a text in the form of natural language, 
  an identifier of the sender, and a dictionary of Python values as structured
  meta-data.

  The subclasses of ``Message`` represent messages sent from different roles.
  Agents may use the roles to decide the orchastration logic.
  """

  #
  # Constants.
  #

  # PyGlove flag to allow modifications to the member attributes through
  # assignment.
  allow_symbolic_assignment = True

  # Constant that refers to the special key used to access the message itself
  # through `Message.get`.
  PATH_ROOT = ''

  # Constant that refers to the special key used to access the message text
  # through `Message.get`.
  PATH_TEXT = 'text'

  # Constant that refers to the special key for describing the structured
  # information extracted as the result of the response.
  PATH_RESULT = 'result'

  # Constant that tags an LM input.
  TAG_LM_INPUT = 'lm-input'

  # Constant that tags an LM response (before output_transform).
  TAG_LM_RESPONSE = 'lm-response'

  # Constant that tags an LM output message (after output_transform).
  TAG_LM_OUTPUT = 'lm-output'

  # Constant that tags a message that is generated by `LangFunc.render`
  TAG_RENDERED = 'rendered'

  # Constant that tags an transformed message.
  TAG_TRANSFORMED = 'transformed'

  #
  # Members
  #

  text: Annotated[str, 'The natural language representation of the message.']

  sender: Annotated[str, 'The sender of the message.']

  referred_modalities: Annotated[
      dict[str, pg.Ref[modality.Modality]],
      'The modality objects referred in the message.'
  ] = pg.Dict()

  metadata: Annotated[
      dict[str, Any],
      (
          'The metadata associated with the message, '
          'which chould carry structured data, such as tool function input. '
          'It is a `pg.Dict` object whose keys can be accessed by attributes.'
      ),
  ] = pg.Dict()

  tags: Annotated[
      list[str],
      (
          'A list of tags associated with the message. '
          'Tags are useful for filtering messages along the source chain. '
      )
  ] = []

  # NOTE(daiyip): Explicit override __init__ for allowing metadata via **kwargs.
  @pg.explicit_method_override
  def __init__(
      self,
      text: str,
      *,
      # Default sender is specified in subclasses.
      sender: str | pg.object_utils.MissingValue = pg.MISSING_VALUE,
      referred_modalities: (
          list[modality.Modality]
          | dict[str, modality.Modality]
          | None
      ) = None,
      metadata: dict[str, Any] | None = None,
      tags: list[str] | None = None,
      source: Optional['Message'] = None,
      # The rest are `pg.Object.__init__` arguments.
      allow_partial: bool = False,
      sealed: bool = False,
      root_path: pg.KeyPath | None = None,
      **kwargs
  ) -> None:
    """Constructor.

    Args:
      text: The text in the message.
      sender: The sender name of the message.
      referred_modalities: The modality objects referred in the message.
      metadata: Structured meta-data associated with this message.
      tags: Tags for the message.
      source: The source message of the current message.
      allow_partial: If True, the object can be partial.
      sealed: If True, seal the object from future modification (unless under a
        `pg.seal(False)` context manager). If False, treat the object as
        unsealed. If None, it's determined by `cls.allow_symbolic_mutation`.
      root_path: The symbolic path for current object. By default it's None,
        which indicates that newly constructed object does not have a parent.
      **kwargs: key/value pairs that will be inserted into metadata.
    """
    metadata = metadata or {}
    metadata.update(kwargs)
    if isinstance(referred_modalities, list):
      referred_modalities = {m.id: pg.Ref(m) for m in referred_modalities}

    super().__init__(
        text=text,
        metadata=metadata,
        referred_modalities=referred_modalities or {},
        tags=tags or [],
        sender=sender,
        allow_partial=allow_partial,
        sealed=sealed,
        root_path=root_path,
    )
    self._source = source

  @classmethod
  def from_value(
      cls,
      value: Union[str, 'Message', Any],
      *,
      format: str | None = None,    # pylint: disable=redefined-builtin
      **kwargs
  ) -> 'Message':
    """Creates a message from a str, message or object of registered format.

    Example:
      # Create a user message from a str.
      lf.UserMessage.from_value('hi')

      # Create a user message from a multi-modal object.
      lf.UserMessage.from_value(lf.Image.from_uri('...'))

      # Create a user message from OpenAI API format.
      lf.Message.from_value(
          {
              'role': 'user',
              'content': [{'type': 'text', 'text': 'hi'}],
          },
          format='openai.api',
      )

    Args:
      value: The value to create a message from.
      format: The format ID to convert to. If None, the conversion will be
        performed according to the type of `value`. Otherwise, the converter
        registered for the format will be used.
      **kwargs: The keyword arguments passed to the __init__ of the converter.

    Returns:
      A message created from the value.
    """
    if isinstance(value, modality.Modality):
      return cls(f'<<[[{value.id}]]>>', referred_modalities=[value])
    if isinstance(value, Message):
      return value
    if isinstance(value, str):
      return cls(value)
    value_type = type(value)
    if format is None:
      converter = MessageConverter.get_by_type(value_type, **kwargs)
    else:
      converter = MessageConverter.get_by_format(format, **kwargs)
      if (converter.OUTPUT_TYPE is not None
          and not isinstance(value, converter.OUTPUT_TYPE)):
        raise ValueError(f'{format!r} is not applicable to {value!r}.')
    return converter.from_value(value)

  #
  # Conversion to other formats or types.
  #

  def as_format(self, format_or_type: str | type[Any], **kwargs) -> Any:
    """Converts the message to a registered format or type.

    Example:

      m = lf.Template('What is this {{image}}?').render()
      m.as_format('openai')  # Convert to OpenAI message format.
      m.as_format('gemini')  # Convert to Gemini message format.
      m.as_format('anthropic')  # Convert to Anthropic message format.

    Args:
      format_or_type: The format ID or type to convert to.
      **kwargs: The conversion arguments.

    Returns:
      The converted object according to the format or type.
    """
    return MessageConverter.get(format_or_type, **kwargs).to_value(self)

  @classmethod
  def convertible_formats(cls) -> list[str]:
    """Returns supported format for message conversion."""
    return MessageConverter.convertible_formats()

  @classmethod
  def convertible_types(cls) -> list[str]:
    """Returns supported types for message conversion."""
    return MessageConverter.convertible_types()

  #
  # Unified interface for accessing text, result and metadata.
  #

  def set(self, key_path: str | pg.KeyPath, value: Any) -> None:
    """Sets a value by key path.

    This the unified interface for set values in Message.

    Examples::

      m = lf.AIMessage('foo', metadata=dict(x=dict(k=[0, 1]), y=2))
      m.set('text', 'bar')
      m.set('x.k[0]', 1)
      m.set('y', pg.MISSING_VALUE)   # Delete.
      assert m == lf.AIMessage('bar', metadata=dict(x=dict(k=[1, 1])))

    Args:
      key_path: A string or a ``pg.KeyPath`` object for locating the value to
        update. For example: `a.b`, `x[0].a`, `text` is a special key that sets
        the text of the message.
      value: The new value for the location.
    """
    if key_path == Message.PATH_TEXT:
      self.rebind({key_path: value}, raise_on_no_change=False)
    else:
      self.metadata.rebind({key_path: value}, raise_on_no_change=False)

  def get(self, key_path: str | pg.KeyPath, default: Any = None) -> Any:
    """Gets text or metadata by key path.

    This is the unified interface to query text or metadata values.

    Args:
      key_path: A string like 'a.x', 'b[0].y' to access metadata value in
        hierarchy. 'text' is a special key that returns the text of the message.
      default: The default value if the key path is not found in `self`.

    Returns:
      The value for the path if found, otherwise the default value.
    """
    if not key_path:
      return self
    if key_path == Message.PATH_TEXT:
      return self.text
    else:
      return self.metadata.sym_get(key_path, default, use_inferred=True)

  #
  # API for accessing the structured result and error.
  # `result` represents the structured output the message - like the return
  # value of a regular Python function. It is stored in metadata with a special
  # key `result`. By default it's None, and should be produced by transforms
  # upon the message text.

  @property
  def result(self) -> Any:
    """Gets the structured result of the message."""
    return self.get(Message.PATH_RESULT, None)

  @result.setter
  def result(self, value: Any) -> None:
    """Sets the structured result of the message."""
    self.set(Message.PATH_RESULT, value)

  #
  # Update and error tracking.
  #

  def _on_init(self):
    super()._on_init()
    self._updates = {}
    self._errors = []

  def _on_change(self, field_updates: dict[pg.KeyPath, pg.FieldUpdate]) -> None:
    super()._on_change(field_updates)
    self._updates.update(field_updates)

  @property
  def modified(self) -> bool:
    """Returns True if the message has been modified in current update scope."""
    return bool(self._updates)

  @property
  def updates(self) -> dict[pg.KeyPath, pg.FieldUpdate]:
    """Returns the updates of the message in current update scope."""
    return self._updates

  @property
  def has_errors(self) -> bool:
    """Returns True if there is an error in current update scope."""
    return bool(self._errors)

  @property
  def errors(self) -> list[Any]:
    """Returns the errors of the message in current update scope."""
    return self._errors

  @contextlib.contextmanager
  def update_scope(self):
    """Context manager to create a update scope."""
    accumulated_updates = self._updates
    accumulated_errors = self._errors

    try:
      self._updates, self._errors = {}, []
      yield
    finally:
      accumulated_updates.update(self._updates)
      self._updates = accumulated_updates

      accumulated_errors.extend(self._errors)
      self._errors = accumulated_errors

  def apply_updates(self, updates: dict[pg.KeyPath, pg.FieldUpdate]) -> None:
    """Updates this message with delta."""
    delta = {k: v.new_value for k, v in updates.items()}

    # Rebind will trigger _on_change, which inserts the updates
    # to current message' updates.
    self.rebind(delta, raise_on_no_change=False)

  #
  # API for supporting modalities.
  #

  def modalities(
      self,
      filter: (  # pylint: disable=redefined-builtin
          Type[modality.Modality]
          | Callable[[modality.Modality], bool]
          | None
      ) = None  # pylint: disable=bad-whitespace
  ) -> list[modality.Modality]:
    """Returns the modality objects referred in the message."""
    if inspect.isclass(filter) and issubclass(filter, modality.Modality):
      filter_fn = lambda v: isinstance(v, filter)  # pytype: disable=wrong-arg-types
    elif filter is None:
      filter_fn = lambda v: True
    else:
      filter_fn = filter
    return [v for v in self.referred_modalities.values() if filter_fn(v)]

  @property
  def images(self) -> list[modality.Modality]:
    """Returns the image objects referred in the message."""
    assert False, 'Overridden in core/modalities/__init__.py'

  @property
  def videos(self) -> list[modality.Modality]:
    """Returns the video objects referred in the message."""
    assert False, 'Overridden in core/modalities/__init__.py'

  @property
  def audios(self) -> list[modality.Modality]:
    """Returns the audio objects referred in the message."""
    assert False, 'Overridden in core/modalities/__init__.py'

  def get_modality(
      self,
      var_name: str,
      default: Any = None
  ) -> modality.Modality | None:
    """Gets the modality object referred in the message.

    Args:
      var_name: The referred variable name for the modality object.
      default: default value.

    Returns:
      A modality object if found, otherwise None.
    """
    return self.referred_modalities.get(var_name, default)

  def chunk(self, text: str | None = None) -> list[str | modality.Modality]:
    """Chunk a message into a list of str or modality objects."""
    chunks = []

    def add_text_chunk(text_piece: str) -> None:
      if text_piece:
        chunks.append(text_piece)
    if text is None:
      text = self.text

    chunk_start = 0
    ref_end = 0
    while chunk_start < len(text):
      ref_start = text.find(modality.Modality.REF_START, ref_end)
      if ref_start == -1:
        add_text_chunk(text[chunk_start:].strip(' '))
        break

      var_start = ref_start + len(modality.Modality.REF_START)
      ref_end = text.find(modality.Modality.REF_END, var_start)
      if ref_end == -1:
        add_text_chunk(text[chunk_start:])
        break

      var_name = text[var_start:ref_end].strip()
      var_value = self.get_modality(var_name)
      if var_value is None:
        raise ValueError(
            f'Unknown modality reference: {var_name!r}. '
            'Please make sure the modality object is present in '
            f'`referred_modalities` when creating {self.__class__.__name__}.'
        )
      add_text_chunk(text[chunk_start:ref_start].strip(' '))
      chunks.append(var_value)
      chunk_start = ref_end + len(modality.Modality.REF_END)
    return chunks

  @classmethod
  def from_chunks(
      cls, chunks: list[str | modality.Modality], separator: str = ' '
  ) -> 'Message':
    """Assembly a message from a list of string or modality objects."""
    fused_text = io.StringIO()
    metadata = dict()
    referred_modalities = dict()
    last_char = None
    for i, chunk in enumerate(chunks):
      if i > 0 and last_char not in ('\t', ' ', '\n', None):
        fused_text.write(separator)
      if isinstance(chunk, str):
        fused_text.write(chunk)
        if chunk:
          last_char = chunk[-1]
        else:
          last_char = None
      else:
        assert isinstance(chunk, modality.Modality), chunk
        fused_text.write(modality.Modality.text_marker(chunk.id))
        last_char = modality.Modality.REF_END[-1]
        # Make a reference if the chunk is already owned by another object
        # to avoid copy.
        referred_modalities[chunk.id] = pg.Ref(chunk)
    return cls(
        fused_text.getvalue().strip(),
        referred_modalities=referred_modalities,
        metadata=metadata,
    )

  #
  # Tagging
  #

  def tag(self, tag: str) -> None:
    if tag not in self.tags:
      with pg.notify_on_change(False):
        self.tags.append(tag)

  def has_tag(self, tag: str | tuple[str, ...]) -> bool:
    if isinstance(tag, str):
      return tag in self.tags
    return any(t in self.tags for t in tag)

  #
  # Message source chain.
  #

  @property
  def source(self) -> Optional['Message']:
    """Returns the source message."""
    return self._source

  @source.setter
  def source(self, source: 'Message') -> None:
    """Sets the source message."""
    self._source = source

  @property
  def root(self) -> 'Message':
    """Returns the root of this message."""
    root = self
    while root.source is not None:
      root = root.source
    return root

  def trace(self, tag: str | None = None) -> list['Message']:
    """Returns the chain of source messages filtered by tag."""
    message_chain = []
    current = self

    while current is not None:
      if tag is None or tag in current.tags:
        message_chain.append(current)
      current = current.source
    return list(reversed(message_chain))

  @property
  def lm_responses(self) -> list['Message']:
    """Returns a chain of LM responses starting from the first LM call."""
    return self.trace(Message.TAG_LM_RESPONSE)

  @property
  def lm_inputs(self) -> list['Message']:
    """Returns a chain of LM inputs starting from the first LM call."""
    return self.trace(Message.TAG_LM_INPUT)

  @property
  def lm_outputs(self) -> list['Message']:
    """Returns a chain of LM inputs starting from the first LM call."""
    return self.trace(Message.TAG_LM_OUTPUT)

  def last(self, tag: str) -> Optional['Message']:
    """Return the last message wih certain tag."""
    current = self
    while current is not None:
      if tag in current.tags:
        return current
      current = current.source
    return None

  @property
  def lm_response(self) -> Optional['Message']:
    """Returns the latest LM raw response."""
    return self.last(Message.TAG_LM_RESPONSE)

  @property
  def lm_input(self) -> Optional['Message']:
    """Returns the latest LM input."""
    return self.last(Message.TAG_LM_INPUT)

  @property
  def lm_output(self) -> Optional['Message']:
    """Returns the latest LM output."""
    return self.last(Message.TAG_LM_OUTPUT)

  #
  # Other methods.
  #

  def natural_language_format(self) -> str:
    """Returns the natural language format representation."""
    # Propagate the modality references to parent context if any.
    if capture_context := modality.get_modality_capture_context():
      for v in self.referred_modalities.values():
        capture_context.capture(v)
    return self.text

  def __eq__(self, other: Any) -> bool:
    if isinstance(other, str):
      return self.text == other
    if isinstance(other, self.__class__):
      return (self.text == other.text
              and self.sender == other.sender
              and self.metadata == other.metadata)
    return False

  def __hash__(self) -> int:
    return hash(self.text)

  def __getattr__(self, key: str) -> Any:
    if key not in self.metadata:
      raise AttributeError(key)
    return self.metadata[key]

  def _html_tree_view_content(
      self,
      *,
      view: pg.views.HtmlTreeView,
      root_path: pg.KeyPath | None = None,
      collapse_level: int | None = None,
      extra_flags: dict[str, Any] | None = None,
      **kwargs,
  ) -> pg.Html:
    """Returns the HTML representation of the message.

    Args:
      view: The HTML tree view.
      root_path: The root path of the message.
      collapse_level: The global collapse level.
      extra_flags: Extra flags to control the rendering.
        - source_tag: tags to filter source messages. If None, the entire
          source chain will be included.
        - include_message_metadata: Whether to include the metadata of the
          message.
        - collapse_modalities_in_text: Whether to collapse the modalities in the
          message text.
        - collapse_llm_usage: Whether to collapse the usage in the message.
        - collapse_message_result_level: The level to collapse the result in the
          message.
        - collapse_message_metadata_level: The level to collapse the metadata in
          the message.
        - collapse_source_message_level: The level to collapse the source in the
          message.
        - collapse_level: The global collapse level.
      **kwargs: Omitted keyword arguments.

    Returns:
      The HTML representation of the message content.
    """
    extra_flags = extra_flags if extra_flags is not None else {}

    include_message_metadata: bool = extra_flags.get(
        'include_message_metadata', True
    )
    source_tag: str | tuple[str, ...] | None = extra_flags.get(
        'source_tag', ('lm-input', 'lm-output')
    )
    collapse_modalities_in_text: bool = extra_flags.get(
        'collapse_modalities_in_text', True
    )
    collapse_llm_usage: bool = extra_flags.get(
        'collapse_llm_usage', False
    )
    collapse_message_result_level: int | None = extra_flags.get(
        'collapse_message_result_level', 1
    )
    collapse_message_metadata_level: int | None = extra_flags.get(
        'collapse_message_metadata_level', 1
    )
    collapse_source_message_level: int | None = extra_flags.get(
        'collapse_source_message_level', 1
    )
    passthrough_kwargs = view.get_passthrough_kwargs(**kwargs)
    def render_tags():
      return pg.Html.element(
          'div',
          [pg.Html.element('span', [tag]) for tag in self.tags],
          css_classes=['message-tags'],
      )

    def render_message_text():
      maybe_reformatted = self.get('formatted_text')
      referred_chunks = {}
      s = pg.Html('<div class="message-text">')
      for chunk in self.chunk(maybe_reformatted):
        if isinstance(chunk, str):
          s.write(s.escape(chunk))
        else:
          assert isinstance(chunk, modality.Modality), chunk
          s.write(
              pg.Html.element(
                  'div',
                  [
                      view.render(
                          chunk,
                          name=chunk.id,
                          root_path=chunk.sym_path,
                          collapse_level=(
                              0 if collapse_modalities_in_text else 1
                          ),
                          extra_flags=dict(
                              display_modality_when_hover=True,
                          ),
                          **passthrough_kwargs,
                      )
                  ],
                  css_classes=['modality-in-text'],
              )
          )
          referred_chunks[chunk.id] = chunk
      s.write('</div>')
      return s

    def render_result():
      if 'result' not in self.metadata:
        return None
      child_path = pg.KeyPath(['metadata', 'result'], root_path)
      return pg.Html.element(
          'div',
          [
              view.render(
                  self.result,
                  name='result',
                  root_path=child_path,
                  collapse_level=view.get_collapse_level(
                      (collapse_level, -1),
                      collapse_message_result_level,
                  ),
                  extra_flags=extra_flags,
                  **passthrough_kwargs,
              )
          ],
          css_classes=['message-result'],
      )

    def render_usage():
      if 'usage' not in self.metadata:
        return None
      child_path = pg.KeyPath(['metadata', 'usage'], root_path)
      return pg.Html.element(
          'div',
          [
              view.render(
                  self.usage,
                  name='llm usage',
                  key_style='label',
                  root_path=child_path,
                  collapse_level=view.get_collapse_level(
                      (collapse_level, -1),
                      0 if collapse_llm_usage else 1,
                  ),
                  extra_flags=extra_flags,
                  **view.get_passthrough_kwargs(
                      remove=['key_style'], **kwargs
                  ),
              )
          ],
          css_classes=['message-usage'],
      )

    def render_source_message():
      source = self.source
      while (source is not None
             and source_tag is not None
             and not source.has_tag(source_tag)):
        source = source.source
      if source is not None:
        child_path = pg.KeyPath('source', root_path)
        child_extra_flags = extra_flags.copy()
        child_extra_flags['collapse_source_message_level'] = (
            view.get_collapse_level(
                (collapse_source_message_level, -1), 0,
            )
        )
        return view.render(
            self.source,
            name='source',
            root_path=child_path,
            collapse_level=view.get_collapse_level(
                (collapse_level, -1),
                collapse_source_message_level,
            ),
            extra_flags=child_extra_flags,
            **passthrough_kwargs,
        )
      return None

    def render_metadata():
      if not include_message_metadata:
        return None
      child_path = pg.KeyPath('metadata', root_path)
      return pg.Html.element(
          'div',
          [
              view.render(
                  self.metadata,
                  css_classes=['message-metadata'],
                  exclude_keys=['usage', 'result'],
                  name='metadata',
                  root_path=child_path,
                  collapse_level=view.get_collapse_level(
                      (collapse_level, -1),
                      collapse_message_metadata_level,
                  ),
                  **view.get_passthrough_kwargs(
                      remove=['exclude_keys'], **kwargs
                  ),
              )
          ],
          css_classes=['message-metadata'],
      )

    return pg.Html.element(
        'div',
        [
            render_tags(),
            render_message_text(),
            render_result(),
            render_usage(),
            render_metadata(),
            render_source_message(),
        ],
        css_classes=['complex_value'],
    )

  @classmethod
  @functools.cache
  def _html_tree_view_config(cls) -> dict[str, Any]:
    return pg.views.HtmlTreeView.get_kwargs(
        super()._html_tree_view_config(),
        dict(
            css_classes=['lf-message'],
        )
    )

  @classmethod
  @functools.cache
  def _html_tree_view_css_styles(cls) -> list[str]:
    return super()._html_tree_view_css_styles() + [
        """
        /* Langfun Message styles.*/
        [class^="message-"] > details {
            margin: 0px 0px 5px 0px;
            border: 1px solid #EEE;
        }
        .lf-message.summary-title::after {
            content: ' 💬';
        }
        details.pyglove.ai-message {
            border: 1px solid blue;
            color: blue;
        }
        details.pyglove.user-message {
            border: 1px solid green;
            color: green;
        }
        .message-tags {
            margin: 5px 0px 5px 0px;
            font-size: .8em;
        }
        .message-tags > span {
            border-radius: 5px;
            background-color: #CCC;
            padding: 3px;
            margin: 0px 2px 0px 2px;
            color: white;
        }
        .message-text {
            padding: 20px;
            margin: 10px 5px 10px 5px;
            font-style: italic;
            white-space: pre-wrap;
            border: 1px solid #EEE;
            border-radius: 5px;
            background-color: #EEE;
        }
        .modality-in-text {
            display: inline-block;
        }
        .modality-in-text > details.pyglove {
            display: inline-block;
            font-size: 0.8em;
            border: 0;
            background-color: #A6F1A6;
            margin: 0px 5px 0px 5px;
        }
        .message-result {
            color: dodgerblue;
        }
        .message-usage {
            color: orange;
        }
        .message-usage .object-key.str {
            border: 1px solid orange;
            background-color: orange;
            color: white;
        }
        """
    ]


class _MessageConverterRegistry:
  """Message converter registry."""

  def __init__(self):
    self._name_to_converter: dict[str, Type[MessageConverter]] = {}
    self._type_to_converters: dict[Type[Any], list[Type[MessageConverter]]] = (
        collections.defaultdict(list)
    )

  def register(self, converter: Type['MessageConverter']) -> None:
    """Registers a message converter."""
    self._name_to_converter[converter.FORMAT_ID] = converter
    if converter.OUTPUT_TYPE is not None:
      self._type_to_converters[converter.OUTPUT_TYPE].append(converter)

  def get_by_type(self, t: Type[Any], **kwargs) -> 'MessageConverter':
    """Returns a message converter for the given type."""
    t = self._type_to_converters[t]
    if not t:
      raise TypeError(
          f'Cannot convert Message to {t!r}.'
      )
    if len(t) > 1:
      raise TypeError(
          f'More than one converters found for output type {t!r}. '
          f'Please specify one for this conversion: {[x.FORMAT_ID for x in t]}.'
      )
    return t[0](**kwargs)

  def get_by_format(self, format: str, **kwargs) -> 'MessageConverter':   # pylint: disable=redefined-builtin
    """Returns a message converter for the given format."""
    if format not in self._name_to_converter:
      raise ValueError(f'Unsupported format: {format!r}.')
    return self._name_to_converter[format](**kwargs)

  def get(
      self,
      format_or_type: str | Type[Any], **kwargs
  ) -> 'MessageConverter':
    """Returns a message converter for the given format or type."""
    if isinstance(format_or_type, str):
      return self.get_by_format(format_or_type, **kwargs)
    assert isinstance(format_or_type, type), format_or_type
    return self.get_by_type(format_or_type, **kwargs)

  def convertible_formats(self) -> list[str]:
    """Returns a list of converter names."""
    return sorted(list(self._name_to_converter.keys()))

  def convertible_types(self) -> list[Type[Any]]:
    """Returns a list of converter types."""
    return list(self._type_to_converters.keys())


class MessageConverter(pg.Object):
  """Interface for converting a Langfun message to other formats."""

  # A global unique identifier for the converter.
  FORMAT_ID: ClassVar[str]

  # The output type of the converter.
  # If None, the converter will not be registered to handle
  # `lf.Message.to_value(output_type)`.
  OUTPUT_TYPE: ClassVar[Type[Any] | None] = None

  _REGISTRY = _MessageConverterRegistry()

  def __init_subclass__(cls, *args, **kwargs):
    super().__init_subclass__(*args, **kwargs)
    if not inspect.isabstract(cls):
      cls._REGISTRY.register(cls)

  @abc.abstractmethod
  def to_value(self, message: Message) -> Any:
    """Converts a Langfun message to other formats."""

  @abc.abstractmethod
  def from_value(self, value: Any) -> Message:
    """Returns a Langfun message from other formats."""

  @classmethod
  def _safe_read(
      cls,
      data: Mapping[str, Any],
      key: str,
      default: Any = pg.MISSING_VALUE
  ) -> Any:
    """Safe reads a key from a mapping."""
    if not isinstance(data, Mapping):
      raise ValueError(f'Invalid data type: {data!r}.')
    if key not in data:
      if pg.MISSING_VALUE == default:
        raise ValueError(f'Missing key {key!r} in {data!r}')
      return default
    return data[key]

  @classmethod
  def get_role(cls, message: Message) -> str:
    """Returns the role of the message."""
    if isinstance(message, SystemMessage):
      return 'system'
    elif isinstance(message, UserMessage):
      return 'user'
    elif isinstance(message, AIMessage):
      return 'assistant'
    else:
      raise ValueError(f'Unsupported message type: {message!r}.')

  @classmethod
  def get_message_cls(cls, role: str) -> type[Message]:
    """Returns the message class of the message."""
    match role:
      case 'system':
        return SystemMessage
      case 'user':
        return UserMessage
      case 'assistant':
        return AIMessage
      case _:
        raise ValueError(f'Unsupported role: {role!r}.')

  @classmethod
  def get(cls, format_or_type: str | Type[Any], **kwargs) -> 'MessageConverter':
    """Returns a message converter."""
    return cls._REGISTRY.get(format_or_type, **kwargs)

  @classmethod
  def get_by_format(cls, format: str, **kwargs) -> 'MessageConverter':  # pylint: disable=redefined-builtin
    """Returns a message converter for the given format."""
    return cls._REGISTRY.get_by_format(format, **kwargs)

  @classmethod
  def get_by_type(cls, t: Type[Any], **kwargs) -> 'MessageConverter':
    """Returns a message converter for the given type."""
    return cls._REGISTRY.get_by_type(t, **kwargs)

  @classmethod
  def convertible_formats(cls) -> list[str]:
    """Returns a list of converter names."""
    return cls._REGISTRY.convertible_formats()

  @classmethod
  def convertible_types(cls) -> list[Type[Any]]:
    """Returns a list of converter types."""
    return cls._REGISTRY.convertible_types()

#
# Messages of different roles.
#


@pg.use_init_args(['text', 'sender', 'metadata'])
class UserMessage(Message):
  """Message sent from a human user."""

  sender = 'User'


@pg.use_init_args(['text', 'sender', 'metadata'])
class AIMessage(Message):
  """Message sent from an agent."""

  sender = 'AI'


@pg.use_init_args(['text', 'sender', 'metadata'])
class SystemMessage(Message):
  """Message sent from the system or environment."""

  sender = 'System'


@pg.use_init_args(['text', 'sender', 'metadata'])
class MemoryRecord(Message):
  """Message used as a memory record."""

  sender = 'Memory'
