-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Expand file tree
/
Copy pathgraph_builder.py
More file actions
181 lines (149 loc) · 6.66 KB
/
graph_builder.py
File metadata and controls
181 lines (149 loc) · 6.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
"""
GraphBuilder Module
"""
from langchain_classic.chains import create_extraction_chain
from langchain_community.chat_models import ErnieBotChat
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from ..helpers import graph_schema, nodes_metadata
class GraphBuilder:
"""
GraphBuilder is a dynamic tool for constructing web scraping graphs based on user prompts.
It utilizes a natural language understanding model to interpret user prompts and
automatically generates a graph configuration for scraping web content.
Attributes:
prompt (str): The user's natural language prompt for the scraping task.
llm (ChatOpenAI): An instance of the ChatOpenAI class configured
with the specified llm_config.
nodes_description (str): A string description of all available nodes and their arguments.
chain (LLMChain): The extraction chain responsible for
processing the prompt and creating the graph.
Methods:
build_graph(): Executes the graph creation process based on the user prompt
and returns the graph configuration.
convert_json_to_graphviz(json_data): Converts a JSON graph configuration
to a Graphviz object for visualization.
Args:
prompt (str): The user's natural language prompt describing the desired scraping operation.
url (str): The target URL from which data is to be scraped.
llm_config (dict): Configuration parameters for the
language model, where 'api_key' is mandatory,
and 'model_name', 'temperature', and 'streaming' can be optionally included.
Raises:
ValueError: If 'api_key' is not included in llm_config.
"""
def __init__(self, prompt: str, config: dict):
"""
Initializes the GraphBuilder with a user prompt and language model configuration.
"""
self.prompt = prompt
self.config = config
self.llm = self._create_llm(config["llm"])
self.nodes_description = self._generate_nodes_description()
self.chain = self._create_extraction_chain()
def _create_llm(self, llm_config: dict):
"""
Creates an instance of the OpenAI class with the provided language model configuration.
Returns:
OpenAI: An instance of the OpenAI class.
Raises:
ValueError: If 'api_key' is not provided in llm_config.
"""
llm_defaults = {"temperature": 0, "streaming": True}
llm_params = {**llm_defaults, **llm_config}
if "api_key" not in llm_params:
raise ValueError("LLM configuration must include an 'api_key'.")
if "gpt-" in llm_params["model"]:
return ChatOpenAI(llm_params)
elif "gemini" in llm_params["model"]:
try:
from langchain_google_genai import ChatGoogleGenerativeAI
except ImportError:
raise ImportError(
"langchain_google_genai is not installed. Please install it using 'pip install langchain-google-genai'."
)
return ChatGoogleGenerativeAI(llm_params)
elif "ernie" in llm_params["model"]:
return ErnieBotChat(llm_params)
raise ValueError("Model not supported")
def _generate_nodes_description(self):
"""
Generates a string description of all available nodes and their arguments.
Returns:
str: A string description of all available nodes and their arguments.
"""
return "\n".join(
[
f"""- {node}: {data["description"]} (Type: {data["type"]},
Args: {", ".join(data["args"].keys())})"""
for node, data in nodes_metadata.items()
]
)
def _create_extraction_chain(self):
"""
Creates an extraction chain for processing the user prompt and
generating the graph configuration.
Returns:
LLMChain: An instance of the LLMChain class.
"""
create_graph_prompt_template = """
You are an AI that designs direct graphs for web scraping tasks.
Your goal is to create a web scraping pipeline that is efficient and tailored to the user's requirements.
You have access to a set of default nodes, each with specific capabilities:
{nodes_description}
Based on the user's input: "{input}", identify the essential nodes required for the task and suggest a graph configuration that outlines the flow between the chosen nodes.
""".format(
nodes_description=self.nodes_description, input="{input}"
)
extraction_prompt = ChatPromptTemplate.from_template(
create_graph_prompt_template
)
return create_extraction_chain(
prompt=extraction_prompt, schema=graph_schema, llm=self.llm
)
def build_graph(self):
"""
Executes the graph creation process based on the user prompt and
returns the graph configuration.
Returns:
dict: A JSON representation of the graph configuration.
"""
return self.chain.invoke(self.prompt)
@staticmethod
def convert_json_to_graphviz(json_data, format: str = "pdf"):
"""
Converts a JSON graph configuration to a Graphviz object for visualization.
Args:
json_data (dict): A JSON representation of the graph configuration.
Returns:
graphviz.Digraph: A Graphviz object representing the graph configuration.
"""
try:
import graphviz
except ImportError:
raise ImportError(
"The 'graphviz' library is required for this functionality. "
"Please install it from 'https://graphviz.org/download/'."
)
graph = graphviz.Digraph(
comment="ScrapeGraphAI Generated Graph",
format=format,
node_attr={"color": "lightblue2", "style": "filled"},
)
graph_config = json_data["text"][0]
# Retrieve nodes, edges, and the entry point from the JSON data
nodes = graph_config.get("nodes", [])
edges = graph_config.get("edges", [])
entry_point = graph_config.get("entry_point")
for node in nodes:
if node["node_name"] == entry_point:
graph.node(node["node_name"], shape="doublecircle")
else:
graph.node(node["node_name"])
for edge in edges:
if isinstance(edge["to"], list):
for to_node in edge["to"]:
graph.edge(edge["from"], to_node)
else:
graph.edge(edge["from"], edge["to"])
return graph