# Copyright (c) 2024 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""table ocr component."""

import base64
import json
import logging

from typing import Optional

from appbuilder.core import utils
from appbuilder.core.component import Component
from appbuilder.core.components.table_ocr.model import *
from appbuilder.core.message import Message
from appbuilder.core._client import HTTPClient
from appbuilder.core.constants import COMPONENT_SUPPORT_FILE_NUMBER
from appbuilder.core._exception import AppBuilderServerException, InvalidRequestArgumentError
from appbuilder.utils.trace.tracer_wrapper import components_run_trace, components_run_stream_trace


class TableOCR(Component):
    r"""
       支持识别图片中的表格内容，返回各表格的表头表尾内容、单元格文字内容及其行列位置信息，全面覆盖各类表格样式，包括常规有线表格、
       无线表格、含合并单元格表格。同时，支持多表格内容识别。

       Examples:

       .. code-block:: python

           import appbuilder
           # 请前往千帆AppBuilder官网创建密钥，流程详见：https://cloud.baidu.com/doc/AppBuilder/s/Olq6grrt6#1%E3%80%81%E5%88%9B%E5%BB%BA%E5%AF%86%E9%92%A5
           os.environ["APPBUILDER_TOKEN"] = '...'

           table_ocr = appbuilder.TableOCR()
           with open("./table_ocr_test.png", "rb") as f:
               out = self.component.run(appbuilder.Message(content={"raw_image": f.read()}))
           print(out.content)

        """

    name = "table_ocr"
    version = "v1"
    manifests = [
        {
            "name": "table_ocr",
            "description": "需要识别图片中的表格内容，使用该工具, 但不支持html后缀文件的识别",
            "parameters": {
                "type": "object",
                "properties": {
                    "file_names": {
                        "type": "array",
                        "items": {
                            "type": "string"
                        },
                        "description": "待识别图片的文件名"
                    },
                    "file_urls": {
                        "type": "array",
                        "items": {
                            "type": "string"
                        },
                        "description": "待识别图片的文件url"
                    }
                },
                "anyOf": [
                    {
                        "required": [
                            "file_names"
                        ]
                    },
                    {
                        "required": [
                            "file_urls"
                        ]
                    }
                ]
            }
        }
    ]

    @HTTPClient.check_param
    @components_run_trace
    def run(self, message: Message, timeout: float = None, retry: int = 0) -> Message:
        """
        表格文字识别
        
        Args:
            message (Message): 输入图片或图片url下载地址用于执行识别操作。
                举例: Message(content={"raw_image": b"..."})
                或 Message(content={"url": "https://image/download/url"})。
            timeout (float, 可选): HTTP超时时间。
            retry (int, 可选): HTTP重试次数。
        
        Returns:
            message (Message): 识别结果。
                举例: Message(name=msg, content={'tables_result': [{
                'table_location': [{'x': 15, 'y': 15}, {'x': 371, 'y': 15}, {'x': 371, 'y': 98}, {'x': 15,
                'y': 98}], 'header': [], 'body': [{'cell_location': [{'x': 15, 'y': 15}, {'x': 120, 'y': 15},
                {'x': 120, 'y': 58}, {'x': 15, 'y': 58}], 'row_start': 0, 'row_end': 1, 'col_start': 0,
                'col_end': 1, 'words': '参数'}, {'cell_location': [{'x': 120, 'y': 15}, {'x': 371, 'y': 15},
                {'x': 371, 'y': 58}, {'x': 120, 'y': 58}], 'row_start': 0, 'row_end': 1, 'col_start': 1,
                'col_end': 2, 'words': '值'}, {'cell_location': [{'x': 15, 'y': 58}, {'x': 120, 'y': 58},
                {'x': 120, 'y': 98}, {'x': 15, 'y': 98}], 'row_start': 1, 'row_end': 2, 'col_start': 0,
                'col_end': 1, 'words': 'Content-Type'}, {'cell_location': [{'x': 120, 'y': 58}, {'x': 371,
                'y': 58}, {'x': 371, 'y': 98}, {'x': 120, 'y': 98}], 'row_start': 1, 'row_end': 2, 'col_start':
                1, 'col_end': 2, 'words': 'application/x-www-form-urlencoded'}], 'footer': []}]}, mtype=dict)
        
        """
        inp = TableOCRInMsg(**message.content)
        req = TableOCRRequest()
        if inp.raw_image:
            req.image = base64.b64encode(inp.raw_image)
        if inp.url:
            req.url = inp.url
        req.cell_contents = "false"
        result, _ = self._recognize(req, timeout, retry)
        result_dict = proto.Message.to_dict(result)
        out = TableOCROutMsg(**result_dict)
        return Message(content=out.model_dump())

    def _recognize(self, request: TableOCRRequest, timeout: float = None,
                   retry: int = 0, request_id: str = None) -> TableOCRResponse:
        r"""调用底层接口进行表格文字识别
                   参数:
                       request (obj: `TableOCRRequest`) : 表格文字识别输入参数
                   返回：
                       response (obj: `TableOCRResponse`): 表格文字识别返回结果
               """
        if not request.image and not request.url:
            raise ValueError(
                "request format error, one of image or url must be set")

        data = TableOCRRequest.to_dict(request)
        if self.http_client.retry.total != retry:
            self.http_client.retry.total = retry
        headers = self.http_client.auth_header(request_id)
        headers['content-type'] = 'application/x-www-form-urlencoded'
        url = self.http_client.service_url("/v1/bce/aip/ocr/v1/table")
        response = self.http_client.session.post(
            url, headers=headers, data=data, timeout=timeout)
        self.http_client.check_response_header(response)
        data = response.json()
        self.http_client.check_response_json(data)
        request_id = self.http_client.response_request_id(response)
        self.__class__._check_service_error(request_id, data)
        res = TableOCRResponse.from_json(json.dumps(data))
        res.request_id = request_id
        return res, data

    @staticmethod
    def _check_service_error(request_id: str, data: dict):
        r"""个性化服务response参数检查
            参数:
                request (dict) : 表格文字识别body返回
            返回：
                无
        """
        if "error_code" in data or "error_msg" in data:
            raise AppBuilderServerException(
                request_id=request_id,
                service_err_code=data.get("error_code"),
                service_err_message=data.get("error_msg")
            )

    def get_table_markdown(self, tables_result):
        """
        将表格识别结果转换为Markdown格式。
        
        Args:
            tables_result (list): 表格识别结果列表，每个元素是一个包含表格数据的字典，其中包含表格体（body）等字段。
        
        Returns:
            list: 包含Markdown格式表格的字符串列表。
        
        """
        markdowns = []
        for table in tables_result:
            cells = table["body"]
            max_row = max(cell['row_end'] for cell in cells)
            max_col = max(cell['col_end'] for cell in cells)
            # 初始化表格数组
            table_arr = [[''] * max_col for _ in range(max_row)]
            # 填充表格数据
            for cell in cells:
                row = cell['row_start']
                col = cell['col_start']
                table_arr[row][col] = cell['words']

            markdown_table = ""
            for row in table_arr:
                markdown_table += "| " + " | ".join(row) + " |\n"
            # 生成分隔行
            separator = "| " + " | ".join(['---'] * max_col) + " |\n"
            # 插入分隔行在表头下方
            header, body = markdown_table.split('\n', 1)
            markdown_table = header + '\n' + separator + body
            markdowns.append(markdown_table)
        return markdowns

    @components_run_stream_trace
    def tool_eval(self, 
                  file_names: Optional[List[str]] = [],
                  file_urls: Optional[List[str]] = [],
                  **kwargs):
        """
        处理并评估传入的文件列表，并返回表格数据的Markdown格式表示。
        
        Args:
            file_names (List[str]): 待处理的文件列表。
            file_urls(List[str]): 待处理的文件url。
            **kwargs: 其他可选参数。
        
        Returns:
            Generator: 生成包含处理结果的生成器。
        
        Raises:
            InvalidRequestArgumentError: 如果请求格式错误，文件URL不存在。
        
        """
        if not file_names and not file_urls:
            raise InvalidRequestArgumentError(
                request_id=kwargs.get("_sys_traceid", ""),
                message="file_names and file_urls cannot both be empty"
            )
        supported_file_type = ["png", "jpg", "jpeg", "webp", "heic", "tif", "tiff", "dcm", "mha", "nii.gz"]
        traceid = kwargs.get("_sys_traceid", "")
        sys_file_urls = kwargs.get("_sys_file_urls", {})
        available_file_urls = {}
        unsupported_files = []
        unknown_files = []
        if file_names:
            for file_name in file_names:
                if len(available_file_urls) >= 10:
                    break
                file_type = file_name.split(".")[-1].lower()
                if file_name in sys_file_urls:
                    if file_type in supported_file_type:
                        available_file_urls[file_name] = sys_file_urls.get(file_name, "")
                    else:
                        unsupported_files.append(file_name)
                else:
                    unknown_files.append(file_name)

        if file_names:
            for img_name in file_names:
                if img_name in sys_file_urls:
                    available_file_urls[img_name] = sys_file_urls.get(img_name, "")

        for file_url in file_urls:
            if len(available_file_urls) >= 10:
                break
            if file_url in list(sys_file_urls.values()):
                continue
            file_name = file_url.split("/")[-1].split("?")[0]
            file_type = file_name.split(".")[-1].lower()
            if file_type in supported_file_type:
                available_file_urls[file_url] = file_url
            else:
                unsupported_files.append(file_url)
            
        for file_name, file_url in available_file_urls.items():
            try:
                req = TableOCRRequest()
                req.url = file_url
                req.cell_contents = "false"
                resp, raw_data = self._recognize(req, request_id=traceid)
                tables_result = proto.Message.to_dict(resp)["tables_result"]
                markdowns = self.get_table_markdown(tables_result)
                rec_res = {
                    file_name: markdowns
                }
                res = json.dumps(rec_res, ensure_ascii=False)
                yield self.create_output(type="text", text=res, raw_data=raw_data, visible_scope="llm")
                yield self.create_output(type="text", text="", raw_data=raw_data, visible_scope="user")
            except Exception as e:
                logging.warning(f"{file_name} ocr failed with exception: {e}")
                continue

        for file in unsupported_files:
            rec_res = {
                file: "不支持的文件类型，请确认是否为图片文件"
            }
            res = json.dumps(rec_res, ensure_ascii=False)
            yield self.create_output(type="text", text=res, name="text_1", visible_scope='llm')
            yield self.create_output(type="text", text=f"", name="text_2", visible_scope='user')
        
        for file in unknown_files:
            rec_res = {
                file: "无法获取url，请确认是否上传成功"
            }
            res = json.dumps(rec_res, ensure_ascii=False)
            yield self.create_output(type="text", text=res, name="text_1", visible_scope='llm')
            yield self.create_output(type="text", text=f"", name="text_2", visible_scope='user')
