-
-
Notifications
You must be signed in to change notification settings - Fork 2k
Expand file tree
/
Copy pathconditional_node.py
More file actions
112 lines (89 loc) · 3.8 KB
/
conditional_node.py
File metadata and controls
112 lines (89 loc) · 3.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""
Module for implementing the conditional node
"""
from typing import List, Optional
from simpleeval import EvalWithCompoundTypes, simple_eval
from .base_node import BaseNode
class ConditionalNode(BaseNode):
"""
A node that determines the next step in the graph's execution flow based on
the presence and content of a specified key in the graph's state. It extends
the BaseNode by adding condition-based logic to the execution process.
This node type is used to implement branching logic within the graph, allowing
for dynamic paths based on the data available in the current state.
It is expected that exactly two edges are created out of this node.
The first node is chosen for execution if the key exists and has a non-empty value,
and the second node is chosen if the key does not exist or is empty.
Attributes:
key_name (str): The name of the key in the state to check for its presence.
Args:
key_name (str): The name of the key to check in the graph's state. This is
used to determine the path the graph's execution should take.
node_name (str, optional): The unique identifier name for the node. Defaults
to "ConditionalNode".
"""
def __init__(
self,
input: str,
output: List[str],
node_config: Optional[dict] = None,
node_name: str = "Cond",
):
"""
Initializes an empty ConditionalNode.
"""
super().__init__(node_name, "conditional_node", input, output, 2, node_config)
try:
self.key_name = self.node_config["key_name"]
except (KeyError, TypeError) as e:
raise NotImplementedError(
"You need to provide key_name inside the node config"
) from e
self.true_node_name = None
self.false_node_name = None
self.condition = self.node_config.get("condition", None)
self.eval_instance = EvalWithCompoundTypes()
self.eval_instance.functions = {"len": len}
def execute(self, state: dict) -> dict:
"""
Checks if the specified key is present in the state and decides the next node accordingly.
Args:
state (dict): The current state of the graph.
Returns:
str: The name of the next node to execute based on the presence of the key.
"""
if self.true_node_name is None:
raise ValueError("ConditionalNode's next nodes are not set properly.")
if self.condition:
condition_result = self._evaluate_condition(state, self.condition)
else:
value = state.get(self.key_name)
condition_result = value is not None and value != ""
if condition_result:
return self.true_node_name
else:
return self.false_node_name
def _evaluate_condition(self, state: dict, condition: str) -> bool:
"""
Parses and evaluates the condition expression against the state.
Args:
state (dict): The current state of the graph.
condition (str): The condition expression to evaluate.
Returns:
bool: The result of the condition evaluation.
"""
# Combine state and allowed functions for evaluation context
eval_globals = self.eval_instance.functions.copy()
eval_globals.update(state)
try:
result = simple_eval(
condition,
names=eval_globals,
functions=self.eval_instance.functions,
operators=self.eval_instance.operators,
)
return bool(result)
except Exception as e:
raise ValueError(
f"Error evaluating condition '{condition}' in {self.node_name}: {e}"
)