# This only tests the command-line interface

import fmcs
import unittest
import time

import sys
from cStringIO import StringIO
orig_stdout = sys.stdout
orig_stderr = sys.stderr

def run(args, expect_sysexit=False):
    try:
        sys.stdout = stdout = StringIO()
        sys.stderr = stderr = StringIO()
        try:
            fmcs.main(args.split())
        finally:
            sys.stdout = orig_stdout
            sys.stderr = orig_stderr
    except SystemExit, err:
        if expect_sysexit:
            stderr.write(str(err) + "\n")
        else:
            print >>orig_stdout, stdout.getvalue()
            print >>orig_stderr, stderr.getvalue()
            raise AssertionError("Unexpected SystemExit: %r" % (err,))
    return stdout.getvalue(), stderr.getvalue()

def confirm(args, expected_num_atoms, expected_num_bonds, expected_smarts=None):
    try:
        sys.stdout = stdout = StringIO()
        sys.stderr = stderr = StringIO()
        try:
            fmcs.main(args.split())
        finally:
            sys.stdout = orig_stdout
            sys.stderr = orig_stderr
    except SystemExit, err:
        sys.stdout.write(stdout.getvalue())
        sys.stdout.write(stderr.getvalue())
        raise AssertionError("Unexpected SystemExit: %r" % (err,))
    text = stdout.getvalue()
    if text == "No MCS found\n":
        if expected_num_atoms is None and expected_num_atoms is None:
            return
        raise AssertionError(text)
    fields = text.split()
    smarts, num_atoms, atoms, num_bonds, bonds, word1, word2 = fields
    if expected_smarts is not None:
        # This isn't a good test; it's highly dependent on the SMARTS output code,
        # and nothing promises that the output will be an invariant, canonical SMARTS
        assert smarts == expected_smarts, (smarts, expected_smarts)
    num_atoms = int(num_atoms)
    num_bonds = int(num_bonds)
    assert (num_atoms, num_bonds) == (expected_num_atoms, expected_num_bonds), (
        (num_atoms, num_bonds), (expected_num_atoms, expected_num_bonds))

def sysexit(args):
    sys.stdout = stdout = StringIO()
    sys.stderr = stderr = StringIO()
    try:
        try:
            fmcs.main(args.split())
        finally:
            sys.stdout = orig_stdout
            sys.stderr = orig_stderr
    except SystemExit, err:
        return str(err)
    raise AssertionError("should have raised SystemExit")

class TestFormats(unittest.TestCase):
    def test_missing_smiles_file(self):
        msg = sysexit("does_not_exist.smi")
        self.assertEqual("Unable to open SMILES file 'does_not_exist.smi'", msg)

    def test_missing_smiles_file_uppercase(self):
        msg = sysexit("does_not_exist.SMI")
        self.assertEqual("Unable to open SMILES file 'does_not_exist.SMI'", msg)
        
    def test_missing_sd_file(self):
        msg = sysexit("does_not_exist.sdf")
        self.assertEqual("Unable to open SD file 'does_not_exist.sdf'", msg)

    def test_missing_sd_file_uppercase(self):
        msg = sysexit("does_not_exist.SDF")
        self.assertEqual("Unable to open SD file 'does_not_exist.SDF'", msg)

    def test_compressed(self):
        msg = sysexit("does_not_exist.sdf.gz")
        self.assertEqual("gzip compressed files not yet supported", msg)

    def test_unsupported_extension(self):
        msg = sysexit("does_not_exist.pdb")
        self.assertEqual("Only SMILES (.smi) and SDF (.sdf) files are supported", msg)

        
    def test_smiles(self):
        confirm("simple.smi", 2, 1)

    def test_sdf(self):
        confirm("simple.sdf", 2, 1)

    def test_bad_record(self):
        stdout, stderr = run("bad_record.smi")
        self.assertTrue("Skipping unreadable structure #2\n" in stderr, stderr)

    def test_bad_format(self):
        stdout, stderr = run("not_really_an_sdf.sdf", expect_sysexit=True)
        self.assertTrue(
            "Input file 'not_really_an_sdf.sdf' must contain at least two structures\n" in stderr, stderr)
        

class TestMinAtoms(unittest.TestCase):
    def test_min_atoms_2(self):
        confirm("simple.smi --min-num-atoms 2", 2, 1)
    def test_min_atoms_3(self):
        confirm("simple.smi --min-num-atoms 3", None, None)
    def test_min_atoms_1(self):
        stdout, stderr = run("simple.smi --min-num-atoms 1", expect_sysexit=True)
        self.assertTrue("--min-num-atoms: must be at least 2, not 1" in stderr, stderr)

class TextMaximize(unittest.TestCase):
    # C12CCC1CC2OCCCCCCC 2-rings-and-chain-with-O
    # C12CCC1CC2SCCCCCCC 2-rings-and-chain-with-S
    def test_maximize_default(self):
        # default maximizes the number of bonds
        confirm("maximize.smi", 6, 7)
    def test_maximize_atoms(self):
        confirm("maximize.smi --maximize atoms", 7, 6)
    def test_maximize_bonds(self):
        confirm("maximize.smi --maximize bonds", 6, 7)
        


class TestAtomTypes(unittest.TestCase):
    # The tests compare:
    #   c1ccccc1O
    #   CCCCCCOn1cccc1
    def test_atom_compare_default(self):
        confirm("atomtypes.smi", 4, 3) # 'cccc'
    def test_atom_compare_elements(self):
        confirm("atomtypes.smi --atom-compare elements", 4, 3) # 'cccc'
    def test_atom_compare_any(self):
        # Note: bond aromaticies must still match!
        confirm("atomtypes.smi --atom-compare any", 6, 5) # 'cccccO' matches 'ccccnO'
    def test_atom_compare_any_bond_compare_any(self):
        # Linear chain of 7 atoms
        confirm("atomtypes.smi --atom-compare any --bond-compare any", 7, 6) 
    def test_bond_compare_any(self):
        # Linear chain of 7 atoms
        confirm("atomtypes.smi --bond-compare any", 7, 6)

class TestIsotopes(unittest.TestCase):
    # C1C[0N]CC[5C]1[1C][2C][2C][3C] C1223
    # C1CPCC[4C]1[2C][2C][1C][3C] C2213
    def test_without_isotope(self):
        # The entire system, except the N/P in the ring
        confirm("isotopes.smi", 9, 8)
    def test_isotopes(self):
        # 5 atoms of class '0' in the ring
        confirm("isotopes.smi --atom-compare isotopes", 5, 4)
    def test_isotope_complete_ring_only(self):
        # the 122 in the chain
        confirm("isotopes.smi --atom-compare isotopes --complete-rings-only", 3, 2)



class TestBondTypes(unittest.TestCase):
    # C1CCCCC1OC#CC#CC#CC#CC#CC 
    # c1ccccc1ONCCCCCCCCCC second
    def test_bond_compare_default(self):
        # Match the 'CCCCCC' part of the first ring, with the second's tail
        confirm("bondtypes.smi", 6, 5) 
    def test_bond_compare_bondtypes(self):
        # Repeat of the previous
        confirm("bondtypes.smi --bond-compare bondtypes", 6, 5) 
    def test_bond_compare_any(self):
        # the CC#CC chain matches the CCCC tail
        confirm("bondtypes.smi --bond-compare any", 10, 9)
    def test_atom_compare_elements_bond_compare_any(self):
        confirm("bondtypes.smi --atom-compare elements --bond-compare any", 10, 9)
    def test_atom_compare_any_bond_compare_any(self):
        # complete match!
        confirm("bondtypes.smi --atom-compare any --bond-compare any", 18, 18)

class TestCompareOption(unittest.TestCase):
    def test_compare_topology(self):
        confirm("bondtypes.smi --compare topology", 18, 18)
    def test_compare_elements(self):
        confirm("bondtypes.smi --compare elements", 10, 9)
    def test_compare_types(self):
        confirm("bondtypes.smi --compare types", 6, 5)


class TestRingMatchesRingOnly(unittest.TestCase):
    # C12CCCC(N2)CCCC1 6-and-7-bridge-rings-with-N
    # C1CCCCN1 6-ring
    # C1CCCCCN1 7-ring
    # C1CCCCCCCC1 9-ring
    # NC1CCCCCC1 N+7-ring
    # C1CC1CCCCCC 3-ring-with-tail
    # C12CCCC(O2)CCCC1 6-and-7-bridge-rings-with-O
    def test_default(self):
        # Should match 'CCCCC'
        confirm("rings.smi", 5, 4)
    def test_ring_only(self):
        # Should match "CCC"
        confirm("rings.smi --ring-matches-ring-only", 3, 2)
    def test_ring_only_select_1_2(self):
        # Should match "C1CCCCCN1"
        confirm("rings.smi --ring-matches-ring-only --select 1,2", 6, 6)
    def test_ring_only_select_1_3(self):
        # Should match "C1CCCCCCN1"
        confirm("rings.smi --ring-matches-ring-only --select 1,3", 7, 7)
    def test_ring_only_select_1_4(self):
        # Should match "C1CCCCCCCC1"
        confirm("rings.smi --ring-matches-ring-only --select 1,4", 9, 9)
    def test_select_1_5(self):
        # Should match "NCCCCCC"
        confirm("rings.smi --select 1,5", 8, 7)
    def test_ring_only_select_1_5(self):
        # Should match "CCCCCC"
        confirm("rings.smi --ring-matches-ring-only --select 1,5", 7, 6)
    def test_select_1_6(self):
        # Should match "CCCCCCCCC" by breaking one of the 3-carbon ring bonds
        confirm("rings.smi --select 1,6", 9, 8)
    def test_ring_only_select_1_6(self):
        # Should match "CCC" from the three atom ring
        confirm("rings.smi --ring-matches-ring-only --select 1,6", 3, 2)
    def test_ring_only_select_1_7(self):
        # Should match the outer ring "C1CCCCCCCC1"
        confirm("rings.smi --ring-matches-ring-only --select 1,7", 9, 9)
    def test_ring_only_select_1_7_any_atoms(self):
        # Should match everything
        confirm("rings.smi --ring-matches-ring-only --select 1,7 --atom-compare any", 10, 11)

class TestCompleteRingsOnly(unittest.TestCase):
    # C12CCCC(N2)CCCC1 6-and-7-bridge-rings-with-N
    # C1CCCCN1 6-ring
    # C1CCCCCN1 7-ring
    # C1CCCCCCCC1 9-ring
    # NC1CCCCCC1 N+7-ring
    # C1CC1CCCCCC 3-ring-with-tail
    # C12CCCC(O2)CCCC1 6-and-7-bridge-rings-with-O
    def test_ring_only(self):
        # No match: "CCC" is not in a ring
        confirm("rings.smi --complete-rings-only", None, None)
    def test_ring_only_select_1_2(self):
        # Should match "C1CCCCCN1"
        confirm("rings.smi --complete-rings-only --select 1,2", 6, 6)
    def test_ring_only_select_1_3(self):
        # Should match "C1CCCCCCN1"
        confirm("rings.smi --complete-rings-only --select 1,3", 7, 7)
    def test_ring_only_select_1_4(self):
        # Should match "C1CCCCCCCC1"
        confirm("rings.smi --complete-rings-only --select 1,4", 9, 9)
    def test_ring_only_select_1_5(self):
        # No match: "CCCCCC" is not in a ring
        confirm("rings.smi --complete-rings-only --select 1,5", None, None)
    def test_ring_only_select_1_7(self):
        # Should match the outer ring "C1CCCCCCCC1"
        confirm("rings.smi --complete-rings-only --select 1,7", 9, 9)
    def test_ring_only_select_1_7_any_atoms(self):
        # Should match everything
        confirm("rings.smi --complete-rings-only --select 1,7 --atom-compare any", 10, 11)


    def test_ring_to_nonring_bond(self):
        # Should allow the cO in phenol to match the CO in the other structure
        confirm("atomtypes.smi --complete-rings-only", 2, 1)

class TestSelect(unittest.TestCase):
    # I tested some of these already
    def test_select_range(self):
        # CCCCCN
        confirm("rings.smi --select 1-3", 6, 5)
    def test_select_multirange(self):
        # CCCCC
        confirm("rings.smi --select 1-3,6", 5, 4)
    def test_to_end(self):
        # CCCC
        confirm("rings.smi --ring-matches-ring-only --select 6-", 3, 2)

    def test_range_start_not_integer(self):
        stdout, stderr = run("rings.smi --select A-9", expect_sysexit=True)
        self.assertTrue("Unknown character at position 1 of 'A-9'" in stderr, stderr)

    def test_unkown_after_first_index(self):
        stdout, stderr = run("rings.smi --select 3A-9", expect_sysexit=True)
        self.assertTrue("Unknown character at position 2 of '3A-9'" in stderr, stderr)

    def test_range_end_not_integer(self):
        stdout, stderr = run("rings.smi --select 3-A9", expect_sysexit=True)
        self.assertTrue("Unknown character at position 3 of '3-A9'" in stderr, stderr)

    def test_unknown_after_range_end(self):
        stdout, stderr = run("rings.smi --select 3-9A", expect_sysexit=True)
        self.assertTrue("Unknown character at position 4 of '3-9A'" in stderr, stderr)

    def test_multiple_commas(self):
        stdout, stderr = run("rings.smi --select 1,,2", expect_sysexit=True)
        self.assertTrue("Unknown character at position 3 of '1,,2'" in stderr, stderr)

class TestTimeout(unittest.TestCase):
    # this should take 12+ seconds to process. Give it 0.1 seconds.
    def test_timeout(self):
        t1 = time.time()
        stdout, stderr = run("lengthy.smi --timeout 0.1")
        t2 = time.time()
        self.assertTrue(t2-t1 < 0.5, t2-t1)
        self.assertTrue("(timed out)" in stdout, stdout)

    # Make sure the "--timeout none" option is supported
    def test_timeout_none(self):
        confirm("atomtypes.smi --timeout none", 4, 3)

    # Check for non-negative values
    def test_timeout_negative(self):
        stdout, stderr = run("atomtypes.smi --timeout -1.0", expect_sysexit=True)
        self.assertTrue("Must be a non-negative value, not '-1.0'" in stderr)

class TestTimes(unittest.TestCase):
    def test_times(self):
        stdout, stderr = run("atomtypes.smi --times")
        for phrase in ("Total time", "load", "fragment", "select", "enumerate", "MCS found after"):
            self.assertTrue(phrase in stderr, (phrase, stderr))

class TestVerbosity(unittest.TestCase):
    def test_verbose_1(self):
        stdout, stderr = run("atomtypes.smi --verbose")
        self.assertTrue("Loaded " in stderr, stderr)
        self.assertFalse("Best after " in stderr, stderr)
        self.assertFalse("unique SMARTS, cache: " in stderr, stderr)
        self.assertFalse("subgraphs enumerated, " in stderr, stderr)
        # Make sure it enabled --times
        for phrase in ("Total time", "load", "fragment", "select", "enumerate", "MCS found after"):
            self.assertTrue(phrase in stderr, (phrase, stderr))

    def test_verbose_2(self):
        stdout, stderr = run("atomtypes.smi --verbose -v")
        self.assertTrue("Loaded " in stderr, stderr)
        self.assertTrue("Best after " in stderr, stderr)
        self.assertTrue("unique SMARTS, cache: " in stderr, stderr)
        self.assertTrue("subgraphs enumerated, " in stderr, stderr)
        # Make sure it enabled --times
        for phrase in ("Total time", "load", "fragment", "select", "enumerate", "MCS found after"):
            self.assertTrue(phrase in stderr, (phrase, stderr))


class TestOutputFormatFragmentSDF(unittest.TestCase):
    # RDKit fragmentation sometimes (but when?) requires a
    # Chem.FastFindRings() or SSSR() perception to get the right
    # chemistry on the newly created molecule. My original
    # implementation didn't do this, so I ended up with
    # a "RuntimeError: Pre-condition Violation" saying:
    #   Pre-condition Violation
    #   not initialized
    #   Violation occurred on line 67 in file .... Code/GraphMol/RingInfo.cpp
    #   Failed Expression: df_init
    def test_failing_ring(self):
        stdout, stderr = run("../sample_files/ace.sdf --output-format fragment-sdf")



if __name__ == "__main__":
    unittest.main()
