import numpy as np
from Bio import SeqIO
import random
import sys
import os

def generate_negative_set(positive_sequences, genome_sequences, n_negatives=None):
    """生成与阳性序列长度和数量匹配的阴性序列集"""
    if n_negatives is None:
        n_negatives = len(positive_sequences)
    
    negative_sequences = {}
    
    # 为每个阳性序列生成一个长度匹配的阴性序列
    for i, (pid, pseq) in enumerate(positive_sequences.items()):
        if i >= n_negatives:
            break
            
        seq_length = len(pseq)
        
        # 方法1: 从基因组中随机选取长度相近的序列片段
        found = False
        attempts = 0
        while not found and attempts < 100:
            random_seq_id = random.choice(list(genome_sequences.keys()))
            random_seq = genome_sequences[random_seq_id]
            
            if len(random_seq) >= seq_length:
                start_pos = random.randint(0, len(random_seq) - seq_length)
                negative_seq = random_seq[start_pos:start_pos + seq_length]
                
                neg_id = f"negative_{i+1}_len{seq_length}"
                negative_sequences[neg_id] = negative_seq
                found = True
            
            attempts += 1
        
        if not found:
            print(f"Failed to extract from genome for sequence {i+1}. Generating random sequence of length {{seq_length}}")
            negative_seq = generate_random_sequence(seq_length, genome_sequences)
            neg_id = f"negative_{i+1}_random_len{seq_length}"
            negative_sequences[neg_id] = negative_seq
    
    return negative_sequences

def generate_random_sequence(length, genome_sequences, k=3):
    """基于背景序列的k-mer频率生成随机序列"""
    all_sequences = ''.join(genome_sequences.values())
    
    nucleotides = ['A', 'T', 'C', 'G']
    nucleotide_counts = {nt: all_sequences.count(nt) for nt in nucleotides}
    total_nucleotides = sum(nucleotide_counts.values())
    
    if total_nucleotides == 0:
        nucleotide_probs = [0.25, 0.25, 0.25, 0.25]
    else:
        nucleotide_probs = [nucleotide_counts[nt] / total_nucleotides for nt in nucleotides]
    
    random_sequence = ''.join(np.random.choice(nucleotides, size=length, p=nucleotide_probs))
    return random_sequence

def read_fasta(file_path):
    """读取FASTA文件到字典"""
    sequences = {}
    with open(file_path, 'r') as file:
        for record in SeqIO.parse(file, "fasta"):
            sequences[record.id] = str(record.seq)
    return sequences

def write_fasta(sequences, output_file):
    """将序列字典写入FASTA文件"""
    with open(output_file, 'w') as file:
        for seq_id, sequence in sequences.items():
            file.write(f">{seq_id}\n{sequence}\n")

def main():
    # 使用命令行参数
    if len(sys.argv) != 4:
        print("Usage: python generate_negative_set.py <positive_fasta> <genome_fasta> <output_fasta>")
        sys.exit(1)
    
    positive_file = sys.argv[1]
    genome_file = sys.argv[2]
    output_file = sys.argv[3]
    
    # 确保输出目录存在
    output_dir = os.path.dirname(output_file)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)
        print(f"Created output directory: {output_dir}")
    
    print("=== Reading fasta files ===")
    positive_seqs = read_fasta(positive_file)
    genome_seqs = read_fasta(genome_file)
    
    print("=== Generating negative set ===")
    negative_seqs = generate_negative_set(positive_seqs, genome_seqs)
    
    write_fasta(negative_seqs, output_file)
    
    print("=== Successfully generating negative set ===")
    print(f"阳性序列总数: {len(positive_seqs)}")
    print(f"阴性序列总数: {len(negative_seqs)}")
    
    pos_lens = [len(s) for s in positive_seqs.values()]
    neg_lens = [len(s) for s in negative_seqs.values()]
    
    print(f"阳性序列长度 - 最小: {min(pos_lens)}, 最大: {max(pos_lens)}, 平均: {np.mean(pos_lens):.2f}")
    print(f"阴性序列长度 - 最小: {min(neg_lens)}, 最大: {max(neg_lens)}, 平均: {np.mean(neg_lens):.2f}")

if __name__ == "__main__":
    main()
