Source code for qharv.seed.xml

# Author: Yubo "Paul" Yang
# Email: yubo.paul.yang@gmail.com
# Routines to manipulate an xml input.
#  Almost all functions are built around the lxml module's API.
#  The central object is lxml.etree.ElementTree, which is usually named "doc".
import os
import numpy as np
from copy import deepcopy
from lxml import etree
from io import StringIO

# ======================== level 0: basic io =========================
[docs]def read(fname): """ read an xml file wrap around lxml.etree.parse Args: fname (str): filename to read from Return: lxml.etree._ElementTree: doc, parsed xml document """ parser = etree.XMLParser(remove_blank_text=True) doc = etree.parse(fname, parser) return doc
[docs]def write(fname, doc): """ write an xml file wrap around lxml.etree._ElementTree.write Args: fname (str): filename to write to doc (lxml.etree._ElementTree): xml file in memory Effect: write fname using contents of doc """ doc.write(fname, pretty_print=True)
[docs]def parse(text): """ parse the text representation of an xml node delegate to read() Args: text (str): string representation of an xml node Return: lxml.etree._Element: root, parsed xml node """ try: # Python2 node = StringIO(text.decode()) except AttributeError: # Python3 node = StringIO(text) root = read(node).getroot() return root
[docs]def str_rep(node): """ return the string representation of an xml node Args: node (lxml.etree._Element): xml node Return: str: string representation of node """ return etree.tostring(node, pretty_print=True).decode()
[docs]def show(node): print(str_rep(node))
[docs]def ls(node, r=False, level=0, indent=" "): """ List directory structure Similar to the Linux `ls` command, but for an xml node Args: node (lxml.etree._Element): xml node r (bool): recursive level (int): level of indentation, used only if r=True indent (str): indent string, used only if r=True Return: str: mystr, a string representation of the directory structure """ mystr = '' children = node.getchildren() if len(children) > 0: for child in children: if type(child) is not etree._Element: continue mystr += indent*level + child.tag + '\n' if r: mystr += ls(child, r=r, level=level+1, indent=indent) else: return '' return mystr
[docs]def append(root, nodes, copy=True): """ append one or more nodes to a root node Args: root (lxml.etree._Element): xml node nodes (list or lxml.etree._Element): xml node(s) copy (bool, optional): make copy of nodes to append """ if copy: from copy import deepcopy nodes = deepcopy(nodes) if type(nodes) is etree._Element: root.append(nodes) else: for node in nodes: root.append(node)
[docs]def remove(nodes): """ remove nodes from the xml tree Args: nodes (list): xml nodes """ for node in nodes: parent = node.getparent() parent.remove(node)
# ========================= level 1: node content io ========================= # node.get & node.set are sufficient for attribute manipulation # level 1 routines are needed for node.text and node.children manipulation
[docs]def make_node(tag, attribs=None, text=None, pad=' '): """ create etree.Element <tag **attribs> text </tag> Args: tag (str): tag node attribs (dict, optional): attributes, default None text (str, optional): text content, default None pad (str, optional): padding for text, default ' ' Return: etree.Element: node Examples: >>> sim = make_node('simulation') >>> epset = make_node('particleset', {'name': 'e'}) >>> lrnode = make_node('parameter', {'name': 'LR_dim_cutoff'}, str(20)) >>> bconds = make_node('parameter', {'name': 'bconds'}, 'p p p', pad='\n') >>> seed = make_node('seed', text=str(31415)) """ node = etree.Element(tag, attribs) if text is not None: node.text = pad + text + pad return node
[docs]def arr2text(arr): """ format a numpy array into a text string """ text = '' if len(arr.shape) == 1: # vector text = " ".join(arr.astype(str)) elif len(arr.shape) == 2: # matrix mat = [arr2text(line) for line in arr] text = "\n" + "\n".join(mat) + "\n" else: raise RuntimeError('arr2text can only convert vector or matrix.') return text
[docs]def text2arr(text, dtype=float, flatten=False): """ convert a text string into a numpy array """ if type(text) is bytes: text = text.decode() tlist = text.strip(' ').strip('\n').split('\n') if len(tlist) == 1: return np.array(tlist[0].split(), dtype=dtype) else: if flatten: mytext = '\n'.join(['\n'.join(line.split()) for line in tlist]) myarr = text2arr(mytext) return myarr.flatten() else: return np.array([line.split() for line in tlist], dtype=dtype)
[docs]def text2vec(text, dtype=float): """ convert a text string into a 1D numpy array """ # unfold at the text level line = ' '.join(text.split('\n')) return np.array(line.split(), dtype=dtype)
[docs]def swap_node(node0, node1): """ replace the node0 with node1 node0 must have a parent Args: node0 (etree.Element): node to be swapped out node1 (etree.Element): replacement node Effect: node0 is replaced by node1 in node0's owning tree """ parent = node0.getparent() idx = parent.index(node0) parent.remove(node0) parent.insert(idx, node1)
# ================= level 2: QMCPACK specialized read =================
[docs]def get_param(node, pname): """ retrieve the str representation of a parameter from: <parameter name="pname"> str_rep </parameter> Args: node (lxml.etree._Element): xml node with <parameter>. pname (str): name of parameter Return: str: string representation of the parameter value """ pnode = node.find('.//parameter[@name="%s"]' % pname) return pnode.text
[docs]def set_param(node, pname, pval, new=False, pad=' '): """ set <parameter> with name 'pname' to 'pval' if new=True, then <parameter name="pname"> does not exist. create it Args: node (lxml.etree._Element): xml node with children having tag 'parameter' pname (str): name of parameter pval (str): value of parameter new (bool): create new <paremter> node, default is false Effect: the text of <parameter> with 'pname' will be set to 'pval' """ pnode = node.find('.//parameter[@name="%s"]' % pname) text = pad + str(pval) + pad # 4 paths dependent on (pnode is None) and new if (pnode is None) and (not new): # unintended input raise RuntimeError('<parameter name="%s"> not found in %s\n\ please set new=True' % (pname, node.tag)) elif (pnode is not None) and new: # unintended input raise RuntimeError('<parameter name="%s"> found in %s\n\ please set new=False' % (pname, node.tag)) elif (pnode is None) and new: pnode = etree.Element('parameter', {'name': pname}) pnode.text = text node.append(pnode) else: pnode.text = text
[docs]def get_axes(doc): sc_node = doc.find('.//simulationcell') if sc_node is None: raise RuntimeError('<simulationcell> not found') lat_node = sc_node.find('.//parameter[@name="lattice"]') unit = lat_node.get('units') assert unit == 'bohr' axes = text2arr(lat_node.text) return axes
[docs]def get_nelec(doc): nelec = 0 for pname in ['u', 'd']: node = doc.find('.//group[@name="%s"]' % pname) if node is not None: npos = int(node.get('size')) nelec += npos return nelec
[docs]def get_pos(doc, pset='ion0', all_pos=True, group=None): # find <particleset> pset_node = doc.find('.//particleset[@name="%s"]' % pset) if pset_node is None: raise RuntimeError('%s not found' % pset) # find <group> if necessary groups = pset_node.findall('.//group') names = [grp.get('name') for grp in groups] if (group is None): # no group give, requesting all particle positions? if (not all_pos): warn_msg = '%d groups found, please specify particle group from %s' % ( len(groups), str(names) ) raise RuntimeError(warn_msg) else: # group given, see if it is available msg = 'specified group will be over-written with all_pos!\n' msg += 'Please set all_pos=False.' if (all_pos): raise RuntimeError(msg) if (group not in names): raise RuntimeError('no group with name "%s" in %s' % (group, str(names))) pos_text = '' if not all_pos: # get requested group positions grp = pset_node.find('.//group[@name="%s"]' % group) pos_node = grp.find('.//attrib[@name="position"]') pos_text = pos_node.text else: for grp in groups: pos_node = grp.find('.//attrib[@name="position"]') if pos_node is None: # look in parent (old-style input) pset_node = grp.getparent() pos_node = pset_node.find('.//attrib[@name="position"]') pos_text += pos_node.text.strip('\n')+'\n' # get requestsed particle positions pos = text2arr(pos_text.strip('\n')) return pos
# ================= level 3: QMCPACK specialized construct =================
[docs]def build_coeff(knots, **attribs): """ construct an <coefficients/> example: build_coeff([1,2]): <coefficients id="new" type="Array"> 1 2 </coefficients> Args: knots (list): a list of numbers Return: lxml.etree._Element: <coefficients/> """ # add required attributes # id (str, optional): coefficient name, default 'new' if 'id' not in attribs: attribs['id'] = 'new' # type (str, optional): coefficient type, default 'Array' if 'type' not in attribs: attribs['type'] = 'Array' # construct node coeff_node = etree.Element('coefficients', attribs) coeff_node.text = ' ' + ' '.join(map(str, knots)) + ' ' # 1D arr2text return coeff_node
[docs]def build_jr2(uuc, udc): uu_node = build_coeff(uuc, **{'id': 'uu'}) cuu = etree.Element('correlation', {'speciesA': 'u', 'speciesB': 'u', 'size': str(len(uuc))} ) cuu.append(uu_node) ud_node = build_coeff(udc, **{'id': 'ud'}) cud = etree.Element('correlation', {'speciesA': 'u', 'speciesB': 'd', 'size': str(len(udc))} ) cud.append(ud_node) j2_node = etree.Element('jastrow', {'name': 'J2', 'type': 'Two-Body', 'function': 'Bspline'} ) j2_node.append(cuu) j2_node.append(cud) return j2_node
[docs]def build_jk2_iso(coeffs, kc): """ construct isotropic e-e reciprocal space Jastrow node example: build_jk2([1,2], 0.4): Args: coefs (list): a list of numbers at the kc (float): k space cutoff in a.u. Return: lxml.etree._Element: <jastrow/> """ coeff_node = build_coeff(coeffs, id='cG2') corr_node = etree.Element('correlation', { 'type': 'Two-Body', 'kc': str(kc), 'symmetry': 'isotropic' }) corr_node.append(coeff_node) jk_node = etree.Element('jastrow', { 'name': 'Jk', 'type': 'kSpace', 'source': 'e' }) jk_node.append(corr_node) return jk_node
# ================= level 4: QMCPACK specialized advanced =================
[docs]def turn_off_jas_opt(wf_node): mywf = deepcopy(wf_node) all_jas = mywf.findall('.//jastrow') for jas in all_jas: for coeff in jas.findall('.//coefficients'): coeff.set('optimize', 'no') return mywf
[docs]def add_backflow(wf_node, bf_node): # make sure inputs are not scrambled assert wf_node.tag == 'wavefunction' assert bf_node.tag == 'backflow' # make a copy of wavefunction mywf = deepcopy(wf_node) # insert backflow block dset = mywf.find('.//determinantset') dset.insert(0, bf_node) # use code path where <backflow> optimization still works bb = None # find basis set builder # bb should be either <sposet_builder> or <determinantset> spol = mywf.findall('.//sposet_builder') assert len(spol) == 1 spo = spol[0] if spo is None: bb = dset else: bb = spo assert bb.tag in ('sposet_builder', 'determinantset') bb.set('use_old_spline', 'yes') bb.set('precision', 'double') bb.set('truncate', 'no') return mywf
[docs]def dset2spo(wf_node, det_map): """ change <wavefunction> from old style, <basis> in <determinantset>, to new style, <basis> in <sposet_builder> Args: wf_node (etree.Element): <wavefunction> node det_map (dict): determinant name -> particle group name e.g. {'updet':'u','downdet':'d'} Returns: None """ # convert between sposet name and determinant id def d2sname(x): return 'spo_'+det_map[x] # construct <sposet_builder> using nodes from <determinantset> dset = wf_node.find('.//determinantset') bb = etree.Element('sposet_builder', dset.attrib) # add <basisset> to bb bb.append(dset.find('.//basisset')) # add <sposet> to bb dets = dset.findall('.//determinant') s2dname = {} # save spo_name -> det_id for det in dets: det.tag = 'sposet' det_id = det.get('id') if det_id not in det_map.keys(): raise RuntimeError('%s not in det_map' % det_id) spo_name = d2sname(det_id) s2dname[spo_name] = det_id det.set('name', spo_name) det.attrib.pop('id') bb.append(det) # replace <determinantset> with <sposet_builder> idx = wf_node.index(dset) wf_node.insert(idx, bb) wf_node.remove(dset) # rewrite <determinantset> dset = etree.Element('determinantset') slater = etree.Element('slaterdeterminant') for spo_name in s2dname.keys(): det_id = s2dname[spo_name] group = det_map[det_id] det = etree.Element('determinant', { 'id': det_id, 'group': group, 'sposet': spo_name }) slater.append(det) dset.append(slater) wf_node.insert(idx+1, dset)