type
status
date
slug
summary
tags
category
icon
password
😀
本文主要用于梳理 EMK-KEN 项目代码的工作,便于理解各代码的工作流程。
 

📝 主旨内容

1、项目概述

随着学术文献的爆炸式增长,有效评估文学的知识价值变得至关重要。然而,现有的引文评估方法在处理大规模引文网络时面临着效率低、鲁棒性不足的挑战。传统方法主要依赖于引用频率,这可能会导致忽视许多贡献率很高的论文。此外,尽管基于图神经网络(GNN)的方法在结构信息获取方面表现出色,但通常需要对整个引文网络进行建模,从而导致计算成本高和训练时间长。为了应对这些挑战,我们引入了 EMK-KEN,这是一种结合了 Mamba 和 KAN 架构的新型评估方法。这种方法能够有效地捕获引文网络结构信息,并精确评估知识价值。EMK-KEN 使用 Mamba 处理论文元数据和文本嵌入,提高了模型效率,并利用 KAN 捕获引文网络的拓扑模式,有效提高了鲁棒性和泛化能力。实验结果表明,EMK-KEN 在计算效率、鲁棒性和泛化能力方面优于现有方法。它还显示了在科学计量分析和学术推荐系统中的实际应用潜力,帮助研究人员更准确地识别高影响力文章。
notion image

2、项目结构

EMK-KEN 的项目结构如下:

3、项目代码梳理

3.1 数据检查和清洗

data_process.py 代码
该代码包含两个函数,其中,
(1)函数 is_english(text: str) 用于判断给定的文本是否为英文。它通过以下步骤进行判断:
① 文本标准化:将所有字符转换为小写并去除首尾空格。
② 分词:使用 word_tokenize 将文本分割成单词。
③ 检查分词后的长度是否至少有 3 个词。
④ 使用 langdetect 库检测语言,若检测结果为英文则返回 True。
(2)函数 preprocess_abstracts(input_file: str, output_file: str) 负责对摘要进行一系列预处理操作,包括:读取CSV文件并检查是否存在 abstract 列;创建一个列表来存储处理后的摘要,以及一个计数器来记录非英文摘要的数量;使用 stopwords.words('english') 获取英文停用词;对英文摘要进行标准化处理(转小写、替换换行符),并使用 word_tokenize 对标准化后的文本进行分词,同时过滤掉非字母字符和停用词,最后过滤后的词语重新组合成字符串并添加到处理后的摘要列表,并保存为新的 csv 文件。

3.2 生成文本嵌入

generate_embedding.py 代码
此代码使用 AllenAI 提供的 SciBERT 模型 (allenai/scibert_scivocab_uncased),从 csv 文件中读取摘要数据,并生成文本嵌入。代码定义了两种嵌入生成策略:(1)CLS Token Embedding:使用 SciBERT 输出中的 [CLS] 标记作为整个句子的向量表示;(2)Mean Pooling:对所有 token 的隐藏状态取平均值,生成固定维度的向量表示。具体步骤:
① 遍历每一条数据,获取 paper_id 和 abstract_text。
② 使用 tokenizer 对摘要进行编码。
③ 将输入传递给 SciBERT 模型,禁用梯度计算以提高效率。
④ 使用 CLS Token 和 Mean Pooling 方法生成嵌入向量。

3.3 元数据标准化

non_data_standardScaler.py 代码
这段代码实现了对多个 CSV 文件中的指定数值特征进行标准化处理的功能。它使用 StandardScaler 对数据进行 Z-score 标准化(均值为 0,标准差为 1)。具体来说,该代码依次读取每个文件,查是否包含所有目标列,提取这些列并合并成一个 DataFrame,同时记录每份数据的行数以便后续还原顺序。然后使用 StandardScaler 对合并后的数据进行拟合(fit)和转换(transform),将结果保存为新的 DataFrame。最后按照原始数据的行数分割标准化后的数据,替换原数据中的相应列,将修改后的数据写回原文件,保留其余未标准化列不变。

3.4 构建引文网络

citation_network.py 代码
此代码主要用于构建和分析论文引用网络,并计算相关的网络属性。它从 CSV 文件中读取论文及其参考文献信息,构建有向图(networkx.DiGraph),并为每篇论文的引用关系生成子图,最后统计和保存每个子图以及合并后的全局引用网络的拓扑属性。具体来说:
函数 build_citation_mapping(df: pd.DataFrame) 从 DataFrame 中提取论文 ID 到其参考文献的映射字典。该函数使用 ast.literal_eval 解析字符串格式的引用列表,跳过无效或者空的引用字段,并构建映射字典。
函数 analyze_subgraph(subgraph: nx.DiGraph) 用于分析一个子图的网络拓扑属性,包括:平均度(avg_degree)、最大/最小入度、出度(max_in_degree, min_in_degree, max_out_degree, min_out_degree)、直径(diameter)、密度(density)、聚类系数(clustering_coefficient)、入度/出度中心性平均值(in_degree_centrality_avg, out_degree_centrality_avg)。
主函数用于构建引用网络,具体操作如下,对于每篇论文:
  • 获取其直接引用的文献;
  • 构建子图:
  • 添加一级引用边(论文 → 引用文献);
  • 添加二级引用边(被引用文献 → 它们的引用);
  • 所有边也添加到 merged_graph 中;
  • 使用 analyze_subgraph 分析子图属性;
  • 将子图保存为 .pkl 文件;
  • 将属性记录到 properties 列表中。

3.5 合并引文网络和元数据及文本嵌入

merge_embed_and_network.py 代码
此代码的主要目的是将论文的引用网络(citation network)与它们的嵌入向量(embeddings)和元数据(metadata)进行融合,从而增强图结构的信息,便于后续用于机器学习或知识图谱任务。
函数 load_and_preprocess_data() 负责从文件中加载引用关系表(paper_id 和 references_id 列)、论文嵌入数据(第一列是 paper_id,其余列是嵌入向量)、元数据(包含 Id 列和其他属性如标题、摘要等),并返回一个包含这三个 DataFrame 的元组。
函数 create_lookup_dictionaries(embeddings: pd.DataFrame, metadata: pd.DataFrame) 接收嵌入数据和元数据两个 DataFrame,用于构造两个字典,一个用于通过论文 ID 快速查找其嵌入向量;另一个用于通过论文 ID 查找其元数据(除 ID 外的其他字段)。
函数 enhance_citation_graphs(citation_data: pd.DataFrame, embeddings: dict, metadata: dict, output_dir: str) 是整个程序的核心函数,它的任务包括:
  • 遍历所有有嵌入向量的论文 ID。
  • 对每个论文 ID,尝试加载其引用网络图(.pkl 文件)。
  • 如果图存在,则为中心节点(当前论文)添加嵌入和元数据。
  • 解析该论文引用了哪些文献,并为这些被引文献节点也添加嵌入和元数据。
  • 清理图结构,移除没有嵌入或元数据的节点,以及自环边。
  • 保存增强后的图到指定目录。
  • 收集所有成功处理的论文 ID 并返回。
函数 parse_reference_list(reference_series: pd.Series) 用于将可能以字符串形式存储(例如 "[1, 2, 3]")的引用列表,解析为 Python 列表格式,确保后续处理可以正确识别每个被引论文 ID。
函数 clean_graph(graph: nx.Graph) 用于清理可能缺少嵌入或元数据的小部分节点,并去除自环边,保证最终图的完整性与规范性。
函数 filter_csv_by_paper_ids(input_path: str, output_path: str, valid_ids: list) 用于过滤原始的元数据 CSV 文件,只保留那些在enhance_citation_graphs() 函数中成功处理的论文 ID 所对应的行,输出一个新的精简版 CSV 文件,便于后续使用。

3.6 构建多个PGY对象的图数据集

data.py 代码
这段代码的主要功能是:将合并后的引用网络图、元数据、嵌入向量和标签整合为 PyTorch Geometric 的 Data 对象,以便用于图神经网络模型的训练和推理。各个函数功能如下表所示:
函数名
作用描述
输入参数
返回值
clean_and_convert_embedding(embedding_str)
清洗并转换字符串形式的嵌入向量为浮点数列表。
embedding_str: 字符串格式的嵌入向量
清洗后的浮点数列表,若解析失败则返回 None
load_graph_data_with_metadata_embeddings_and_labels(root_dir, citation_dir, folders, date_dict)
从 .pkl 文件加载图数据,并结合元数据、嵌入信息和标签生成 Data 对象。
root_dir: 标签数据目录, citation_dir: 图数据根目录, folders: 子目录列表, date_dict: 论文 ID 到发表年份的映射
包含图结构与特征的 Data 对象列表
calculate_relative_time_encoding(time_stamps, center_index)
根据中心节点的发表日期计算相对时间编码,并归一化到 [0,1] 范围。
time_stamps: 所有节点的发表日期,center_index: 中心节点索引
归一化后的相对时间差(torch.Tensor)
build_date_dict(csv_files)
构建论文 ID 到标准化发表年份的字典,用于时间编码。
csv_files: CSV 文件路径列表
映射论文 ID 到相对于最早年份的年份值的字典
主函数
执行整个数据处理流水线,包括构建时间字典、加载图数据、保存处理结果。
无显式输入参数(通过变量配置路径等)
将处理后的图数据保存为 data_list.pt 文件

3.7 MamST Module

mamst.py 代码
该模型 MamST 是一个结合了 图神经网络 (GNN) 和 Mamba 架构 的混合模型,用于处理带有元数据和嵌入信息的图结构数据。它融合了空间(图结构)与时间(日期编码)特征,并通过 Mamba 模块捕获序列化特征中的长期依赖关系。MamST 的核心功能如下:
如果存在元数据(num_node_features > 0),通过全连接层(nn.Linear)和 ReLU 激活函数对其进行处理。然后将时间编码(date_encoding)与元数据结合,增强模型对时间信息的敏感性。
然后使用两个 Mamba 模块(MC1 和 MC2)分别处理元数据和嵌入。同时使用图卷积网络(GCNConv)处理嵌入,捕捉图结构中的拓扑信息。针对文本嵌入进行特征融合,将 GCN 和 Mamba 处理后的嵌入特征进行拼接,并通过全连接层调整维度。
最后,如果存在元数据,返回一个元组,包含处理后的元数据和嵌入。如果没有元数据,仅返回处理后的嵌入。

3.8 KNU Module

KNU 代码
KANLinear 代码
KAN 代码
针对 KNU 类:
函数名
功能描述
__init__
初始化 KNU 模块,包括 MamST、元数据和嵌入的 KAN 处理器、预测层和正则化层。
forward
定义前向传播过程,包括特征提取、中心节点选择、特征变换、特征融合和分类预测。
针对 KANLinear 类:
函数名
功能描述
__init__
初始化 KANLinear 层,包括基础权重、样条权重和相关参数。
reset_parameters
重置模型参数,包括基础权重和样条权重的初始化。
b_splines
计算输入张量的 B 样条基值。
curve2coeff
通过最小二乘拟合计算样条系数。
scaled_spline_weight
获取应用缩放后的样条权重。
forward
定义前向传播过程,包括基础组件和样条组件的计算。
update_grid
根据输入数据分布动态调整样条网格。
regularization_loss
计算正则化损失,包括 L1 正则化和熵正则化。
针对 KAN 类:
函数名
功能描述
__init__
初始化 KAN 网络,包括多个 KANLinear 层。
forward
顺序执行所有 KANLinear 层的前向传播,并可选地更新网格。
regularization_loss
聚合所有 KANLinear 层的正则化损失。
KNU 的工作流程如下:
  1. 首先,KNU 的 forward 函数调用 MamST 的 forward 函数进行特征提取。MamST 的 forward 函数处理输入数据,输出元数据和嵌入。
  1. 然后,KNU 的 forward 函数根据是否存在元数据,选择性地处理元数据和嵌入。
  1. 对于元数据和嵌入,分别调用 KAN 的 forward 函数进行特征变换。
  1. KAN 的 forward 函数调用 KANLinear 的 forward 函数进行逐层特征处理。
  1. KANLinear 的 forward 函数计算基础组件和样条组件的输出,并返回结果。
  1. KNU 的 forward 函数融合变换后的特征,进行最终分类预测。同时,KNU 的 forward 函数调用 KAN 的 regularization_loss 函数计算正则化损失。

3.9 EMK-KEN Module

emkken.py 代码
该代码定义了 EMKKEN 类,用于构建一个知识评价网络模型。这个类在初始化时会创建两个子模块:mamst 和 knu。 对于 MamST 模块,其传入的参数包括节点元数据维度、隐藏层大小、Mamba 层的状态维度、卷积核大小、Dropout 率等。对于 KNU 模块,它接收 MamST 实例本身作为参数,还包含输出类别数、中间层维度、Dropout 率等配置。这两个模块分别负责图特征提取和分类决策。

🤗 总结归纳

本文主要讲解了 EMK-KEN 的项目结构,便于理解模型各组成部分,及其工作原理。

📎 参考文章

💡
有关项目的安装或者使用上的问题,欢迎您在底部评论区留言,一起交流~
 
Mamba知识点OpenCV-Python学习笔记
Loading...