"""Module to generate networkx graphs."""
from jarvis.core.atoms import get_supercell_dims
from jarvis.core.specie import Specie
from jarvis.core.utils import random_colors
import numpy as np
from collections import OrderedDict
from jarvis.analysis.structure.neighbors import NeighborsAnalysis
[docs]class Graph(object):
"""Generate a graph object."""
def __init__(
self,
nodes=[],
node_attributes=[],
edges=[],
edge_attributes=[],
color_map=None,
labels=None,
):
"""
Initialize the graph object.
Args:
nodes: IDs of the graph nodes as integer array.
node_attributes: node features as multi-dimensional array.
edges: connectivity as a (u,v) pair where u is
the source index and v the destination ID.
edge_attributes: attributes for each connectivity.
as simple as euclidean distances.
"""
self.nodes = nodes
self.node_attributes = node_attributes
self.edges = edges
self.edge_attributes = edge_attributes
self.color_map = color_map
self.labels = labels
[docs] @staticmethod
def from_atoms(
atoms=None,
lengthscale=0.5,
variance=1.0,
get_prim=False,
zero_diag=False,
node_atomwise_angle_dist=False,
node_atomwise_rdf=False,
features="basic",
enforce_c_size=10.0,
max_n=100,
max_cut=5.0,
verbose=False,
make_colormap=True,
):
"""
Get Networkx graph. Requires Networkx installation.
Args:
atoms: jarvis.core.Atoms object.
rcut: cut-off after which distance will be set to zero
in the adjacency matrix.
features: Node features.
'atomic_number': graph with atomic numbers only.
'cfid': 438 chemical descriptors from CFID.
'basic':10 features
array: array with CFID chemical descriptor names.
See: jarvis/core/specie.py
enforce_c_size: minimum size of the simulation cell in Angst.
"""
if get_prim:
atoms = atoms.get_primitive_atoms
dim = get_supercell_dims(atoms=atoms, enforce_c_size=enforce_c_size)
atoms = atoms.make_supercell(dim)
raw_data = np.array(atoms.raw_distance_matrix)
adj = variance * np.exp(-raw_data / lengthscale)
if zero_diag:
np.fill_diagonal(adj, 0.0)
nodes = np.arange(atoms.num_atoms)
if features == "atomic_number":
node_attributes = np.array(
[[np.array(Specie(i).Z)] for i in atoms.elements],
dtype="float",
)
elif features == "basic":
feats = [
"Z",
"coulmn",
"row",
"X",
"atom_rad",
"nsvalence",
"npvalence",
"ndvalence",
"nfvalence",
"first_ion_en",
"elec_aff",
]
node_attributes = []
for i in atoms.elements:
tmp = []
for j in feats:
tmp.append(Specie(i).element_property(j))
node_attributes.append(tmp)
node_attributes = np.array(node_attributes, dtype="float")
elif features == "cfid":
node_attributes = np.array(
[np.array(Specie(i).get_descrp_arr) for i in atoms.elements],
dtype="float",
)
elif isinstance(features, list):
node_attributes = []
for i in atoms.elements:
tmp = []
for j in features:
tmp.append(Specie(i).element_property(j))
node_attributes.append(tmp)
node_attributes = np.array(node_attributes, dtype="float")
else:
raise ("Please check the input options.")
if node_atomwise_rdf or node_atomwise_angle_dist:
nbr = NeighborsAnalysis(
atoms, max_n=max_n, verbose=verbose, max_cut=max_cut
)
if node_atomwise_rdf:
node_attributes = np.concatenate(
(node_attributes, nbr.atomwise_radial_dist()), axis=1
)
node_attributes = np.array(node_attributes, dtype="float")
if node_atomwise_angle_dist:
node_attributes = np.concatenate(
(node_attributes, nbr.atomwise_angle_dist()), axis=1
)
node_attributes = np.array(node_attributes, dtype="float")
uv = []
edge_features = []
for ii, i in enumerate(atoms.elements):
for jj, j in enumerate(atoms.elements):
uv.append((ii, jj))
edge_features.append(adj[ii, jj])
edge_attributes = edge_features
if make_colormap:
sps = atoms.uniq_species
color_dict = random_colors(number_of_colors=len(sps))
new_colors = {}
for i, j in color_dict.items():
new_colors[sps[i]] = j
color_map = []
for ii, i in enumerate(atoms.elements):
color_map.append(new_colors[i])
return Graph(
nodes=nodes,
edges=uv,
node_attributes=np.array(node_attributes),
edge_attributes=np.array(edge_attributes),
color_map=color_map,
)
[docs] def to_networkx(self):
"""Get networkx representation."""
import networkx as nx
graph = nx.DiGraph()
graph.add_nodes_from(self.nodes)
graph.add_edges_from(self.edges)
for i, j in zip(self.edges, self.edge_attributes):
graph.add_edge(i[0], i[1], weight=j)
return graph
@property
def num_nodes(self):
"""Return number of nodes in the graph."""
return len(self.nodes)
@property
def num_edges(self):
"""Return number of edges in the graph."""
return len(self.edges)
[docs] @classmethod
def from_dict(self, d={}):
"""Constuct class from a dictionary."""
return Graph(
nodes=d["nodes"],
edges=d["edges"],
node_attributes=d["node_attributes"],
edge_attributes=d["edge_attributes"],
color_map=d["color_map"],
labels=d["labels"],
)
[docs] def to_dict(self):
"""Provide dictionary representation of the Graph object."""
info = OrderedDict()
info["nodes"] = np.array(self.nodes).tolist()
info["edges"] = np.array(self.edges).tolist()
info["node_attributes"] = np.array(self.node_attributes).tolist()
info["edge_attributes"] = np.array(self.edge_attributes).tolist()
info["color_map"] = np.array(self.color_map).tolist()
info["labels"] = np.array(self.labels).tolist()
return info
def __repr__(self):
"""Provide representation during print statements."""
return "Graph({})".format(self.to_dict())
@property
def adjacency_matrix(self):
"""Provide adjacency_matrix of graph."""
return np.array(self.edge_attributes).reshape(
self.num_nodes, self.num_nodes
)
"""
if __name__ == "__main__":
from jarvis.core.atoms import Atoms
from jarvis.db.figshare import get_jid_data
atoms = Atoms.from_dict(get_jid_data("JVASP-664")["atoms"])
g = Graph.from_atoms(
atoms=atoms,
features="basic",
get_prim=True,
zero_diag=True,
node_atomwise_angle_dist=True,
node_atomwise_rdf=True,
)
g = Graph.from_atoms(
atoms=atoms,
features="cfid",
get_prim=True,
zero_diag=True,
node_atomwise_angle_dist=True,
node_atomwise_rdf=True,
)
g = Graph.from_atoms(
atoms=atoms,
features="atomic_number",
get_prim=True,
zero_diag=True,
node_atomwise_angle_dist=True,
node_atomwise_rdf=True,
)
g = Graph.from_atoms(atoms=atoms, features="basic")
g = Graph.from_atoms(
atoms=atoms, features=["Z", "atom_mass", "max_oxid_s"]
)
g = Graph.from_atoms(atoms=atoms, features="cfid")
# print(g)
d = g.to_dict()
g = Graph.from_dict(d)
num_nodes = g.num_nodes
num_edges = g.num_edges
print(num_nodes, num_edges)
assert num_nodes == 48
assert num_edges == 2304
assert len(g.adjacency_matrix) == 2304
# graph, color_map = get_networkx_graph(atoms)
# nx.draw(graph, node_color=color_map, with_labels=True)
# from jarvis.analysis.structure.neighbors import NeighborsAnalysis
"""