Causal Graphs

Causal graphs are fairly easy to create in Python. We just need to recall the definition of a node and an edge. A node is a point on the graph (in causal graphs, these represent data); an edge is a line that connects two nodes. In causal diagrams, edges are directed arrows with a head (where the arrow points to) and a tail (where the arrow points from).

The graphviz module makes plotting graphs easy. We’ll create a Digraph (short for directed graph).

import graphviz
g1 = graphviz.Digraph('G')
g1.edge('x', 'y')
g1.edge('u', 'y')

Rather than use the graphviz module directly, we can use the causalgraphicalmodels module, which creates a suite of tools that work on top of graphviz diagrams.

To install the module, we need to use the !pip magic command.

!pip install causalgraphicalmodels --user
Requirement already satisfied: causalgraphicalmodels in /home/james/.local/lib/python3.8/site-packages (0.0.4)
Requirement already satisfied: graphviz in /mnt/software/anaconda3/lib/python3.8/site-packages (from causalgraphicalmodels) (0.17)
Requirement already satisfied: pandas in /mnt/software/anaconda3/lib/python3.8/site-packages (from causalgraphicalmodels) (1.2.2)
Requirement already satisfied: networkx in /mnt/software/anaconda3/lib/python3.8/site-packages (from causalgraphicalmodels) (2.5)
Requirement already satisfied: numpy in /mnt/software/anaconda3/lib/python3.8/site-packages (from causalgraphicalmodels) (1.19.2)
Requirement already satisfied: decorator>=4.3.0 in /mnt/software/anaconda3/lib/python3.8/site-packages (from networkx->causalgraphicalmodels) (4.4.2)
Requirement already satisfied: python-dateutil>=2.7.3 in /mnt/software/anaconda3/lib/python3.8/site-packages (from pandas->causalgraphicalmodels) (2.8.1)
Requirement already satisfied: pytz>=2017.3 in /mnt/software/anaconda3/lib/python3.8/site-packages (from pandas->causalgraphicalmodels) (2021.1)
Requirement already satisfied: six>=1.5 in /mnt/software/anaconda3/lib/python3.8/site-packages (from python-dateutil>=2.7.3->pandas->causalgraphicalmodels) (1.15.0)

The syntax of a causalgraphicalmodels graph is slightly different than the graphiz module. Here, we specify a list of nodes as well as a list of directed edges. It is standard to list directed edges as a (tail, head) set of nodes.

from causalgraphicalmodels import CausalGraphicalModel
g2 = CausalGraphicalModel(
    nodes=["x", "y", "u"],
        ("x", "y"),
        ("u", "y")

The advantage of using this module is that it has causal analysis tools built in. We can check for backdoor paths from \(x\) to \(y\) with the get_all_backdoor_paths() function.

print(g2.get_all_backdoor_paths("x", "y"))

An empty list indicates that there is no backdoor path! If the causal model is true, then we can estimate a causal effect of \(x\) on \(y\).

Let’s modify the above graph to include a fork \(x \leftarrow u \rightarrow y\).

g3 = CausalGraphicalModel(
    nodes=["x", "y", "u"],
        ("x", "y"),
        ("u", "y"),
        ("u", "x")

The above graph has a backdoor path \(x \leftarrow u \rightarrow y\).

print(g3.get_all_backdoor_paths("x", "y"))
[['x', 'u', 'y']]

Let’s reverse the position of \(x\) and \(u\) such that \(x\) now sits in the middle of the fork.

g4 = CausalGraphicalModel(
    nodes=["x", "y", "u"],
        ("x", "y"),
        ("u", "y"),
        ("x", "u")

Here, \(x\) is correlated with the unobserved data \(u\). But note that in this model \(x\) causes \(u\) rather than is influenced by \(u\).

There is no backdoor path from \(x\) to \(y\).

print(g4.get_all_backdoor_paths("x", "y"))

With \(u\) unobserved, the effect of \(x\) on \(y\) that we will estimate under this model (if it is true) is the total causal effect of \(x\) on \(y\). If \(u\) were observable, then we could get both the direct effect of \(x \rightarrow y\) of \(x\) as well as the moderated \(x \rightarrow u \rightarrow y\) effect.

We’ll review these graphs after learning how to perform regression.

Concept check: replicate the graph in the code cell below.

graph with edges {x -> z, z -> y, u -> x, u -> y}

You should have edges \(x \rightarrow z, z \rightarrow y, u \rightarrow x, u \rightarrow y\).

import graphviz
g5 = graphviz.Digraph('G', format='svg')
g5.edge('x', 'z')
g5.edge('z', 'y')
g5.edge('u', 'y')
g5.edge('u', 'x')