电脑基础 · 2023年3月28日

PyG教程(2):图数据

一.概述

PyG是面向图数据的,它同时支持同构图(homogeneous graphs)异构图(heterogeneous)。同构图指只包含一种类型的节点和边的图(下图左)。而异构图指包含两种及以上类型的节点和边的图(下图右)。

PyG教程(2):图数据

在PyG中,同构图被描述为torch_geometric.data.Data类的实例,而异构图被描述为torch_geometric.data.HeteroData的实例。

本文主要介绍PyG关于同构图的的相关操作,操作环境为:

pytorch = 1.10.1
cuda = 11.3
torch_geometric = 2.0.4

二.基本图操作

2.1 图的创建

同构图是用Data类是进行描述的,因此首先查看其初始化函数的参数列表:

def __init__(self, x: OptTensor = None, edge_index: OptTensor = None,
             edge_attr: OptTensor = None, y: OptTensor = None,
             pos: OptTensor = None, **kwargs):

对应的参数说明为:

参数 说明
x 节点特征矩阵,shape为[num_nodes, num_node_features]Tensor类型
edge_index 边索引(边表),shape为[2, num_edges],在这个包含两行的数组中,第1行与第2行中对应索引位置的值分别表示一条边的源节点和目标节点,LongTensor类型。
edge_attr 边特征矩阵,shape为[num_edges, num_edges_featrues]Tensor类型
y 图级标签或节点级标签,Tensor类型
pos 节点的位置矩阵,shape为[num_nodes, num_dimensions]Tensor类型
**kwargs 用户自定义的额外属性,传入格式需为attr_name=attr_value

Data类的初始化函数中参数默认值都为None,这意味着没有哪个参数是必要的,在实际使用时需要根据待构造图的实际情况来传入相应的属性。

2.2 常用的图属性与方法

在PyG中,对于一个Data对象其包含众多属性和方法,这里列举一下常用的,更详细的请参见官网Data部分。

方法/属性 说明
num_node_features/num_features 图节点数特征(维度)数
num_edge_features 图中边的特征(维度)数
keys 图属性名列表
num_edges 图边数
num_nodes 图节点数
is_directed()/is_undirected() 是否为有向图/无向图
is_cuda 图是否存储在GPU上
has_self_loops()/contains_self_loops() 图中节点是否包含自环
has_isolated_nodes()/contains_isolated_nodes() 图中是否包含孤立节点
to(device) 将图实例放置到指定的设备(GPU或CPU)上
clone() 对图进行深拷贝

2.3 演示示例

首先创建一个包含5个顶点、12条边的无向图。需要注意的是,在edge_index中边都有有方向的,即从源节点到目标节点。若要创建从节点
v
v
v
到节点
u
u
u
的无向边,则需要在edge_index中传入两条相应的边,即(u,v), (v,u)

import torch
import torch_geometric.data as data
from torch_geometric.utils import to_networkx
import matplotlib.pyplot as plt
import networkx as nx
edge_index = torch.LongTensor([[0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 4, 4],
                               [1, 2, 4, 0, 2, 1, 0, 3, 2, 4, 3, 0]])
x = torch.ones(5, 2)
g = data.Data(edge_index=edge_index, x=x)
print(g)
"""
Data(edge_index=[2, 12], x=[5, 2])
"""
# 转换为nextworkx格式的图并可视化
g = to_networkx(g)
nx.draw(g, with_labels=g.nodes)
plt.show()

创建的图可视化结果为:

PyG教程(2):图数据

对上述创建的Data对象应用2.2节介绍的部分方法实例代码如下:

print(g.num_nodes, g.num_edges)
# 5 12
print(g.keys)
# ['x', 'edge_index']
print(g.num_node_features)
# 2
print(g.is_undirected())
# True
print(g.has_isolated_nodes())
# False

若要将自己创建的图实例保存到本地磁盘或从本地磁盘加载保存的图数据,可以使用torch.save()torch.load()

torch.save([g], "temp/data.pt")
g = torch.load("temp/data.pt")
print(g)
# [Data(edge_index=[2, 12], x=[5, 2])]

三.进阶图操作

torch_geometric.utils模块中包含了许多对图数据的高级操作方法,下面将对其中最常用的方法进行介绍。

3.1 度的计算

通过degree(index, num_nodes=None)方法可以计算图中节点的度,其中:

  • indexedge_index中的两个维度中任意一个
  • num_nodes:节点的数量,可选参数

示例代码:

print(degree(g.edge_index[0]))
# tensor([3., 2., 3., 2., 2.])
print(degree(g.edge_index[1]))
# tensor([3., 2., 3., 2., 2.])

3.2 自环的添加与删除

自环指节点指向自身的边。在utils中处理自环的方法包括:

  • contains_self_loops(edge_index):判断图中节点是否包含自环。
  • remove_self_loops(edge_index):删除图中所有的自环。
  • add_self_loops(edge_index):为图中的节点添加自环,对于有自环的节点,它会再为该节点添加一个自环。
  • add_remaining_self_loops:为图中还没有自环的节点添加自环。

示例代码:

print(contains_self_loops(g.edge_index))
# False
edge_index, _ = add_self_loops(g.edge_index)
print(edge_index)
"""
tensor([[0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 4, 4, 0, 1, 2, 3, 4],
        [1, 2, 4, 0, 2, 1, 0, 3, 2, 4, 3, 0, 0, 1, 2, 3, 4]])
"""
edge_index, _ = add_remaining_self_loops(edge_index)
print(edge_index)
"""
没有添加新的自环
tensor([[0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 4, 4, 0, 1, 2, 3, 4],
        [1, 2, 4, 0, 2, 1, 0, 3, 2, 4, 3, 0, 0, 1, 2, 3, 4]])
"""
edge_index, _ = remove_self_loops(edge_index)
print(edge_index)
"""
tensor([[0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 4, 4],
        [1, 2, 4, 0, 2, 1, 0, 3, 2, 4, 3, 0]])
"""

3.2 子图提取

utils中提供了若干方法用来在图中提取子图。

  • subgraph(subset, edge_index):根据给定的图节点集合subset来抽取图中包含这些节点的子图。
  • k_hop_subgraph(node_idx, num_hops, edge_index):提取给定节点集node_idx能经过num_hops跳到达的所有节点组成的子图(包括node_idx本身)。

sub_graph方法示例代码:

def draw(edge_index):
    graph = data.Data(edge_index=edge_index)
    graph = to_networkx(graph)
    print(graph.nodes)
    nx.draw(graph, with_labels=graph.nodes)
    plt.show()
edge_index, _ = subgraph(subset=torch.LongTensor(
    [0, 1, 2]), edge_index=g.edge_index)
draw(edge_index)

提取的子图可视化如下所示:

PyG教程(2):图数据

k_hop_subgraph方法的示例代码如下所示:

g = k_hop_subgraph(
    node_idx=[0], num_hops=1, edge_index=g.edge_index)
print(g)
"""
(tensor([0, 1, 2, 4]), tensor([[0, 0, 0, 1, 1, 2, 2, 4],
        [1, 2, 4, 0, 2, 1, 0, 0]]), tensor([0]), tensor([ True,  True,  True,  True,  True,  True,  True, False, False, False,
        False,  True]))
"""

从上图可以看出,该方法返回一个4元组,元组的4个元素依次为:子图的节点集、子图的边集、用来查询的节点集(中心节点集)、指示原始图g中的边是否在子图中的布尔数组。我们取子图的边集进行可视化结果如下:

PyG教程(2):图数据

3.4 转换为无向图

通过to_undirected(edge_index)可以将一个图转换为无向图:

edge_index = torch.LongTensor([[0, 0], [1, 2]])
edge_index = to_undirected(edge_index)
print(edge_index)
"""
tensor([[0, 0, 1, 2],
        [1, 2, 0, 0]])
"""

结语

参考资料:

  • torch_geometric.data
  • torch_geometric.utils

本文主要介绍了PyG中对单个图的相关操作方法,从上面的操作可以看出对于PyG对图结构的操作其实就是在操作edge_index(该属性本来就用来在PyG中保存图的结构信息)。