You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
def get_mini_batch(sampler, root_nodes, ts, num_hops, extra_neg_samples): # neg_samples is not used
"""
Call function fetch_subgraph()
Return: Subgraph of each node.
"""
all_graphs = []
train_ptr = len(root_nodes) // (extra_neg_samples + 2)
for i, z in enumerate(zip(root_nodes, ts)):
if i == train_ptr:
sampler.reset()
root_node, root_time = z
all_graphs.append(fetch_subgraph(sampler, root_node, root_time, num_hops))
return all_graphs
作者您好!感谢您的出色工作~
我发现在为root_nodes进行邻居节点采样时,dst节点的采样可能存在问题,以下是问题的概述与修复方式:
在每次采样时,root_nodes会以batch_size * (2+sample_num)的形式送入采样器,后者会对每个root_node返回其历史邻居。但是,看上去,root_nodes的前batch_size个节点一定是该batch中的src,第batch_size到2*batch_size一定是该batch中的dst。
因此,当采样器遍历前batch_size个root_node(src)时,采样器中的指针会根据时间变化,这导致在采样dst节点时,倒数neighbor个历史邻居很可能不满足采样要求,导致无法采集到任何邻居。
我想询问是否应该在dst采样前执行一次
sample.reset()
?示例代码如下,位于construct_graph.py中:
期待您的回复~ @CongWeilin
The text was updated successfully, but these errors were encountered: