forked from datafold/data-diff
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbound_exprs.py
More file actions
98 lines (66 loc) · 2.42 KB
/
bound_exprs.py
File metadata and controls
98 lines (66 loc) · 2.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""Expressions bound to a specific database"""
import inspect
from functools import wraps
from typing import Union, TYPE_CHECKING
from runtype import dataclass
from typing_extensions import Self
from .abcs import AbstractDatabase, AbstractCompiler
from .queries.ast_classes import ExprNode, ITable, TablePath, Compilable
from .queries.api import table
from .schema import create_schema
@dataclass
class BoundNode(ExprNode):
database: AbstractDatabase
node: Compilable
def __getattr__(self, attr):
value = getattr(self.node, attr)
if inspect.ismethod(value):
@wraps(value)
def bound_method(*args, **kw):
return BoundNode(self.database, value(*args, **kw))
return bound_method
return value
def query(self, res_type=list):
return self.database.query(self.node, res_type=res_type)
@property
def type(self):
return self.node.type
def compile(self, c: AbstractCompiler) -> str:
assert c.database is self.database
return self.node.compile(c)
def bind_node(node, database):
return BoundNode(database, node)
ExprNode.bind = bind_node
@dataclass
class BoundTable(BoundNode): # ITable
database: AbstractDatabase
node: TablePath
def with_schema(self, schema) -> Self:
table_path = self.node.replace(schema=schema)
return self.replace(node=table_path)
def query_schema(self, *, columns=None, where=None, case_sensitive=True) -> Self:
table_path = self.node
if table_path.schema:
return self
raw_schema = self.database.query_table_schema(table_path.path)
schema = self.database._process_table_schema(table_path.path, raw_schema, columns, where)
schema = create_schema(self.database, table_path, schema, case_sensitive)
return self.with_schema(schema)
@property
def schema(self):
return self.node.schema
def bound_table(database: AbstractDatabase, table_path: Union[TablePath, str, tuple], **kw):
return BoundTable(database, table(table_path, **kw))
# Database.table = bound_table
# def test():
# from . import connect
# from .queries.api import table
# d = connect("mysql://erez:qweqwe123@localhost/erez")
# t = table(('Rating',))
# b = BoundTable(d, t)
# b2 = b.with_schema()
# breakpoint()
# test()
if TYPE_CHECKING:
class BoundTable(BoundTable, TablePath):
pass