#!/usr/bin/env python3

import os
import sys
import argparse
import numpy as np
from colabdesign.af import mk_af_model

def add_cyclic_offset(af_model, bug_fix=True):
    """add cyclic offset to connect N and C term"""
    def cyclic_offset(L):
        i = np.arange(L)
        ij = np.stack([i, i + L], -1)
        offset = i[:, None] - i[None, :]
        c_offset = np.abs(ij[:, None, :, None] - ij[None, :, None, :]).min((2, 3))
        if bug_fix:
            a = c_offset < np.abs(offset)
            c_offset[a] = -c_offset[a]
        return c_offset * np.sign(offset)
    
    idx = af_model._inputs["residue_index"]
    offset = np.array(idx[:, None] - idx[None, :])
    
    if af_model.protocol in ["fixbb", "partial", "hallucination"]:
        Ln = 0
        for L in af_model._lengths:
            offset[Ln : Ln + L, Ln : Ln + L] = cyclic_offset(L)
            Ln += L
    af_model._inputs["offset"] = offset

def predict_cyclic_peptide(sequence, output_dir, params_path, num_recycles=3, use_multimer=False, model_num='1'):
    """Predict structure of a cyclic peptide"""
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Setup AlphaFold model
    flags = {
        "initial_guess": False,
        "best_metric": "rmsd", 
        "data_dir":params_path,
        "use_multimer": use_multimer,
        "model_names": [f"model_{model_num}_multimer_v3" if use_multimer else f"model_{model_num}_ptm"]
    }
    
    print("Setting up AlphaFold model...")
    af_model = mk_af_model(protocol="hallucination", **flags)
    
    # Prepare inputs for hallucination (no input PDB needed)
    af_model.prep_inputs(
        length=len(sequence),
        rm_aa="C"  # disable cysteine by default
    )
    
    print("Adding cyclic offset to connect N and C termini...")
    add_cyclic_offset(af_model)
    
    print(f"Predicting structure for sequence: {sequence}")
    print(f"Length: {len(sequence)} residues")
    
    # Run prediction
    af_model.predict(seq=sequence, num_recycles=num_recycles, verbose=True)
    
    # Get results
    results = af_model.aux["log"]
    
    # Print results
    print(f"\nResults:")
    print(f"pLDDT: {results['plddt']:.3f}")
    print(f"pTM: {results['ptm']:.3f}")
    print(f"PAE: {results['pae']:.3f}")
    
    # Save structure
    output_pdb = os.path.join(output_dir, f"cyclic_peptide.pdb")
    af_model.save_pdb(output_pdb)
    print(f"Structure saved to: {output_pdb}")
    
    # Save results to file
    results_file = os.path.join(output_dir, "results.txt")
    with open(results_file, 'w') as f:
        f.write(f"Sequence: {sequence}\n")
        f.write(f"Length: {len(sequence)}\n")
        f.write(f"pLDDT: {results['plddt']:.3f}\n")
        f.write(f"pTM: {results['ptm']:.3f}\n") 
        f.write(f"PAE: {results['pae']:.3f}\n")
        f.write(f"Recycles: {num_recycles}\n")
    
    return results

def main():
    parser = argparse.ArgumentParser(description='Predict cyclic peptide structure using AlphaFold')
    parser.add_argument('sequence', help='Amino acid sequence (single letter codes)')
    parser.add_argument('-o', '--output', required=True, help='Output directory')
    parser.add_argument('--params_path', default="/programs/x86_64-linux/colabdesign/1.1.2/", help='Path to AF params')
    parser.add_argument('--model_num', type=str, default='1', help='Number for AF model')
    parser.add_argument('--recycles', type=int, default=3, help='Number of recycles (default: 3)')
    parser.add_argument('--multimer', action='store_true', help='Use AlphaFold multimer')
    
    args = parser.parse_args()
    
    # Validate sequence
    valid_aa = set('ACDEFGHIKLMNPQRSTVWY')
    sequence = args.sequence.upper().strip()
    
    if not all(aa in valid_aa for aa in sequence):
        print("Error: Invalid amino acid characters in sequence")
        sys.exit(1)
    
    if len(sequence) < 3:
        print("Error: Sequence too short (minimum 3 residues)")
        sys.exit(1)

    if args.model_num not in ['1', '2', '3', '4', '5']:
        print(f"Error: Model num {args.model_num} is not in 1-5")
        sys.exit(1)

    if len(sequence) > 50:
        print("Warning: Long sequences may take a very long time to fold")
    
    # Run prediction
    try:
        results = predict_cyclic_peptide(
            sequence=sequence,
            output_dir=args.output,
            params_path=args.params_path,
            num_recycles=args.recycles,
            use_multimer=args.multimer,
            model_num=args.model_num
        )
        print("\nPrediction completed successfully!")
        
    except Exception as e:
        print(f"Error during prediction: {e}")
        sys.exit(1)

if __name__ == "__main__":
    main()
