#!/usr/bin/env python


import os, sys
import argparse
import pickle
import json

def get_diffused_residues(inp):
    if inp.endswith(".pdb"):
        pdb = inp
        trb = inp[:-4] + ".trb"
    elif inp.endswith(".trb"):
        pdb = inp[:-4] + ".pdb"
        trb = inp
    else:
        print(f"ERROR: Input file `{inp}` is neither a pdb nor trb file")
        return None

    if not os.path.exists(pdb):
        print(f"ERROR: File `{pdb}` does not exist")
        return None
    if not os.path.exists(trb):
        print(f"ERROR: File `{trb}` does not exist")
        return None

    try:
        with open(trb, 'rb') as f:
            data = pickle.load(f)
    except:
        print(f"ERROR: Cannot load file {trb} - likely incorrect format.")
        return None

    residues = []
    for ii, val in enumerate( data["inpaint_seq"] ): # Likely want inpainting versus `sampled_mask`
        if not val:
            residues.append(ii)

    for chain, num in data["con_ref_pdb_idx"]:
        if chain != 'A':
            print(f"ERROR: This quick and dirty script can't handle situations with multiple chains.")

    return pdb, ["A" + str(r) for r in residues]

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Convert RFDiffusionAA .trb output files into a LigandMPNN input specification")
    parser.add_argument("inputs", nargs="+", help="The input pdbs (with associated .trb files) to use.")
    parser.add_argument("--output", '-o', default="residues.json", help="The specification file to write")

    args = parser.parse_args()

    data = {}
    for inp in args.inputs:
        output = get_diffused_residues(inp)
        if output is None:
            continue
        pdb, resis = output
        data[pdb] = resis

    with open(args.output, "w") as f:
        json.dump(data, f, indent=4)
