图卷积神经网络(GCN)-matlab代码 (2024)

1.GCN的发展历程

图卷积神经网络,英文全称 Graph Convolutional Neural Networks,以下简写为 GCN,提出于 2017 年。GCN 在 CCN 与 RNN 的基础之上发展而来。CCN 与 RNN 是深度学习中的经典模型。CCN 的处理对象是图片,图片是一个二维的结构。CNN 用于提取图片的特征,其关键在于卷积核 kernel,kernel 是一个小窗口,其大小一般是 3×3,这个小窗口在图片上的平移,通过卷积的方式来提取图片的特征。RNN 的处理对象是自然语言序列这样的图片信息,这种信息是一维的结构。在 RNN 当中,通过各种门的操作,使得序列前后的信息相互影响,从而很好地捕捉序列的特征。上面所讲的图片或者是语言都是欧氏空间的数据,所以结构很规则。但是在现实生活中,有很多不规则的数据结构,比如本文所关注的图结构,或者说是拓扑图的结构,再比如化学分子结构。这种图的结构十分不规则,有学者认为这是无限维的一种数据,这让 CNN、RNN 瞬间失效。为了处理这类数据,很多学者从上个世纪就开始研究处理这种数据的方法。这里涌现出了许多方法,例如 GNN、DeepWalk、node2vec 等等,GCN 就是其中一种。

2.GCN的应用

GCN 的作用实际上和 CNN 差不多,都是一个特征提取器,只不过 GCN 的处理对象更加复杂。GCN 设计了一种从图数据当中提取特征的方法,从而可以应用到对图数据进行节点分类、图分类、边预测、还可以得到图的嵌入表示。

3.GCN的数学原理

下图是GCN的架构图,从图中可以看出GCN有三层:输入层、隐藏层、输出层。其中输入层只有一层,隐藏层可以有若干层,输出层也只有一层。在输入层,邻接矩阵表示图中每个节点之间的连接关系,特征值表示每个节点的特征。在隐藏层,节点首先聚合自己与邻居节点的特征,而后通过非线性运算(ReLU、Sigmoid是非线性激活函数,可以灵活选取),如此反复进行,隐藏层有多少层,这样的过程就执行多少次。在输出层,输出GCN的预测结果,或者分类结果。

图卷积神经网络(GCN)-matlab代码 (1)

在这里,以节点分类为例来叙述GCN的应用。在实际应用中,GCN可分为3个过程,训练、验证、测试。在本篇所要介绍的代码中,每训练300次就进行一次验证,可以说训练与验证过程是同一阶段进行的。GCN的输入参数是:邻接矩阵、特征值、层间的权重系数。GCN的训练与验证过程:GCN通过邻接矩阵与特征值预测图中节点的类别,将预测的节点类别与节点标签(即实际上的节点类别)进行对比,得出误差,更新权重以使误差更小,如此往复,直到迭代次数用尽。GCN的测试过程:将测试数据(即就是需要预测分类的数据)输入训练好的GCN中,这样就能比较准确地预测节点的分类。

图卷积神经网络的计算过程如下:

首先是 GCN 模型的基础数学公式:

图卷积神经网络(GCN)-matlab代码 (2)
图卷积神经网络(GCN)-matlab代码 (3)

对于邻接矩阵的定义是矩阵中的值为对应位置节点与节点之间的关系,比如节点 1与节点 2 相连,则邻接矩阵中 1 行 2 列和 2 行 1 列的值置为 1,倘若节点 1 与节点 2 不相连,则邻接矩阵中 1 行 2 列和 2 行 1 列的值置为 0。然而节点中对角线的位置是节点与自身的关系。而节点与自身并没有边相连,所以邻接矩阵中的对角线自然都为 0。这样一来的话,在后续的计算过程中就无法区别邻接矩阵中的“自身节点”与“无连接节点”。

图卷积神经网络(GCN)-matlab代码 (4)
图卷积神经网络(GCN)-matlab代码 (5)

4.GCN的matlab例程

(1)在matlab的“附加功能”中有“获取附加功能”的选项,点击它。

图卷积神经网络(GCN)-matlab代码 (6)

(2)然后搜索gcn,在Deep Learning Toolbox中点击Node Classification Using Graph Convolutional Network。

图卷积神经网络(GCN)-matlab代码 (7)

(3)而后点击Open Live Script。

图卷积神经网络(GCN)-matlab代码 (8)

(4)按照提示进行下一步操作。

图卷积神经网络(GCN)-matlab代码 (9)
图卷积神经网络(GCN)-matlab代码 (10)

5.GCN的matlab代码介绍-利用GCN进行节点分类

接下来我将简单介绍一下这篇代码。Node Classification Using Graph Convolutional Network的代码。

这个例程说明了如何使用GCN进行节点分类。在此示例中,GCN必须预测图中未标记节点的标签,节点的标签就是节点的类别,预测标签就是 节点分类。本例是一个给分子中的原子进行分类的例程。在本例中,图形由分子表示。分子中的原子表示图中的节点, 原子之间的化学键表示图中的边。邻接矩阵所表示的连接关系就是分子中原子的连接关系,特征值就是原子序数, 输出的节点标签是“C、H、O、N、S”等原子类型。因此,GCN的输入是分子图,输出是对分子中每个未标记原子的原子类型的预测。

为了给图的每个节点分配一个分类标签,GCN对函数 f(X,A) 在图 G(V,E) 上进行建模。其中 V 表示节点, E 表示边。

  • X是特征值矩阵,维度是 N*C,N是图G(V,E)的节点数,即N=|V|,C是每个节点的特征数目
  • A是邻接矩阵,维度是N*N,矩阵中元素为0表示两个接待无连接,为1表示两个节点有连接。
  • Z是输出矩阵,维度是N*F,F是节点的类别数目,矩阵元素为0表示节点不属于这个类别,矩阵元素为1表示节点属于这个类别。

(1)下载 QM7 数据集

本例使用QM7数据集[2][3],这是一个分子数据集,由7165个分子组成,最多由23个原子组成。也就是说,原子数最多的分子有23个原子。总的来说,数据集由5个独特的原子组成:碳(C)、氢(H)、氮(N)、氧(O)和硫(S)。

dataURL = 'http://quantum-machine.org/data/qm7.mat';outputFolder = fullfile(tempdir,'qm7Data');dataFile = fullfile(outputFolder,'qm7.mat');if ~exist(dataFile, 'file') mkdir(outputFolder); fprintf('Downloading file ''%s'' ...\n', dataFile); websave(dataFile, dataURL);enddata = load(dataFile)

(2)数据预处理

得到训练数据集,验证数据集。每个数据集由三部分组成:特征值、邻接矩阵、节点标签。

coulombData = double(permute(data.X, [2 3 1]));atomicNumber = sort(data.Z,2,'descend'); adjacencyData = coloumb2Adjacency(coulombData, atomicNumber);[adjacencyDataSplit, coulombDataSplit, atomicNumberSplit] = splitData(adjacencyData, coulombData, atomicNumber);[adjacency, features, labels] = cellfun(@preprocessData, adjacencyDataSplit, coulombDataSplit, atomicNumberSplit, 'UniformOutput', false);features = normalizeFeatures(features);%训练数据集featureTrain = features{1};adjacencyTrain = adjacency{1};targetTrain = labels{1};%验证数据集featureValidation = features{2};adjacencyValidation = adjacency{2};targetValidation = labels{2};

(3)数据可视化(可以忽略)

idx = [1 5 300 1159];for j = 1:numel(idx) % Remove padded zeros from the data atomicNum = nonzeros(atomicNumber(idx(j),:)); numOfNodes = numel(atomicNum); adj = adjacencyData(1:numOfNodes,1:numOfNodes,idx(j)); % Convert adjacency matrix to graph compound = graph(adj); % Convert atomic numbers to symbols symbols = cell(numOfNodes, 1); for i = 1:numOfNodes if atomicNum(i) == 1 symbols{i} = 'C'; elseif atomicNum(i) == 6 symbols{i} = 'H'; elseif atomicNum(i) == 7 symbols{i} = 'N'; elseif atomicNum(i) == 8 symbols{i} = 'O'; else symbols{i} = 'S'; end end % Plot graph subplot(2,2,j) plot(compound, 'NodeLabel', symbols, 'LineWidth', 0.75, ... 'Layout', 'force') title("Molecule " + idx(j))endlabelsAll = cat(1,labels{:});classes = categories(labelsAll);figurehistogram(labelsAll)xlabel('Category')ylabel('Frequency')title('Label Counts')

(4)GCN模型参数

numInputFeatures = size(featureTrain,2)%输入层神经元个数numHiddenFeatureMaps = 32;%隐藏层神经元个数numOutputFeatures = numel(classes);%输出层神经元个数sz = [numInputFeatures numHiddenFeatureMaps];numOut = numHiddenFeatureMaps;numIn = numInputFeatures;parameters.W1 = initializeGlorot(sz,numOut,numIn,'double');%第一层权重sz = [numHiddenFeatureMaps numHiddenFeatureMaps];numOut = numHiddenFeatureMaps;numIn = numHiddenFeatureMaps;parameters.W2 = initializeGlorot(sz,numOut,numIn,'double');%第二层权重sz = [numHiddenFeatureMaps numOutputFeatures];numOut = numOutputFeatures;numIn = numHiddenFeatureMaps;parameters.W3 = initializeGlorot(sz,numOut,numIn,'double');%第三层权重numEpochs = 1500;%迭代次数learnRate = 0.01;%学习速率validationFrequency = 300;%每训练300次验证1次plots = "training-progress";executionEnvironment = "auto";

(5)训练与验证过程

if plots == "training-progress" figure % Accuracy. subplot(2,1,1) lineAccuracyTrain = animatedline('Color',[0 0.447 0.741]); lineAccuracyValidation = animatedline( ... 'LineStyle','--', ... 'Marker','o', ... 'MarkerFaceColor','black'); ylim([0 1]) xlabel("Epoch") ylabel("Accuracy") grid on % Loss. subplot(2,1,2) lineLossTrain = animatedline('Color',[0.85 0.325 0.098]); lineLossValidation = animatedline( ... 'LineStyle','--', ... 'Marker','o', ... 'MarkerFaceColor','black'); ylim([0 inf]) xlabel("Epoch") ylabel("Loss") grid onendtrailingAvg = [];trailingAvgSq = [];dlX = dlarray(featureTrain);dlXValidation = dlarray(featureValidation);if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlX = gpuArray(dlX);endT = onehotencode(targetTrain, 2, 'ClassNames', classes);TValidation = onehotencode(targetValidation, 2, 'ClassNames', classes);start = tic;% Loop over epochs.for epoch = 1:numEpochs % Evaluate the model gradients and loss using dlfeval and the % modelGradients function. [gradients, loss, dlYPred] = dlfeval(@modelGradients, dlX, adjacencyTrain, T, parameters); % Update the network parameters using the Adam optimizer. [parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ... trailingAvg,trailingAvgSq,epoch,learnRate); % Display the training progress. if plots == "training-progress" subplot(2,1,1) D = duration(0,0,toc(start),'Format','hh:mm:ss'); title("Epoch: " + epoch + ", Elapsed: " + string(D)) % Loss. addpoints(lineLossTrain,epoch,double(gather(extractdata(loss)))) % Accuracy score. score = accuracy(dlYPred, targetTrain, classes); addpoints(lineAccuracyTrain,epoch,double(gather(score))) drawnow % Display validation metrics. if epoch == 1 || mod(epoch,validationFrequency) == 0 % Loss. dlYPredValidation = model(dlXValidation, adjacencyValidation, parameters); lossValidation = crossentropy(dlYPredValidation, TValidation, 'DataFormat', 'BC'); addpoints(lineLossValidation,epoch,double(gather(extractdata(lossValidation)))) % Accuracy score. scoreValidation = accuracy(dlYPredValidation, targetValidation, classes); addpoints(lineAccuracyValidation,epoch,double(gather(scoreValidation))) drawnow end endend

(6)测试过程

%测试数据集featureTest = features{3};adjacencyTest = adjacency{3};targetTest = labels{3};%测试dlXTest = dlarray(featureTest);dlYPredTest = model(dlXTest, adjacencyTest, parameters);%GCN模型函数%测试结果[scoreTest, predTest] = accuracy(dlYPredTest, targetTest, classes);

(7)可视化测试结果-混淆矩阵

numOfSamples = numel(targetTest);classTarget = zeros(numOfSamples, numOutputFeatures);classPred = zeros(numOfSamples, numOutputFeatures);for i = 1:numOutputFeatures classTarget(:,i) = targetTest==categorical(classes(i)); classPred(:,i) = predTest==categorical(classes(i));end% Compute class-wise accuracy scoreclassAccuracy = sum(classPred == classTarget)./numOfSamples;% Visualize class-wise accuracy scorefigure[~,idx] = sort(classAccuracy,'descend');histogram('Categories',classes(idx), ... 'BinCounts',classAccuracy(idx), ... 'Barwidth',0.8)xlabel("Category")ylabel("Accuracy")title("Class Accuracy Score")confusionMatrix, order] = confusionmat(targetTest, predTest);figurecm = confusionchart(confusionMatrix, classes, ... 'ColumnSummary','column-normalized', ... 'RowSummary','row-normalized', ... 'Title', 'GCN QM7 Confusion Chart');

(8)代码中的函数1:Split Data Function。

将数据集按照8:1:1的比例划分,分别作为训练数据、验证数据、测试数据。

function [adjacencyDataSplit, coulombDataSplit, atomicNumberSplit] = splitData(adjacencyData, coulombData, atomicNumber)adjacencyDataSplit = cell(1,3);coulombDataSplit = cell(1,3);atomicNumberSplit = cell(1,3);numMolecules = size(adjacencyData, 3);% Set initial random state for example reproducibility.rng(0);% Get training dataidx = randperm(size(adjacencyData, 3), floor(0.8*numMolecules));adjacencyDataSplit{1} = adjacencyData(:,:,idx);coulombDataSplit{1} = coulombData(:,:,idx);atomicNumberSplit{1} = atomicNumber(idx,:);adjacencyData(:,:,idx) = [];coulombData(:,:,idx) = [];atomicNumber(idx,:) = [];% Get validation dataidx = randperm(size(adjacencyData, 3), floor(0.1*numMolecules));adjacencyDataSplit{2} = adjacencyData(:,:,idx);coulombDataSplit{2} = coulombData(:,:,idx);atomicNumberSplit{2} = atomicNumber(idx,:);adjacencyData(:,:,idx) = [];coulombData(:,:,idx) = [];atomicNumber(idx,:) = [];% Get test dataadjacencyDataSplit{3} = adjacencyData;coulombDataSplit{3} = coulombData;atomicNumberSplit{3} = atomicNumber;end

(9)代码中的函数2:Preprocess Data Function

进行数据预处理,从原始数据集中得到邻接矩阵、特征值矩阵、节点标签矩阵。

function [adjacency, features, labels] = preprocessData(adjacencyData, coulombData, atomicNumber)adjacency = sparse([]);features = [];labels = [];for i = 1:size(adjacencyData, 3) % Remove padded zeros from atomicNumber tmpLabels = nonzeros(atomicNumber(i,:)); labels = [labels; tmpLabels]; % Get the indices of the un-padded data validIdx = 1:numel(tmpLabels); % Use the indices for un-padded data to remove padded zeros % from the adjacency data tmpAdjacency = adjacencyData(validIdx, validIdx, i); % Build the adjacency matrix into a block diagonal matrix adjacency = blkdiag(adjacency, tmpAdjacency); % Remove padded zeros from coulombData and extract the % feature array tmpFeatures = diag(coulombData(validIdx, validIdx, i)); features = [features; tmpFeatures];end% Convert labels to categorical arrayatomicNumbers = unique(labels);atomNames = ["Hydrogen","Carbon","Nitrogen","Oxygen","Sulphur"];labels = categorical(labels, atomicNumbers, atomNames);end

(10)代码中的函数3:Normalize Features Function

对特征值矩阵进行归一化,然后才能输入到GCN中。

function features = normalizeFeatures(features)% Get the mean and variance from the training datameanFeatures = mean(features{1});varFeatures = var(features{1}, 1);% Standardize training, validation and test datafor i = 1:3 features{i} = (features{i} - meanFeatures)./sqrt(varFeatures);endend

(11)代码中的函数4:Model Function

搭建GCN模型。GCN的输入参数是:邻接矩阵、特征值、层间的权重系数。GCN的训练与验证过程:GCN通过邻接矩阵与特征值预测图中节点的类别,将预测的节点类别与节点标签(即实际上的节点类别)进行对比,得出误差,更新权重以使误差更小,如此往复,直到迭代次数用尽。GCN的测试过程:将测试数据(即就是需要预测分类的数据)输入训练好的GCN中,这样就能比较准确地预测节点的分类。

function dlY = model(dlX, A, parameters)% Normalize adjacency matrixL = normalizeAdjacency(A);Z1 = dlX;Z2 = L * Z1 * parameters.W1;Z2 = relu(Z2) + Z1;Z3 = L * Z2 * parameters.W2;Z3 = relu(Z3) + Z2;Z4 = L * Z3 * parameters.W3;dlY = softmax(Z4, 'DataFormat', 'BC');end

(12)代码中的函数5:Normalize Ajacency Function

对邻接矩阵进行数据预处理,以方便GCN的处理。

function normAdjacency = normalizeAdjacency(adjacency)% Add self connections to adjacency matrixadjacency = adjacency + speye(size(adjacency));% Compute degree of nodesdegree = sum(adjacency, 2);% Compute inverse square root of degreedegreeInvSqrt = sparse(sqrt(1./degree));% Normalize adjacency matrixnormAdjacency = diag(degreeInvSqrt) * adjacency * diag(degreeInvSqrt);end

(13)代码中的函数6:Model Gradients Function

用于衡量GCN模型的精度。

function [gradients, loss, dlYPred] = modelGradients(dlX, adjacencyTrain, T, parameters)dlYPred = model(dlX, adjacencyTrain, parameters);loss = crossentropy(dlYPred, T, 'DataFormat', 'BC');gradients = dlgradient(loss, parameters);end

(14)代码中的函数7:Accuracy Function

用于衡量GCN的分类准确率。

function [score, prediction] = accuracy(YPred, target, classes)% Decode probability vectors into class labelsprediction = onehotdecode(YPred, classes, 2);score = sum(prediction == target)/numel(target);end
图卷积神经网络(GCN)-matlab代码 (2024)

References

Top Articles
Latest Posts
Article information

Author: Sen. Emmett Berge

Last Updated:

Views: 5508

Rating: 5 / 5 (60 voted)

Reviews: 83% of readers found this page helpful

Author information

Name: Sen. Emmett Berge

Birthday: 1993-06-17

Address: 787 Elvis Divide, Port Brice, OH 24507-6802

Phone: +9779049645255

Job: Senior Healthcare Specialist

Hobby: Cycling, Model building, Kitesurfing, Origami, Lapidary, Dance, Basketball

Introduction: My name is Sen. Emmett Berge, I am a funny, vast, charming, courageous, enthusiastic, jolly, famous person who loves writing and wants to share my knowledge and understanding with you.