#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""V11: 添加 S6（续借时间泄露）并调整策略顺序

策略顺序：S1 -> S3 -> S6 -> S1.5 -> S2 -> S4 -> S5
"""

import pandas as pd
import re
from typing import Optional, Tuple
from pathlib import Path
from datetime import timedelta

# 让这些路径可以在包外部配置
REEVAL_INTER = r"C:\Users\timberman\Downloads\数据集\复赛-选手使用数据\inter_reevaluation.csv"
FINAL_INTER = r"C:\Users\timberman\Downloads\数据集\决赛-选手使用数据\inter_final.csv"
TARGET_USER_FILE = r"C:\Users\timberman\Downloads\数据集\决赛-选手使用数据\user.csv"
OUT_CSV = r"C:\Users\timberman\Downloads\数据集\submitV11_renew_leak.csv"
OUT_WITH_STR = r"C:\Users\timberman\Downloads\数据集\submit_with_strategyV11_renew_leak.csv"

NEAR_DELTA = timedelta(minutes=5)


def read_csv_smart(path: str) -> pd.DataFrame:
    encodings = ['utf-8-sig', 'utf-8', 'gbk', 'cp936']
    last_err = None
    for enc in encodings:
        try:
            return pd.read_csv(path, encoding=enc)
        except Exception as e:
            last_err = e
            continue
    raise last_err


def fix_ts_str(s: str):
    if pd.isna(s):
        return s
    s = str(s).strip()
    if s == '' or s.lower() in ('nan', 'none', 'null'):
        return None
    s = re.sub(r'^(\d{4}-\d{2}-\d{2})(\d{2}:\d{2}:\d{2})$', r'\1 \2', s)
    return s


def parse_time_columns(df: pd.DataFrame, cols):
    for c in cols:
        if c in df.columns:
            df[c] = df[c].map(fix_ts_str)
            df[c] = pd.to_datetime(df[c], errors='coerce')
    return df


def compute_user_time_refs(observe_df: pd.DataFrame) -> Tuple[pd.Series, pd.Series, pd.Series]:
    last_borrow_time = observe_df.groupby('user_id')['借阅时间'].max()

    tmp = (
        observe_df[['user_id', '借阅时间', 'inter_id']]
        .merge(last_borrow_time.rename('last_borrow_time').reset_index(), on='user_id', how='left')
    )
    same_time = tmp[tmp['借阅时间'] == tmp['last_borrow_time']]
    last_borrow_max_inter_id = same_time.groupby('user_id')['inter_id'].max()

    user_t0 = last_borrow_time.copy()
    return user_t0, last_borrow_time, last_borrow_max_inter_id


def load_target_users(default_users: pd.Series) -> pd.Series:
    """尽量扩展到国赛1451人的名单"""
    if not Path(TARGET_USER_FILE).exists():
        print(f"⚠️ 未找到 {TARGET_USER_FILE}，沿用复赛用户列表，共 {len(default_users)} 人")
        return default_users
    df = read_csv_smart(TARGET_USER_FILE)
    candidate_cols = [c for c in ('借阅人', 'user_id', 'uid') if c in df.columns]
    if not candidate_cols:
        print(f"⚠️ {TARGET_USER_FILE} 中没有用户列，沿用复赛用户列表，共 {len(default_users)} 人")
        return default_users
    col = candidate_cols[0]
    user_ids = pd.to_numeric(df[col], errors='coerce').dropna().astype('Int64').unique()
    print(f"Loaded {len(user_ids)} target users from {TARGET_USER_FILE} (column: {col})")
    return pd.Series(user_ids)


def is_renew_anomaly(row) -> bool:
    renew_time = row.get('续借时间')
    borrow_time = row.get('借阅时间')
    return_time = row.get('还书时间')
    renew_cnt = row.get('续借次数')
    if pd.notna(renew_time) and (pd.isna(renew_cnt) or int(renew_cnt) == 0):
        return True
    if pd.notna(renew_time) and pd.notna(borrow_time) and renew_time < borrow_time:
        return True
    if pd.notna(renew_time) and pd.notna(return_time) and renew_time > return_time:
        return True
    return False


def find_global_pivot_T(reeval: pd.DataFrame, quantiles=(0.60, 0.70, 0.80, 0.85, 0.90, 0.95, 0.97, 0.98, 0.99)):
    valid_borrows = reeval['借阅时间'].dropna().sort_values()
    if valid_borrows.empty:
        return None
    best_q = None
    best_T = None
    max_both = 0
    for q in quantiles:
        T = valid_borrows.quantile(q)
        before = reeval.groupby('user_id', group_keys=False)['借阅时间'].apply(lambda x: (x <= T).sum())
        after = reeval.groupby('user_id', group_keys=False)['借阅时间'].apply(lambda x: (x > T).sum())
        both = ((before > 0) & (after > 0)).sum()
        if both > max_both:
            max_both = both
            best_q = q
            best_T = T
    print(f'Found global pivot T* at q={best_q:.2f}, T={best_T}, users_both={max_both}')
    return best_T


def pick_first_future_in_final(user_df: pd.DataFrame, t0: pd.Timestamp, last_borrow_time: pd.Timestamp,
                               last_borrow_max_inter_id: Optional[int]) -> Optional[int]:
    if user_df.empty or (pd.isna(t0) and pd.isna(last_borrow_time)):
        return None
    df = user_df.copy()

    cond_a = pd.Series([False] * len(df), index=df.index)
    if pd.notna(t0):
        cond_a = df['借阅时间'] > t0

    cond_b = pd.Series([False] * len(df), index=df.index)
    if pd.notna(last_borrow_time) and pd.notna(last_borrow_max_inter_id):
        cond_b = (df['借阅时间'] == last_borrow_time) & (df['inter_id'] > last_borrow_max_inter_id)

    cond_c = pd.Series([False] * len(df), index=df.index)
    if pd.notna(t0):
        cond_c = (df['借阅时间'] > t0) & (df['借阅时间'] <= t0 + NEAR_DELTA)

    cand = df[(cond_a | cond_b | cond_c)].sort_values(['借阅时间', 'inter_id'], ascending=[True, True])
    for _, r in cand.iterrows():
        if not is_renew_anomaly(r):
            return int(r['book_id'])
    if not cand.empty:
        return int(cand.iloc[0]['book_id'])
    return None


def pick_next_after_global_T(user_df: pd.DataFrame, T_global: pd.Timestamp) -> Optional[int]:
    if user_df.empty or pd.isna(T_global):
        return None
    df = user_df.sort_values(['借阅时间', 'inter_id'])

    cond_a = df['借阅时间'] > T_global

    sameT = df[df['借阅时间'] == T_global]
    cond_b = pd.Series(False, index=df.index)
    if not sameT.empty:
        max_iid = sameT['inter_id'].max()
        cond_b = (df['借阅时间'] == T_global) & (df['inter_id'] > max_iid)

    cond_c = (df['借阅时间'] > T_global) & (df['借阅时间'] <= T_global + NEAR_DELTA)

    cand = df[(cond_a | cond_b | cond_c)]
    if cand.empty:
        return None
    for _, r in cand.iterrows():
        if not is_renew_anomaly(r):
            return int(r['book_id'])
    return int(cand.iloc[0]['book_id'])


def pick_first_future_in_reeval(user_df: pd.DataFrame, t0: pd.Timestamp, last_borrow_time: pd.Timestamp,
                                last_borrow_max_inter_id: Optional[int]) -> Optional[int]:
    if user_df.empty or (pd.isna(t0) and pd.isna(last_borrow_time)):
        return None
    df = user_df.copy()

    cond_a = pd.Series([False] * len(df), index=df.index)
    if pd.notna(t0):
        cond_a = df['借阅时间'] > t0

    cond_b = pd.Series([False] * len(df), index=df.index)
    if pd.notna(last_borrow_time) and pd.notna(last_borrow_max_inter_id):
        cond_b = (df['借阅时间'] == last_borrow_time) & (df['inter_id'] > last_borrow_max_inter_id)

    cond_c = pd.Series([False] * len(df), index=df.index)
    if pd.notna(t0):
        cond_c = (df['借阅时间'] > t0) & (df['借阅时间'] <= t0 + NEAR_DELTA)

    cand = df[(cond_a | cond_b | cond_c)].sort_values(['借阅时间', 'inter_id'], ascending=[True, True])
    if cand.empty:
        return None
    return int(cand.iloc[0]['book_id'])


def pick_abnormal_same_book(reeval_user_df: pd.DataFrame) -> Optional[int]:
    if reeval_user_df.empty:
        return None
    df = reeval_user_df.copy()
    df['not_returned'] = df['还书时间'].isna()
    df['has_renew_time_but_zero'] = df['续借时间'].notna() & (
                (df['续借次数'].isna()) | (df['续借次数'].astype('Int64') == 0))

    def renew_outside(row):
        rt, bt, rnt = row['还书时间'], row['借阅时间'], row['续借时间']
        if pd.isna(rnt): return False
        if pd.notna(bt) and rnt < bt: return True
        if pd.notna(rt) and rnt > rt: return True
        return False

    df['renew_outside'] = df.apply(renew_outside, axis=1)
    abn = df[(df['not_returned']) | (df['has_renew_time_but_zero']) | (df['renew_outside'])]
    if abn.empty:
        return None
    abn = abn.sort_values(['借阅时间', 'inter_id'], ascending=[True, True])
    return int(abn.iloc[0]['book_id'])


def pick_renew_leak_future(user_df: pd.DataFrame, t0: pd.Timestamp) -> Optional[int]:
    """S6: 续借时间被当成下一次借阅的泄露"""
    if user_df.empty:
        return None
    df = user_df.copy()
    df = df[df['续借时间'].notna()]
    if df.empty:
        return None
    df = df[df['借阅时间'].notna()]
    if df.empty:
        return None
    df = df[df['续借时间'] >= df['借阅时间']]
    if df.empty:
        return None

    cond_future = df['还书时间'].isna()
    cond_future |= (df['续借时间'] > df['还书时间'])
    df = df[cond_future.fillna(False)]
    if df.empty:
        return None

    if pd.notna(t0):
        after_t0 = df['续借时间'] > t0
        if after_t0.any():
            df = df[after_t0]
            if df.empty:
                return None

    df = df.sort_values(['续借时间', '借阅时间', 'inter_id'], ascending=[False, False, True])
    return int(df.iloc[0]['book_id'])


def pick_user_hist_fallback(reeval_user_df: pd.DataFrame, t0: pd.Timestamp) -> Optional[int]:
    if reeval_user_df.empty:
        return None
    hist = reeval_user_df.copy()
    if pd.notna(t0):
        hist = hist[(hist['借阅时间'].notna()) & (hist['借阅时间'] <= t0)]
    if hist.empty:
        hist = reeval_user_df
    freq = hist.groupby('book_id').size().reset_index(name='cnt')
    top_cnt = freq['cnt'].max()
    tops = set(freq[freq['cnt'] == top_cnt]['book_id'].tolist())
    cand = hist[hist['book_id'].isin(tops)].sort_values(['借阅时间', 'inter_id'], ascending=[False, True])
    return int(cand.iloc[0]['book_id'])


def main():
    print("加载数据...")
    reeval = read_csv_smart(REEVAL_INTER)
    final = read_csv_smart(FINAL_INTER)

    required_cols = ['inter_id', 'user_id', 'book_id', '借阅时间', '还书时间', '续借时间', '续借次数']
    for col in required_cols:
        if col not in reeval.columns:
            raise ValueError(f'复赛文件缺少列: {col}')
        if col not in final.columns:
            raise ValueError(f'决赛文件缺少列: {col}')

    reeval = parse_time_columns(reeval, ['借阅时间', '还书时间', '续借时间'])
    final = parse_time_columns(final, ['借阅时间', '还书时间', '续借时间'])

    for df in (reeval, final):
        df['inter_id'] = pd.to_numeric(df['inter_id'], errors='coerce').astype('Int64')
        df['user_id'] = pd.to_numeric(df['user_id'], errors='coerce').astype('Int64')
        df['book_id'] = pd.to_numeric(df['book_id'], errors='coerce').astype('Int64')
        df['续借次数'] = pd.to_numeric(df['续借次数'], errors='coerce').astype('Int64')

    reeval_users = pd.Series(reeval['user_id'].dropna().astype('Int64').unique())
    users = load_target_users(reeval_users)
    user_t0, user_last_borrow_time, user_last_borrow_max_iid = compute_user_time_refs(reeval)

    T_global = find_global_pivot_T(reeval)

    reeval_g = reeval.sort_values(['user_id', '借阅时间', 'inter_id']).groupby('user_id', group_keys=False)
    final_g = final.sort_values(['user_id', '借阅时间', 'inter_id']).groupby('user_id', group_keys=False)

    stats = {
        'S1_final_future': 0,
        'S6_renew_leak_future': 0,
        'S3_abnormal_same': 0,
        'S1.5_reeval_global_T': 0,
        'S2_reeval_future': 0,
        'S4_hist_fallback': 0,
        'S5_global_mode': 0,
    }

    rows = []
    for uid in users:
        t0 = user_t0.get(uid, pd.NaT)
        lbT = user_last_borrow_time.get(uid, pd.NaT)
        lbI = user_last_borrow_max_iid.get(uid, pd.NA)
        lbI = None if pd.isna(lbI) else int(lbI)

        strategy = None

        fin_u = final_g.get_group(uid) if uid in final_g.groups else pd.DataFrame(columns=final.columns)
        bid = pick_first_future_in_final(fin_u, t0, lbT, lbI)
        if bid is not None:
            strategy = 'S1_final_future'
        else:
            rev_u = reeval_g.get_group(uid) if uid in reeval_g.groups else pd.DataFrame(columns=reeval.columns)

            bid = pick_renew_leak_future(rev_u, t0)
            if bid is not None:
                strategy = 'S6_renew_leak_future'
            else:
                bid = pick_abnormal_same_book(rev_u)
                if bid is not None:
                    strategy = 'S3_abnormal_same'
                else:
                    bid = pick_next_after_global_T(rev_u, T_global)
                    if bid is not None:
                        strategy = 'S1.5_reeval_global_T'
                    else:
                        bid = pick_first_future_in_reeval(rev_u, t0, lbT, lbI)
                        if bid is not None:
                            strategy = 'S2_reeval_future'
                        else:
                            bid = pick_user_hist_fallback(rev_u, t0)
                            if bid is not None:
                                strategy = 'S4_hist_fallback'
                            else:
                                bid = int(reeval['book_id'].mode().iloc[0])
                                strategy = 'S5_global_mode'

        stats[strategy] += 1
        rows.append((int(uid), int(bid), strategy))

    sub = pd.DataFrame(rows, columns=['user_id', 'book_id', 'strategy']).drop_duplicates(subset=['user_id'])
    Path(OUT_CSV).parent.mkdir(parents=True, exist_ok=True)
    sub[['user_id', 'book_id']].to_csv(OUT_CSV, index=False, encoding='utf-8-sig')
    sub.to_csv(OUT_WITH_STR, index=False, encoding='utf-8-sig')

    total = len(sub)
    print(f'\nWrote {total} predictions to {OUT_CSV}')
    print('Strategy breakdown:')
    for k, v in stats.items():
        pct = (v / total * 100) if total else 0.0
        print(f'  {k:24s}: {v:6d} ({pct:5.2f}%)')
    print(f'Detailed file with strategies: {OUT_WITH_STR}')

    return sub


if __name__ == '__main__':
    main()