-
Notifications
You must be signed in to change notification settings - Fork 71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Model] add dhn #207
base: main
Are you sure you want to change the base?
[Model] add dhn #207
Conversation
examples/dhn/dhn_trainer.py
Outdated
tra_auc_cul = tf.keras.metrics.AUC() | ||
val_auc_cul = tf.keras.metrics.AUC() | ||
test_auc_cul = tf.keras.metrics.AUC() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在GammaGL中,不能出现 import tensorflow
,计算auc可以尝试使用 sklearn
库提供的 roc_auc_score
gammagl/models/dhn.py
Outdated
self.lin1 = tlx.nn.Linear(in_features=36, out_features=64, act=tlx.nn.ELU(), W_init="xavier_uniform") | ||
self.lin2 = tlx.nn.Linear(in_features=82, out_features=64, act=tlx.nn.ELU(), W_init="xavier_uniform") | ||
self.lin3 = tlx.nn.Linear(in_features=64, out_features=64, act=tlx.nn.ELU(), W_init="xavier_uniform") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里在定义 in_features
和 out_features
时,建议将数值作为参数传入,而不是使用固定值
gammagl/models/dhn.py
Outdated
class DHN(tlx.nn.Module): | ||
def __init__(self): | ||
super().__init__() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DHN
在 DHNModel
中定义使用,并且定义了多个 DHN
,建议将 DHN
改为 DHNLayer
表示单个 DHN
层,并把这个类移动到 gammagl/layers
中
gammagl/models/dhn.py
Outdated
def forward(self, fea): | ||
node = tlx.convert_to_tensor(fea[:, :NUM_FEA]) | ||
|
||
# 提取neigh1和neigh2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议把注释修改为英文的,下面也是
gammagl/models/dhn.py
Outdated
self.lin1 = tlx.nn.Linear(in_features=128, out_features=32, act=tlx.nn.ELU(), W_init="xavier_uniform") | ||
self.lin2 = tlx.nn.Linear(in_features=32, out_features=1, act=tlx.nn.ELU(), W_init="xavier_uniform") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in_features
和 out_features
的值同样应该通过参数传入的方式进行赋值
examples/dhn/dhn_trainer.py
Outdated
for epoch in range(EPOCH): | ||
print("-----Epoch {}/{}-----".format(epoch + 1, EPOCH)) | ||
|
||
# train | ||
m.set_train() | ||
tra_batch_A_fea, tra_batch_B_fea, tra_batch_y = batch_data(G_train, BATCH_SIZE).__next__() | ||
tra_out = m(tra_batch_A_fea, tra_batch_B_fea, tra_batch_y) | ||
|
||
data = { | ||
"n1": tra_batch_A_fea, | ||
"n2": tra_batch_B_fea, | ||
"label": tra_batch_y | ||
} | ||
|
||
tra_loss = net_with_train(data, tra_batch_y) | ||
tra_auc_cul.update_state(y_true=tra_batch_y, y_pred=tlx.sigmoid(tra_out).detach().numpy()) | ||
tra_auc = tra_auc_cul.result().numpy() | ||
print('train: ', tra_loss, tra_auc) | ||
|
||
# val | ||
m.set_eval() | ||
val_batch_A_fea, val_batch_B_fea, val_batch_y = batch_data(G_val, BATCH_SIZE).__next__() | ||
val_out = m(val_batch_A_fea, val_batch_B_fea, val_batch_y) | ||
|
||
val_loss = tlx.losses.sigmoid_cross_entropy(output=val_out, target=tlx.convert_to_tensor(val_batch_y)) | ||
val_auc_cul.update_state(y_true=val_batch_y, y_pred=tlx.sigmoid(val_out).detach().numpy()) | ||
val_auc = val_auc_cul.result().numpy() | ||
print("val: ", val_loss.item(), val_auc) | ||
|
||
# test | ||
test_batch_A_fea, test_batch_B_fea, test_batch_y = batch_data(G_test, BATCH_SIZE).__next__() | ||
test_out = m(test_batch_A_fea, test_batch_B_fea, test_batch_y) | ||
|
||
test_loss = tlx.losses.sigmoid_cross_entropy(output=test_out, target=tlx.convert_to_tensor(test_batch_y)) | ||
test_auc_cul.update_state(y_true=test_batch_y, y_pred=tlx.sigmoid(test_out).detach().numpy()) | ||
test_auc = test_auc_cul.result().numpy() | ||
print("test: ", test_loss.item(), test_auc) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
训练流程需要你参照 examples/gcn/gcn_trainer.py
进行编写,主要包括超参数定义、main
函数编写,损失函数编写等内容
gammagl/layers/conv/dhn_conv.py
Outdated
class DHNModel(tlx.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.dhn1 = DHNConv() | ||
self.dhn2 = DHNConv() | ||
self.lin1 = tlx.nn.Linear(in_features=4*BATCH_SIZE, out_features=BATCH_SIZE, act=tlx.nn.ELU(), W_init="xavier_uniform") | ||
self.lin2 = tlx.nn.Linear(in_features=BATCH_SIZE, out_features=1, act=tlx.nn.ELU(), W_init="xavier_uniform") | ||
|
||
def forward(self, n1, n2, label): | ||
n1_emb = self.dhn1(n1) | ||
n2_emb = self.dhn2(n2) | ||
|
||
pred = self.lin1(tlx.concat([n1_emb, n2_emb], axis=1)) | ||
pred = self.lin2(pred) | ||
|
||
return pred |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分代码可以放到gammagl/models下面
gammagl/layers/conv/dhn_conv.py
Outdated
type2idx = { | ||
'M': 0, | ||
'A': 1, | ||
# 'C': 2, | ||
# 'T': 3 | ||
} | ||
|
||
NODE_TYPE = len(type2idx) | ||
K_HOP = 2 | ||
|
||
NUM_FEA = (K_HOP + 2) * 4 + NODE_TYPE | ||
NUM_NEIGHBOR = 5 | ||
BATCH_SIZE=32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分变量可以作为__init__
和forward
函数的输入值传入,不要在文件中写死
examples/dhn/dhn_trainer.py
Outdated
def load_ACM(test_ratio=0.2): | ||
edge_index_M = [] | ||
edge_index_A = [] | ||
|
||
with open(args.dataset_path, 'r') as f: | ||
for line in f.readlines(): | ||
src, dst = line.strip().split() | ||
src_type, src_id = src[0], src[1:] # Resolves the source node type and ID | ||
dst_type, dst_id = dst[0], dst[1:] # Resolve the target node type and ID | ||
|
||
# Convert the node ID to an integer index and place it in a list | ||
if src[0] == 'M': | ||
edge_index_M.append(int(src_id)) | ||
elif src[0] == 'A': | ||
edge_index_A.append(-int(src_id) - 1) | ||
|
||
if dst[0] == 'M': | ||
edge_index_M.append(int(dst_id)) | ||
elif dst[0] == 'A': | ||
edge_index_A.append(-int(dst_id) - 1) | ||
|
||
edge_index = tlx.convert_to_tensor([edge_index_M, edge_index_A]) | ||
G['M', 'MA', 'A'].edge_index = edge_index | ||
|
||
# Computed split point | ||
sp = 1 - test_ratio * 2 | ||
num_edge = len(edge_index_M) | ||
sp1 = int(num_edge * sp) | ||
sp2 = int(num_edge * test_ratio) | ||
|
||
G_train = HeteroGraph() | ||
G_val = HeteroGraph() | ||
G_test = HeteroGraph() | ||
|
||
# Divide the training set, the verification set, and the test set | ||
G_train['M', 'MA', 'A'].edge_index = tlx.convert_to_tensor([edge_index_M[:sp1], edge_index_A[:sp1]]) | ||
G_val['M', 'MA', 'A'].edge_index = tlx.convert_to_tensor([edge_index_M[sp1:sp1 + sp2], edge_index_A[sp1:sp1 + sp2]]) | ||
G_test['M', 'MA', 'A'].edge_index = tlx.convert_to_tensor([edge_index_M[sp1 + sp2:], edge_index_A[sp1 + sp2:]]) | ||
|
||
print( | ||
f"all edge: {len(G['M', 'MA', 'A'].edge_index[0])}, train edge: {len(G_train['M', 'MA', 'A'].edge_index[0])}, val edge: {len(G_val['M', 'MA', 'A'].edge_index[0])}, test edge: {len(G_test['M', 'MA', 'A'].edge_index[0])}") | ||
|
||
return G_train, G_val, G_test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要你在gammagl/datasets
文件夹下,写一个ACM
数据集的文件,从而能够在其他文件中调用ACM
的接口,即可实现数据集的下载和处理,而不是直接读取txt文件,可以参考其他数据集的加载方式。
examples/dhn/dhn_trainer.py
Outdated
|
||
warnings.filterwarnings("ignore", category=UserWarning) | ||
|
||
random.seed(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以使用tlx.set_seed()
接口进行随机种子的设置
examples/dhn/dhn_trainer.py
Outdated
def find_all_simple_paths(edge_index, src, dest, max_length): | ||
# Converts edge_index to an adjacency list representation | ||
num_nodes = max(edge_index[0].max().item(), | ||
edge_index[1].max().item(), | ||
-edge_index[0].min().item(), | ||
-edge_index[1].min().item(), | ||
abs(src.item())) + 1 | ||
adj_list = [[] for _ in range(num_nodes)] | ||
for u, v in zip(edge_index[0].tolist(), edge_index[1].tolist()): | ||
adj_list[u].append(v) | ||
|
||
src = src.item() | ||
|
||
paths = [] | ||
visited = set() | ||
stack = [(src, [src])] | ||
|
||
while stack: | ||
(node, path) = stack.pop() | ||
|
||
if node == dest: | ||
paths.append(path) | ||
elif len(path) < max_length: | ||
for neighbor in adj_list[node]: | ||
if neighbor not in path: | ||
visited.add((node, neighbor)) | ||
stack.append((neighbor, path + [neighbor])) | ||
for neighbor in adj_list[node]: | ||
if (node, neighbor) in visited: | ||
visited.remove((node, neighbor)) | ||
|
||
return paths |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数的作用好像是要找到从源节点src
到目标节点dest
的简单路径,且长度不超过max_length
,可以把代码整理一下放到gammagl/utils
路径下,并用rst文档,描述一下该工具类的描述和用法。同时,需要给出test文件
gammagl/datasets/acm4dhn.py
Outdated
class ACM4DHN(InMemoryDataset): | ||
url = 'https://raw.githubusercontent.com/BUPT-GAMMA/HDE/main/ds/imdb' | ||
test_ratio = 0.3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test_ratio
是否可以作为一个参数输入呢?而不是作为一个固定值。另外,需要你在 gammagl/tests/datasets
路径下写一个测试文件,内容可以参考该路径下的其他文件怎么写的
examples/dhn/dhn_trainer.py
Outdated
type2idx = { | ||
'M': 0, | ||
'A': 1, | ||
# 'C': 2, | ||
# 'T': 3 | ||
} | ||
|
||
NODE_TYPE = len(type2idx) | ||
K_HOP = 2 | ||
|
||
NUM_FEA = (K_HOP + 2) * 4 + NODE_TYPE | ||
NUM_NEIGHBOR = 5 | ||
BATCH_SIZE=32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些值可以放到 parser
中进行设置,作为一个超参数
@@ -0,0 +1,324 @@ | |||
import argparse |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要写一个 readme.md
文件,内容参考 examples/
路径下的其他文件下下的 readme.md
的内容
Description
Checklist
Please feel free to remove inapplicable items for your PR.
or have been fixed to be compatible with this change
Changes