-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathAMR.py
154 lines (121 loc) · 4.88 KB
/
AMR.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from llama_index import (
SimpleDirectoryReader,
VectorStoreIndex,
ServiceContext,
)
from llama_index.node_parser import (
HierarchicalNodeParser,
get_leaf_nodes,
get_root_nodes
)
from llama_index.llms import Ollama
from llama_index.storage.docstore import SimpleDocumentStore
from llama_index.storage import StorageContext
from llama_index.retrievers.auto_merging_retriever import AutoMergingRetriever
from llama_index.embeddings import HuggingFaceEmbedding
from llama_index.query_engine import RetrieverQueryEngine
from llama_index.schema import BaseNode, NodeRelationship
from typing import List
# Initialize variables
documents_dir = "data/statements_txt_files"
llm_model_name = "llama2"
llm_temp = 0
llm_response_max_tokens = 1024
embed_model_name = "sentence-transformers/all-MiniLM-L6-v2"
top_k = 10
AMR_chunk_sizes = [2048, 512, 128]
AMR_chunk_overlap = 20
# LLM
llm_model= Ollama(
model=llm_model_name, temperature=llm_temp, max_tokens=llm_response_max_tokens
)
# Embedding
embed_model = HuggingFaceEmbedding(model_name=embed_model_name)
# Read the documents from the directory
reader = SimpleDirectoryReader(input_dir=documents_dir, filename_as_id=True)
docs = reader.load_data()
# Create nodes
node_parser = HierarchicalNodeParser.from_defaults(
chunk_sizes=AMR_chunk_sizes,
chunk_overlap=AMR_chunk_overlap,
include_metadata=True)
nodes = node_parser.get_nodes_from_documents(docs)
len(nodes)
# Simple helper function for fetching “intermediate” nodes within a node list. These are nodes that have both children and parent nodes.
# llama index only created helper functions for leaf nodes and root nodes, hence the need for this function.
def get_intermediate_nodes(nodes: List[BaseNode]) -> List[BaseNode]:
"""Get intermediate nodes."""
intermediate_nodes = []
for node in nodes:
if NodeRelationship.PARENT in node.relationships and NodeRelationship.CHILD in node.relationships:
intermediate_nodes.append(node)
return intermediate_nodes
# Get leaf, intermediate and root nodes
leaf_nodes = get_leaf_nodes(nodes)
intermediate_nodes = get_intermediate_nodes(nodes)
root_nodes = get_root_nodes(nodes)
len(leaf_nodes)
len(intermediate_nodes)
len(root_nodes)
# Add metadata to root nodes
for node in root_nodes:
title = node.metadata['file_name']
title = title.replace(".txt", "")
title = title.replace("_", " ")
node.metadata['title'] = title
# from llama_index.extractors import KeywordExtractor
# extractor = KeywordExtractor(llm=llm_model, keywords=2)
# metadata_dicts = extractor.extract(nodes[:2])
# Let leaf and intermediate nodes inherit metadata from their parent nodes
for node in intermediate_nodes:
parent_id = node.parent_node.node_id
matching_parent_node = [node for node in root_nodes if node.node_id == parent_id][0]
parent_metadata = matching_parent_node.metadata
node.metadata['title'] = parent_metadata['title']
for node in leaf_nodes:
parent_id = node.parent_node.node_id
matching_parent_node = [node for node in intermediate_nodes if node.node_id == parent_id][0]
parent_metadata = matching_parent_node.metadata
node.metadata['title'] = parent_metadata['title']
# Create docstore
docstore = SimpleDocumentStore()
# Insert nodes into docstore
docstore.add_documents(nodes)
# Define storage context (will include vector store by default too)
storage_context = StorageContext.from_defaults(docstore=docstore)
# Define service context
service_context = ServiceContext.from_defaults(
embed_model=embed_model,
llm=llm_model
)
# Load index into vector index
base_index = VectorStoreIndex(
leaf_nodes,
storage_context=storage_context,
service_context=service_context,
)
# Define retriever
base_retriever = base_index.as_retriever(similarity_top_k=top_k)
retriever = AutoMergingRetriever(base_retriever, storage_context, verbose=True)
# Query
query_1 = "What is the revenue for Top Strike?"
query_2 = "What are the revenues for UP Fintech and Top Strike."
query = query_2
# Display retrived nodes
nodes = retriever.retrieve(query)
base_nodes = base_retriever.retrieve(query)
# len(nodes)
# len(base_nodes)
# for i, node in enumerate(nodes):
# print("NODE " + str(i) + "\n\n" + str(node.score) + "\n\n" + node.text + "\n\n")
# for i, node in enumerate(base_nodes):
# includes_Top_Strike = True #node.metadata["file_name"] == "Top_Strike.txt"
# if includes_Top_Strike:
# print("NODE " + str(i) + "\n\n" + str(node.score) + "\n\n" + node.metadata["file_name"] + "\n\n" + node.text + "\n\n")
# Query engine
query_engine = RetrieverQueryEngine.from_args(retriever, service_context=service_context)
base_query_engine = RetrieverQueryEngine.from_args(base_retriever, service_context=service_context)
response = query_engine.query(query)
base_response = base_query_engine.query(query)
print("AMR:" + "\n\n" + str(response) + "\n\n")
print("BASE:" + "\n\n" + str(base_response))