#$Id: molNavigator.py 110 2011-07-07 21:15:03Z sarkiss $
"""Adds "Molecules" page to Navigator and provides mechanism for displaying molecules.
"""
import  os, wx, vtk, pickle
from enthought.tvtk import messenger
from Pmv.pmvPalettes import AtomElements
from PyBabel.babelElements import babel_elements
from MolKit.protein import Protein, Chain, Residue
from MolKit.molecule import Atom, Molecule, AtomSet, MoleculeSet
from MolKit.pdbParser import PdbqtParser
import wx.lib.customtreectrl as CT
from MolKit.pdbParser import PdbParser
from icons import molPNG, chainPNG, residuePNG, atomPNG
ID_CLEAR = wx.NewId() 
OldAtomElements = { 'N':(0.,0.,1.), 'C':(.7,.7,.7), 'O':(1.,0.,0.),
                 'H':(0.,1.,1.), 'S':(1.,1.,0.), 'A':(0.,1.,0.) }
AtomElements.update(OldAtomElements)
class MolNavigator(CT.CustomTreeCtrl):
    "Molecule Navigator"
    def __init__(self, frame):
        "Constructor for MolNavigator"
        style=wx.SUNKEN_BORDER | CT.TR_HAS_BUTTONS | CT.TR_HIDE_ROOT | CT.TR_AUTO_CHECK_PARENT | \
        CT.TR_NO_LINES| CT.TR_AUTO_CHECK_CHILD | CT.TR_AUTO_TOGGLE_CHILD | wx.TR_MULTIPLE
        CT.CustomTreeCtrl.__init__(self, frame.navigator, id=-1, agwStyle=style)    
        
        self.images = []
        il = wx.ImageList(16, 16)

        il.Add(molPNG)
        il.Add(chainPNG)
        il.Add(residuePNG)
        il.Add(atomPNG)
        
        self.AssignImageList(il)

        self.Bind(CT.EVT_TREE_ITEM_CHECKED, self.EvtCheckListBox)
        self.Bind(CT.EVT_TREE_SEL_CHANGED, self.OnSelChanged)
        self.Bind(wx.EVT_RIGHT_UP, self.OnRightUp)
        self.root = self.AddRoot("Root")
        self.molecules = []
        self.SetPyData(self.root, self.molecules)
        self.moleculesNames = []
        frame.navigator.AddPage(self, "Molecules", bitmap=atomPNG)
        self.frame = frame
        frame.molecules = self.molecules
        self.selectionAssembly = None
        self.toggleSelection = False
        self.EnableSelectionGradient(False)
        self.EnableSelectionVista(True)  

        
    def OnMakeLigand(self, event):
        residuesHOH = [] #water residues 
        erroMsg = "Make Ligand Failed."
        data = self.GetPyData(self.item)
        if hasattr(data,'chains'):    
            index = self.root._children.index(self.item)
            mol = data
        else:            
            index = None
            prot = Protein()
            prot.chains.append(data)
            if len(data.residues) == 1:
                data.name = data.residues[0].name
            if data.name.strip():
                prot.name = data.name.strip()
            else:
                prot.name = data.parent.name + "_"
                
            if data.residues:
                prot.name += data.residues[0].name
                      
            for residue in data.residues: 
                if "HOH" in residue.name: 
                    residuesHOH.append(residue)
                     
            if residuesHOH ==data.residues:
                erroMsg = data.name +" contains only water (HOH)."
                
            prot.allAtoms = data.residues.atoms
            prot.allAtoms.top = prot
            prot.parser = data.parent.parser
            mol = prot
            for atom in prot.allAtoms:
                data.parent.allAtoms.remove(atom)
        try:
            if len(mol.allAtoms) > 500:
                dlg = wx.MessageDialog(self, "Warning: This molecule has " + str(len(mol.allAtoms))+
                                       " atoms and it might take long time to make a ligand out of it. \n\nWould you like to cancel this job?",
                               'Make Ligand Warning!', wx.YES_NO)
                if dlg.ShowModal() == wx.ID_YES:                 
                    return
                dlg.Destroy()
                
            mol = self.frame.autodockNav.AddLigand(mol)
        except Exception, inst:
            if len(mol.allAtoms) == 0:
                dlg = wx.MessageDialog(self, erroMsg+"\nRemove "+mol.name +" from viewer?",
                               'Make Ligand Failed', wx.OK|wx.CANCEL,
                               )
                if dlg.ShowModal() == wx.ID_OK:                 
                    self.OnRemove(None)
                dlg.Destroy()
                return                
            else:
                wx.MessageBox(erroMsg, "Make Ligand Failed")
                raise
        
        self.ext = 'PDBQT'
        if not mol: return #error message from AddLigand appears in Logger
        self.OnRemove(None)
        self.AddMolecule(mol, index, resetCamera=False)
        self.Rerender()
        self.frame.shell.prompt() #for adding gasteiger charges to peptide
        txt = "Wrote " + mol.name +" to " + self.frame.vsModel.ligandsFolder
        self.frame.statusBar.SetStatusText(txt, 0)
        
    def OnMakeMacromolecule(self, event):
        dlg = wx.ProgressDialog("Please Wait...",
                               "Making AutoDock Macromolecule...",
                               parent=self.frame,
                               style = wx.PD_APP_MODAL  )                     
        
        data = self.GetPyData(self.item)
        if hasattr(data,'chains'):    
            index = self.root._children.index(self.item)
            mol = data
        else:
            index = None
            prot = Protein()
            prot.chains.append(data)
            prot.name = data.parent.name+"_"
            if data.name: #when chain has no name
                prot.name += data.name
            prot.parser = data.parent.parser
            mol = prot
        macromoleculePath = self.frame.autodockNav.AddMacromolecule(mol)
        if not macromoleculePath:

            return
        self.ext = 'pdbqt'
        self.OnRemove(None)
        mol = self.ReadMolecule(macromoleculePath)[0]
        self.AddMolecule(mol, index=index, resetCamera=False)
        self.Rerender()
        self.frame.shell.prompt() #adding gasteiger charges to peptide
        txt = "Wrote " + mol.name +" to " + self.frame.vsModel.macromoleculePath
        self.frame.statusBar.SetStatusText(txt, 0)    
        dlg.Destroy()
    def OnFlexResidues(self, event):
        nodes = self.GetSelections()
        flexRes = []
        for node in nodes:
            res  = self.GetPyData(node)
            if isResidue(res):
                flexRes.append(res)
        self.frame.autodockNav.AddFlexResidues(flexRes)
        self.ext = 'pdbqt'
        self.item =  self.GetSelections()[0].GetParent()
        if not  isinstance(self.GetPyData(self.item), Molecule):#must be chain
            self.item =  self.item.GetParent()
        #self.OnRemove(None)
        #mol = self.TryOpenMolecule(self.frame.vsModel.macromoleculePath)
        if self.frame.vsModel.flex_residues:            
            self.TryOpenMolecule(self.frame.vsModel.flexres_filename)       
        else:
            self.frame.log.warn("Selected residues are not flexible: "+str(flexRes)) 
        self.Rerender()

    def TryOpenMolecule(self, filename):
        "Included in try/except to log possible Traceback "
        return self.frame.TryCommand(self.OpenMolecule, filename)
        
    def OpenMolecule(self, filename):
        "Read and display molecules from filename"
        if not os.path.exists(filename):
            self.frame.log.error("File does not exist: "+filename)
            return
        self.frame.statusBar.SetStatusText("Parsing %s. Please Wait..."%(filename), 0)
        mols = self.ReadMolecule(filename)
        if not mols:
            self.frame.log.error("No molecule in "+filename)
            return
        for mol in mols:
            mol.lenAtoms = len(mol.allAtoms)
            txt = "Read %s - %d chain(s) - %d atoms"%(mol.name, len(mol.chains), mol.lenAtoms)
            self.frame.statusBar.SetStatusText(txt, 0)
            self.frame.fileHistory.AddFileToHistory(filename)
            self.AddMolecule(mol)
        return mols
        
#    def AddBonds(self, mol, name=None, force=False):
#        if name == None:
#            name = mol.name
#        fileName = mol.parser.filename
#        if force or self.frame.vsModel.ligandsFolder in fileName:
#            pickleFileName = os.path.join(self.frame.vsModel.etcFolder, name + ".pkl")
#            if os.path.exists(pickleFileName):
#                dict = pickle.load(open(pickleFileName))
#                if dict.has_key('CONECT'):
#                    mol.allAtoms.bonds[0].data = []#clear bond records
#                    parser = PdbParser()
#                    parser.mol = mol
#                    parser.parse_PDB_CONECT(dict['CONECT'])
#                    return
#        self.frame.TryCommand(mol.buildBondsByDistance)
        
    def AddMolecule(self, mol, index=None, resetCamera=True):
        "Add mol to the scene"
        if index == None:
            index = len(self.molecules)
        self.molecules.insert(index, mol)
        name = mol.name
        while name in self.moleculesNames:
            name = name+'.'
        mol.name = name
        self.moleculesNames.insert(index, name)               
        molTreeID =  self.InsertItemByIndex(self.root, index, name, ct_type=1, image=0 )
        self.SetPyData(molTreeID, mol)
        if hasattr(mol, 'geomContainer'):
            assembly = mol.geomContainer.masterGeom.obj #from ePMV
        else:
            assembly = vtk.vtkAssembly()
        if len(mol.chains) > 1:
            for chain in mol.chains:
                part = self.GenerateAssambly(chain.residues.atoms)
                part.chainName = chain.name               
                item = self.AppendItem(molTreeID, chain.name, ct_type=1, image=1)
                part.treeID = item
                chain.assembly = part #this makes self.object._vtk_ob in vtkSupport.TVTKBranchNode to remember chainName attribute
                assembly.AddPart(part)  
                self.SetPyData(item, chain)
                part.AddObserver('ModifiedEvent', messenger.send)
                messenger.connect(part, 'ModifiedEvent', self.OnModified)
                self.BuildResidueTree(item, chain)
        else:
            assembly = self.GenerateAssambly(mol.allAtoms)
            self.BuildResidueTree(molTreeID, mol.chains[0])
        self.CheckItem2(molTreeID, True)
        self.CheckChilds(molTreeID, True)
        assembly.AddObserver('ModifiedEvent', messenger.send)
        messenger.connect(assembly, 'ModifiedEvent', self.OnModified)
        mol.assembly = assembly
        assembly.molName = mol.name #this is needed for our VTK Pipline browser
        assembly.treeID = molTreeID
        self.frame.renderer3D.AddActor(assembly)        
        if resetCamera:
            self.frame.renderer3D.ResetCamera()  
            self.frame.canvas3D.Refresh()
        self.frame.view.SetSelection(0)# 3D Viewer
        self.frame.navigator.SetSelection(0)# Molecules
        self.item = molTreeID
        
    def BuildResidueTree(self, treeID, chain):
        if len(chain.residues) > 1:
            for residue in chain.residues:
                item = self.AppendItem(treeID, residue.name,  image=2)
                self.SetPyData(item, residue)
                self.BuildAtomTree(item, residue)
        else:
            self.BuildAtomTree(treeID, chain.residues[0])
                   
    def BuildAtomTree(self, treeID, residue):
        for atom in residue.atoms:
            item = self.AppendItem(treeID, atom.name,  image=3)
            self.SetPyData(item, atom)
        
    def OnModified(self, obj, evt):
        if not hasattr(obj, 'treeID'): return
        if obj.GetVisibility():
            self.CheckItem2(obj.treeID, True)
        else:
            self.CheckItem2(obj.treeID, False)
        self.Rerender()
                          
    def ReadMolecule(self, filename):
        "Read a molecule from filename"
        return self.frame.pmv.mv.readMolecule(filename)
        
    def GenerateAssambly(self, atoms, selection=False):
        lenAtoms = len(atoms)
        atoms[0].lenAtoms = lenAtoms # hold to this variables since it used in other places
        self.frame.progressTextSuffix = None
        assembly = vtk.vtkAssembly()
        glyph = vtk.vtkGlyph3D() # used to hold spheres
        lut = vtk.vtkLookupTable() #used for coloring
        atoms[0].assembly = assembly
        assembly.glyph = glyph
        assembly.lut = lut
        if lenAtoms > 5000:
            self.frame.progressText = "Building geometries for "+atoms.getStringRepr()
            self.frame.ConfigureProgressBar(max = lenAtoms/10 + 1)
        else:
            self.frame.progressCount = 0
            self.frame.progressMax = -1

        if lenAtoms > 100:
            resolution = 9
        else:
            resolution = 20
        atomSet = set(atoms.element)
        lenAtomSet = len(atomSet)
        atoms[0].lenAtomSet = lenAtomSet
        assembly.lutLength = lenAtomSet+2
        lut.SetNumberOfTableValues(lenAtomSet+2)
        errorTxt = ""
        for i, atom in enumerate(atomSet):  #build spheres for each type of atom
            sphere = vtk.vtkSphereSource()
            sphere.SetThetaResolution(resolution)
            sphere.SetPhiResolution(resolution)
            if len(atom) == 2:
                atom = atom[0].upper()+atom[1].lower()
            try:
                rad = babel_elements[atom]["bs_rad"]
            except KeyError, inst:
                errorTxt += atom +" "
                rad = 0.3
            rad = rad/1.3
            if selection:
                sphere.SetRadius(rad+0.2)
            else:
                sphere.SetRadius(rad)                
            glyph.SetSource(i, sphere.GetOutput())
            if atom in AtomElements:
                color = AtomElements[atom]
            else:
                color = AtomElements['A']
            if selection:
                lut.SetTableValue(i, 1, 0.5, 0.8, 0.8)
            else:
                lut.SetTableValue(i, color[0], color[1], color[2], 1)
        
        lut.SetTableValue(lenAtomSet, 1, 1, 0.5, 1) #selected
        lut.SetTableValue(lenAtomSet+1, 0, 0, 0, 0) #hide
        
        #setup the glyph
        polys = vtk.vtkCellArray() #holds lines between atoms
        points = vtk.vtkPoints()   #holds atom coordinates
        points.SetDataTypeToFloat()
        points.SetNumberOfPoints(lenAtoms)
        createIndices = vtk.vtkIdTypeArray() #holds atom type
        createIndices.SetNumberOfValues(lenAtoms)
        createCellArray = vtk.vtkCellArray() #each atom is in a unique cell
        createCellArray.Allocate(lenAtoms,1)
        atomList = list(atomSet)
        atoms.bonds #this is needed to create atom._bndIndex_
        for i, atom in enumerate(atoms):
            points.SetPoint(i, atom.coords)
            createIndices.SetValue(i, atomList.index(atom.element))
            createCellArray.InsertNextCell(1)
            createCellArray.InsertCellPoint(i) 
            for bond in atom.bonds:
                if not selection:
                    try:
                        index1 = bond.atom1._bndIndex_
                        index2 = bond.atom2._bndIndex_
                        polys.InsertNextCell(2)
                        polys.InsertCellPoint(bond.atom1._bndIndex_)
                        polys.InsertCellPoint(bond.atom2._bndIndex_)  
                    except AttributeError, inst: #this happened for PDBID:1vsn (ligand attached to molcule)
                        print inst, __file__
            if i%10:
                self.frame.UpdateProgressBar()
        polyData = vtk.vtkPolyData()
        polyData.SetPoints(points)
        polyData.GetPointData().SetScalars(createIndices)
        polyData.SetVerts(createCellArray)
        polyData.SetLines(polys)
        assembly.polyData = polyData
        assembly.points = points
        if errorTxt:
            self.frame.log.error("Can't find Bable atom element for "+ errorTxt +" from " + atoms[0].top.name )
        if lenAtoms > 200 and not selection:
            return self.DisplayLines(atoms)
        else:
            return self.DisplayBallsAndSticks(atoms)

    def DisplayBallsAndSticks(self, atoms):
        """
    Used to display a molecule as balls and sticks: 
    arguments: file - passed to MolKit.Read(file).
    return   : vtkAssembly of balls and sticks.
        """
        lenAtoms = atoms[0].lenAtoms
        lenAtomSet = atoms[0].lenAtomSet
        assembly = atoms[0].assembly
        glyph = assembly.glyph
        lut = assembly.lut
        polyData = assembly.polyData
        glyph.SetScaleModeToDataScalingOff()
        glyph.SetRange(0, lenAtomSet-1)
        glyph.SetIndexModeToScalar()
        glyph.SetOrient(1)
        glyphMapper = vtk.vtkPolyDataMapper()
        glyphMapper.SetInput(glyph.GetOutput())
        glyphMapper.SetLookupTable(lut)
        glyphMapper.SetScalarRange(0, assembly.lutLength)
        glyphActor = vtk.vtkLODActor()
        glyphActor.SetMapper(glyphMapper)
        glyph.SetInput(polyData)
        glyph.GeneratePointIdsOn()        
        tuber = vtk.vtkTubeFilter()
        tuber.SetInput(polyData)
        tuber.SetNumberOfSides(8)
        tuber.SetCapping(0)
        tuber.SetRadius(0.17)
        tuber.SetVaryRadius(0)
        tuber.SetRadiusFactor(10)       
        tubeMapper = vtk.vtkPolyDataMapper()
        tubeMapper.SetInputConnection(tuber.GetOutputPort())
        tubeMapper.SetLookupTable(lut)
        tubeMapper.SetScalarRange(0, assembly.lutLength)
        tubeActor = vtk.vtkLODActor()
        tubeActor.SetMapper(tubeMapper)
        self.SetDefaultMaterials(tubeActor)
        self.SetDefaultMaterials(glyphActor)
        assembly.tuber = tuber
        assembly.AddPart(tubeActor)    
        assembly.AddPart(glyphActor)
        assembly.tubeActor = tubeActor
        assembly.glyphActor = glyphActor
        self.frame.UpdateProgressBar()
        return assembly
            
    def DisplayLines(self, atoms):
        lenAtomSet = atoms[0].lenAtomSet
        lenAtoms = atoms[0].lenAtoms
        assembly = atoms[0].assembly
        lut = assembly.lut
        polyData = assembly.polyData        
        lineMapper = vtk.vtkPolyDataMapper()
        lineMapper.SetInput(polyData)
        lineMapper.SetLookupTable(lut)
        lineMapper.SetScalarRange(0, assembly.lutLength)
        lineActor = vtk.vtkLODActor()
        lineActor.SetMapper(lineMapper)
        self.SetDefaultMaterials(lineActor)
        assembly.AddPart(lineActor)    
        assembly.lineActor = lineActor
        self.frame.UpdateProgressBar()
        return assembly

    def DisplayRibbons(self, atoms):
        self.frame.pmv.mv.displayExtrudedSS(atoms)

    def SetDefaultMaterials(self, actor):
        "Refactored common functions for vtkLODActor"
        actor.GetProperty().SetRepresentationToSurface()
        actor.GetProperty().SetInterpolationToGouraud()
        actor.GetProperty().SetAmbient(0.15)
        actor.GetProperty().SetDiffuse(0.85)
        actor.GetProperty().SetSpecular(0.1)
        actor.GetProperty().SetSpecularPower(100)
        actor.GetProperty().SetSpecularColor(1,1,1)
        actor.GetProperty().SetColor(1,1,1)                

    def EvtCheckListBox(self, event):
        item = event.GetItem()
        self.ToggleItemVisibility(item)
        
    def ToggleItemVisibility(self, item):
        assembly = self.GetPyData(item).assembly
        if self.IsItemChecked(item):
            self.SetVisibility(assembly, 1)  
        else:
            self.SetVisibility(assembly, 0)    
        self.Rerender()

    def ToggleSelectionsVisibility(self, event):
        selections = self.GetSelections()
        for item in selections:
            checked = self.IsItemChecked(item)
            self.CheckItem(item, not checked)
            #self.ToggleItemVisibility(item)

    def SetVisibility(self, assembly, flag=1):
        "This is needed to avoid nested vtkAssembly visibility bug http://www.vtk.org/Bug/view.php?id=3312"
        assembly.SetVisibility(flag)
        if not hasattr(assembly, 'GetParts'): #to handle other vtk objects    
            return
        numberOfPaths = assembly.GetNumberOfPaths()
        parts = assembly.GetParts()
        for i in range(numberOfPaths):
            part = parts.GetItemAsObject(i) 
            if part:
                if isinstance(part, vtk.vtkAssembly):
                    self.SetVisibility(part, flag)
                part.SetVisibility(flag)
                        
    def OnDisplayLines(self, event):
        "Called through Display -> Lines"
        self.OnDisplayComand(event, 'lineActor', self.DisplayLines)
        
    def OnDisplayBallsAndSticks(self, event):
        "Called through Display -> Balls and Sticks"
        assembly = self.OnDisplayComand(event, 'tubeActor', self.DisplayBallsAndSticks)
        self.SetVisibility(assembly.glyphActor, assembly.tubeActor.GetVisibility())
        
    def OnDisplayRibbons(self, event):
        "Called through Display -> Ribbons"
        self.OnDisplayComand(event, 'ribbons_assembly', self.DisplayRibbons)
            
    def OnDisplayComand(self, event, actorStr, command):
        "A generic function factored out to display 'actorStr' using command"
        isChecked = event.IsChecked()
        pyData = self.GetPyData(self.item)
        if hasattr(pyData, 'chains'):
            chains = pyData.chains
        else:
            chains = [pyData]
        for chain in chains:
            assembly = chain.residues.atoms[0].assembly
            if hasattr(assembly, actorStr): 
                actor = eval('assembly.'+actorStr)
                self.SetVisibility(actor, isChecked)
            elif isChecked:
                command(chain.residues.atoms)                
        self.Rerender()
        return assembly
    
    def OnHBonds(self, event):
        "Called through Display -> Hydrogen Bonds"
        isChecked = event.IsChecked()
        molecule = self.GetPyData(self.item)
        if hasattr(molecule, 'hbondsActor'):
            molecule.hbondsActor.SetVisibility(isChecked)
        elif isChecked:
            from hbonds import show_h_bonds
            show_h_bonds(self.frame, molecule, molecule)
        
    def OnDisplaySufrace(self, event):
        "Called through Display -> Molecular Sufrace"
        isChecked = event.IsChecked()
        molecule = self.GetPyData(self.item)
        if hasattr(molecule, 'grid'):
            molecule.grid.visible = isChecked
        elif isChecked:
            self.frame.mayaviEngine.DisplayMolecularSurface(molecule)
    
    
    def OnRightUp(self, event):
        "Creates Refresh menu on wx.EVT_RIGHT_UP"
        selection = self.GetSelections()
        if len(selection) > 1:
            pt = event.GetPosition()
            item, flags = self.HitTest(pt)
            node = self.GetPyData(item)  
            if not hasattr(node, 'assembly'): 
                if isResidue(node):
                    for i in selection:
                        if self.GetPyData(i).top != node.top:
                            return  
                    menu = wx.Menu()
                    autodockMenu = wx.Menu()
                    flexResiduesMenu = autodockMenu.Append(wx.ID_ANY, "Flexible Residues")
                    self.Bind(wx.EVT_MENU, self.OnFlexResidues, flexResiduesMenu)
                    menu.AppendMenu(wx.ID_ANY, "AutoDock", autodockMenu)                    
                else:
                    return      
            else:  
                menu = wx.Menu()
                hideMenu = menu.Append(wx.ID_ANY, "Toggle Visibility")
                self.Bind(wx.EVT_MENU, self.ToggleSelectionsVisibility, hideMenu)
                removeMenu = menu.Append(wx.ID_ANY, "Remove Selected")
                self.Bind(wx.EVT_MENU, self.RemoveSelected, removeMenu)
                removeAllMenu = menu.Append(wx.ID_ANY, "Remove All")
                self.Bind(wx.EVT_MENU, self.OnRemoveAll, removeAllMenu)            
            self.PopupMenu(menu)    
        else:
            pt = event.GetPosition()
            item, flags = self.HitTest(pt)
            self.item = item
            if item and item != self.root:          
                node = self.GetPyData(self.item)  
                menu = wx.Menu()
                displayMenu = wx.Menu()
                if hasattr(node, 'allAtoms'):# and not hasattr(node, 'hetatm'):
                    self.displayLinesMenu = displayMenu.Append(wx.ID_ANY, "Lines", kind=wx.ITEM_CHECK)
                    self.Bind(wx.EVT_MENU, self.OnDisplayLines, self.displayLinesMenu)            
                    if hasattr(node.allAtoms[0].assembly, 'lineActor') and node.allAtoms[0].assembly.lineActor.GetVisibility():
                        self.displayLinesMenu.Check()
                    self.displayBallsAndSticksMenu = displayMenu.Append(wx.ID_ANY, "Balls and Sticks", kind=wx.ITEM_CHECK)
                    self.Bind(wx.EVT_MENU, self.OnDisplayBallsAndSticks, self.displayBallsAndSticksMenu)
                    if hasattr(node.allAtoms[0].assembly, 'tubeActor') and node.allAtoms[0].assembly.tubeActor.GetVisibility():
                        self.displayBallsAndSticksMenu.Check()
                    self.displayRibbonsMenu = displayMenu.Append(wx.ID_ANY, "Ribbons", kind=wx.ITEM_CHECK)
                    self.Bind(wx.EVT_MENU, self.OnDisplayRibbons, self.displayRibbonsMenu)
                    if hasattr(node.allAtoms[0].assembly, 'ribbons_assembly') and node.allAtoms[0].assembly.ribbons_assembly.GetVisibility():
                        self.displayRibbonsMenu.Check()
                    displayMenu.AppendSeparator()
                    hbondMenu = displayMenu.Append(wx.ID_ANY, "Hydrogen Bonds", kind=wx.ITEM_CHECK)
                    self.Bind(wx.EVT_MENU, self.OnHBonds, hbondMenu)
                    if hasattr(node, 'hbondsActor') and node.hbondsActor.GetVisibility():
                            hbondMenu.Check()
                    
                    sufraceMenu = displayMenu.Append(wx.ID_ANY, "Molecular Surface", kind=wx.ITEM_CHECK)
                    self.Bind(wx.EVT_MENU, self.OnDisplaySufrace, sufraceMenu)
                    if hasattr(node, 'grid') and node.grid.visible:
                        if not node.grid.running:
                            del node.grid
                        else:
                            sufraceMenu.Check()
                    
                    displayMenu.AppendSeparator()                    
                labelSubmenu = wx.Menu()
                
                displayLabelAtomsMenu = labelSubmenu.Append(wx.ID_ANY, "Atoms", kind=wx.ITEM_CHECK)
                self.Bind(wx.EVT_MENU, self.OnDisplayAtomLabels, displayLabelAtomsMenu)
                atoms = node.findType(Atom)
                if hasattr(atoms, 'vtkLabel'):
                    displayLabelAtomsMenu.Check()
                
                displayMenu.AppendMenu(wx.ID_ANY, "Label", labelSubmenu)
                menu.AppendMenu(wx.ID_ANY, "Display", displayMenu)
                if hasattr(node, 'assembly') or hasattr(node, 'residues'):
                    autodockMenu = wx.Menu()
                    makeLigandMenu = autodockMenu.Append(wx.ID_ANY, "Make Ligand")
                    self.Bind(wx.EVT_MENU, self.OnMakeLigand, makeLigandMenu)
                    makeMacromoleculeMenu = autodockMenu.Append(wx.ID_ANY, "Make Macromolecule")
                    self.Bind(wx.EVT_MENU, self.OnMakeMacromolecule, makeMacromoleculeMenu)
                    menu.AppendMenu(wx.ID_ANY, "AutoDock", autodockMenu)
                    menu.AppendSeparator()
                    saveMenu = menu.Append(wx.ID_ANY, "Save as PDB")
                    self.Bind(wx.EVT_MENU, self.OnSave, saveMenu)
                    removeMenu = menu.Append(wx.ID_ANY, "Remove from Scene")
                    self.Bind(wx.EVT_MENU, self.OnRemove, removeMenu)
                else:
                    if isResidue(node):
                        autodockMenu = wx.Menu()
                        flexResiduesMenu = autodockMenu.Append(wx.ID_ANY, "Flexible Residues")
                        self.Bind(wx.EVT_MENU, self.OnFlexResidues, flexResiduesMenu)
                        menu.AppendMenu(wx.ID_ANY, "AutoDock", autodockMenu)                    
                self.PopupMenu(menu)
                
            if item == None:
                menu = wx.Menu()
                openMenu = menu.Append(wx.ID_ANY, "Load Molecule")
                self.Bind(wx.EVT_MENU, self.OnOpen, openMenu)
                if self.GetCount() > 0:                
                    removeAllMenu = menu.Append(wx.ID_ANY, "Remove All Molecules")
                    self.Bind(wx.EVT_MENU, self.OnRemoveAll, removeAllMenu)
                self.PopupMenu(menu)    
            
    def OnDisplayAtomLabels(self, event):
        molecule = self.GetPyData(self.item)
        isChecked = event.IsChecked()
        atoms = molecule.findType(Atom)
        if isChecked:
            self.DisplayAtomLabels(atoms)
        else:
            self.RemoveAtomLabels(atoms)
            
    def RemoveAtomLabels(self, atoms):
        for atom in atoms:
            if hasattr(atom, 'vtkLabel'):
                self.frame.renderer3D.RemoveActor(atom.vtkLabel)
                del atom.vtkLabel
        self.frame.canvas3D.Refresh()
        
    def DisplayAtomLabels(self, atoms):
        for atom in atoms:
            atext = vtk.vtkVectorText()
            atext.SetText(atom.name)
            textMapper = vtk.vtkPolyDataMapper()
            textMapper.SetInputConnection(atext.GetOutputPort())
            textActor = vtk.vtkFollower()
            textActor.SetMapper(textMapper)
            textActor.SetScale(0.3, 0.3, 0.3)
            textActor.SetPosition(atom.coords[0], atom.coords[1], atom.coords[2])
            self.frame.renderer3D.AddActor(textActor)        
            textActor.SetCamera(self.frame.renderer3D.GetActiveCamera())
            atom.vtkLabel = textActor
        self.frame.canvas3D.Refresh()

    def OnOpen(self, event):
        return self.frame.OnFileOpenMenu(event)

    def OnRemove(self, event): 
        "Called when Remove menu is pressed"
        data = self.GetPyData(self.item)
        atoms = data.findType(Atom)
        self.RemoveAtomLabels(atoms)
        if hasattr(data,'chains'):       
            try:     
                self.molecules.remove(data)
                self.moleculesNames.remove(data.name)
                self.frame.pmv.mv.removeObject(data)
                self.frame.renderer3D.RemoveActor(data.assembly)
                if hasattr(data,'grid') and data.grid.running:
                    self.frame.mayaviEngine.engine.scenes[0].children.remove(data.grid)
            except Exception, inst:
                self.frame.log.error(str(inst))
            self.Delete(self.item)   
        else:#chain                
            data.parent.assembly.RemovePart(data.assembly)
            if data in data.parent.children:
                data.parent.remove(data)
#                for atom in data.residues.atoms:
#                    try:
#                        data.parent.allAtoms.remove(atom)
#                    except:
#                        pass
            if len(data.parent.chains) == 0:#remove molecule
                self.molecules.remove(data.parent)
                self.moleculesNames.remove(data.parent.name)
                self.frame.renderer3D.RemoveActor(data.parent.assembly)
                self.Delete(data.parent.assembly.treeID)                   
            else:
                self.Delete(self.item)   
        self.Rerender()
        self.OnSelChanged(None)
    
    def RemoveSelected(self, event):
        selections = self.GetSelections()
        for item in selections:
            self.item = item
            self.OnRemove(None)
            
    def Remove(self, index):
        "Called from AddDocking"
        self.Delete(self.molecules[index].assembly.treeID)
        self.frame.renderer3D.RemoveActor(self.molecules[index].assembly)
        mol = self.molecules.pop(index)
        self.frame.pmv.mv.removeObject(mol)
        self.moleculesNames.pop(index)
    
    def OnRemoveAll(self, event):
        children = self.root.GetChildren()
        if children:
            for child in children:
                self.item = child
                self.OnRemove(event)

    
    def OnSave(self, event):
        nodes = self.GetPyData(self.item)
        dlg = wx.FileDialog(self, "Choose a file", os.getcwd(), "", 
                            "All files (*)|*", 
                            style=wx.SAVE)
        if dlg.ShowModal() == wx.ID_OK:
            fileName = dlg.GetPath()
            from MolKit.pdbWriter import PdbWriter
            writer = PdbWriter()
            writer.write(fileName, nodes, records=['ATOM', 'HETATM', 'CONECT'])
        dlg.Destroy()  
                
    def Rerender(self):
        self.frame.mayaviEngine.scene.render()
    
    def OnSelChanged(self, event):
        if not self.toggleSelection:
            return
        selections = self.GetSelections()
        if self.selectionAssembly:
            self.frame.renderer3D.RemoveActor(self.selectionAssembly)
        atomSet = AtomSet()
        for selection in selections:
            molecule = self.GetPyData(selection)
            atoms = molecule.findType(Atom)
            if hasattr(atoms[0], 'assembly'):
                if not atoms[0].assembly.GetVisibility():
                    continue
            atomSet.extend(atoms)

        if atomSet:
            prevAssembly = None
            if hasattr(atomSet[0], 'assembly'):
                prevAssembly = atomSet[0].assembly
            assembly = self.GenerateAssambly(atomSet, True)
            self.frame.renderer3D.AddActor(assembly)
            self.selectionAssembly = atomSet[0].assembly
            if prevAssembly:
                atomSet[0].assembly = prevAssembly
        self.Rerender()
        
    def ToggleSelection(self, event):
        self.toggleSelection = not self.toggleSelection
        if self.toggleSelection:
            self.OnSelChanged(None)
        else:
            if self.selectionAssembly:
                self.frame.renderer3D.RemoveActor(self.selectionAssembly)
            self.Rerender()

    def UpdateConformation(self, molecule, conformation):
        "Updates conformation of a molecule"
        if len(molecule.chains) > 1:
            startIndex = 0
            for chain in molecule.chains:        
                assembly = chain.assembly
                points = assembly.points
                endIndex = startIndex+len(chain.residues.atoms)
                for i, coord in enumerate(conformation.coords[startIndex:endIndex]):
                    points.SetPoint(i, coord)
                if hasattr(assembly, 'glyph'):                
                    assembly.glyph.Modified()    
                    assembly.tuber.Modified()
                startIndex = endIndex
            molecule.allAtoms.coords = conformation.coords #this has been added to display labels properly                    
        else:
            assembly = molecule.assembly
            points = assembly.points
            molecule.allAtoms.coords = conformation.coords #this has been added to display labels properly
            for i, coord in enumerate(conformation.coords):
                points.SetPoint(i, coord)
            if hasattr(assembly, 'glyph'):                
                assembly.glyph.Modified()    
                assembly.tuber.Modified()
            
        #check to see if updated conformation is visible on the screen, if not ResetCamera
        visPts = vtk.vtkSelectVisiblePoints()
        renderSize = self.frame.renderer3D.GetSize()
        visPts.SetRenderer(self.frame.renderer3D)
        visPts.SelectionWindowOn()
        visPts.SetInput(assembly.polyData)
        visPts.SelectionWindowOn()
        visPts.SetSelection(0, renderSize[0], 0, renderSize[1])
        visPts.Update()
        visCount = visPts.GetOutputDataObject(0).GetVerts().GetSize()
        if visCount == 0:
            self.frame.renderer3D.ResetCamera()
        self.OnSelChanged(None)  
        self.frame.canvas3D.Refresh()

def isResidue(node):
    "Returns true if node is PDBQT residue"
    if isinstance(node, Residue)  and len(node.parent.residues) > 2:
        return True
    else:
        return False 
      