Edit on GitHub

sqlglot.expression_core

  1from __future__ import annotations
  2
  3import sys
  4import typing as t
  5from collections import deque
  6from copy import deepcopy
  7
  8from sqlglot.helper import mypyc_attr, to_bool
  9from sqlglot.tokenizer_core import Token
 10
 11
 12EC = t.TypeVar("EC", bound="ExpressionCore")
 13
 14
 15POSITION_META_KEYS: t.Tuple[str, ...] = ("line", "col", "start", "end")
 16SQLGLOT_META: str = "sqlglot.meta"
 17UNITTEST: bool = "unittest" in sys.modules or "pytest" in sys.modules
 18
 19
 20@mypyc_attr(allow_interpreted_subclasses=True)
 21class ExpressionCore:
 22    __slots__ = (
 23        "args",
 24        "parent",
 25        "arg_key",
 26        "index",
 27        "comments",
 28        "_type",
 29        "_meta",
 30        "_hash",
 31    )
 32
 33    key: t.ClassVar[str]
 34    arg_types: t.ClassVar[t.Dict[str, bool]] = {}
 35    required_args: t.ClassVar[t.Set[str]] = set()
 36    is_var_len_args: t.ClassVar[bool] = False
 37    is_func: t.ClassVar[bool] = False
 38    _hash_raw_args: t.ClassVar[bool] = False
 39
 40    def __init__(self, **args: object) -> None:
 41        self.args: t.Dict[str, t.Any] = args
 42        self.parent: t.Optional[ExpressionCore] = None
 43        self.arg_key: t.Optional[str] = None
 44        self.index: t.Optional[int] = None
 45        self.comments: t.Optional[t.List[str]] = None
 46        self._type: t.Optional[ExpressionCore] = None
 47        self._meta: t.Optional[t.Dict[str, t.Any]] = None
 48        self._hash: t.Optional[int] = None
 49
 50        for arg_key, value in self.args.items():
 51            self._set_parent(arg_key, value)
 52
 53    def _set_parent(self, arg_key: str, value: object, index: t.Optional[int] = None) -> None:
 54        if isinstance(value, ExpressionCore):
 55            value.parent = self
 56            value.arg_key = arg_key
 57            value.index = index
 58        elif isinstance(value, list):
 59            for i, v in enumerate(value):
 60                if isinstance(v, ExpressionCore):
 61                    v.parent = self
 62                    v.arg_key = arg_key
 63                    v.index = i
 64
 65    def iter_expressions(self: EC, reverse: bool = False) -> t.Iterator[EC]:
 66        for vs in reversed(self.args.values()) if reverse else self.args.values():
 67            if isinstance(vs, list):
 68                for v in reversed(vs) if reverse else vs:
 69                    if isinstance(v, ExpressionCore):
 70                        yield t.cast(EC, v)
 71            elif isinstance(vs, ExpressionCore):
 72                yield t.cast(EC, vs)
 73
 74    def bfs(self: EC, prune: t.Optional[t.Callable[[EC], bool]] = None) -> t.Iterator[EC]:
 75        queue: t.Deque[EC] = deque()
 76        queue.append(self)
 77        while queue:
 78            node = queue.popleft()
 79            yield node
 80            if prune and prune(node):
 81                continue
 82            for v in node.iter_expressions():
 83                queue.append(v)
 84
 85    def dfs(self: EC, prune: t.Optional[t.Callable[[EC], bool]] = None) -> t.Iterator[EC]:
 86        stack: t.List[EC] = [self]
 87        while stack:
 88            node = stack.pop()
 89            yield node
 90            if prune and prune(node):
 91                continue
 92            for v in node.iter_expressions(reverse=True):
 93                stack.append(v)
 94
 95    @property
 96    def meta(self) -> t.Dict[str, t.Any]:
 97        if self._meta is None:
 98            self._meta = {}
 99        return self._meta
100
101    @property
102    def this(self) -> t.Any:
103        return self.args.get("this")
104
105    @property
106    def expression(self) -> t.Any:
107        return self.args.get("expression")
108
109    @property
110    def expressions(self) -> t.List[t.Any]:
111        return self.args.get("expressions") or []
112
113    def pop_comments(self) -> t.List[str]:
114        comments = self.comments or []
115        self.comments = None
116        return comments
117
118    def append(self, arg_key: str, value: t.Any) -> None:
119        if type(self.args.get(arg_key)) is not list:
120            self.args[arg_key] = []
121        self._set_parent(arg_key, value)
122        values = self.args[arg_key]
123        if hasattr(value, "parent"):
124            value.index = len(values)
125        values.append(value)
126
127    @property
128    def depth(self) -> int:
129        if self.parent:
130            return self.parent.depth + 1
131        return 0
132
133    def find_ancestor(self, *expression_types: t.Type[EC]) -> t.Optional[EC]:
134        ancestor = self.parent
135        while ancestor and not isinstance(ancestor, expression_types):
136            ancestor = ancestor.parent
137        return ancestor  # type: ignore[return-value]
138
139    @property
140    def same_parent(self) -> bool:
141        return type(self.parent) is self.__class__
142
143    def root(self) -> ExpressionCore:
144        expression = self
145        while expression.parent:
146            expression = expression.parent
147        return expression
148
149    def __eq__(self, other: object) -> bool:
150        return self is other or (type(self) is type(other) and hash(self) == hash(other))
151
152    def __hash__(self) -> int:
153        if self._hash is None:
154            nodes: t.List[ExpressionCore] = []
155            queue: t.Deque[ExpressionCore] = deque()
156            queue.append(self)
157
158            while queue:
159                node = queue.popleft()
160                nodes.append(node)
161
162                for child in node.iter_expressions():
163                    if child._hash is None:
164                        queue.append(child)
165
166            for node in reversed(nodes):
167                hash_ = hash(node.key)
168
169                if node._hash_raw_args:
170                    for k, v in sorted(node.args.items()):
171                        if v:
172                            hash_ = hash((hash_, k, v))
173                else:
174                    for k, v in sorted(node.args.items()):
175                        vt = type(v)
176
177                        if vt is list:
178                            for x in v:
179                                if x is not None and x is not False:
180                                    hash_ = hash((hash_, k, x.lower() if type(x) is str else x))
181                                else:
182                                    hash_ = hash((hash_, k))
183                        elif v is not None and v is not False:
184                            hash_ = hash((hash_, k, v.lower() if vt is str else v))
185
186                node._hash = hash_
187        assert self._hash
188        return self._hash
189
190    def error_messages(self, args: t.Optional[t.Sequence] = None) -> t.List[str]:
191        errors: t.List[str] = []
192
193        if UNITTEST:
194            for k in self.args:
195                if k not in self.arg_types:
196                    raise TypeError(f"Unexpected keyword: '{k}' for {self.__class__}")
197
198        for k in self.required_args:
199            v = self.args.get(k)
200            if v is None or (isinstance(v, list) and not v):
201                errors.append(f"Required keyword: '{k}' missing for {self.__class__}")
202
203        if args and self.is_func and len(args) > len(self.arg_types) and not self.is_var_len_args:
204            errors.append(
205                f"The number of provided arguments ({len(args)}) is greater than "
206                f"the maximum number of supported arguments ({len(self.arg_types)})"
207            )
208
209        return errors
210
211    def update_positions(
212        self: EC,
213        other: t.Optional[ExpressionCore | Token] = None,
214        line: t.Optional[int] = None,
215        col: t.Optional[int] = None,
216        start: t.Optional[int] = None,
217        end: t.Optional[int] = None,
218    ) -> EC:
219        if other is None:
220            self.meta["line"] = line
221            self.meta["col"] = col
222            self.meta["start"] = start
223            self.meta["end"] = end
224        elif isinstance(other, ExpressionCore):
225            for k in POSITION_META_KEYS:
226                if k in other.meta:
227                    self.meta[k] = other.meta[k]
228        else:
229            # Token: has .line, .col, .start, .end attributes
230            self.meta["line"] = other.line
231            self.meta["col"] = other.col
232            self.meta["start"] = other.start
233            self.meta["end"] = other.end
234        return self
235
236    def to_py(self) -> t.Any:
237        raise ValueError(f"{self} cannot be converted to a Python object.")
238
239    def text(self, key: str) -> str:
240        field = self.args.get(key)
241        if isinstance(field, str):
242            return field
243        return ""
244
245    @property
246    def name(self) -> str:
247        return self.text("this")
248
249    @property
250    def alias(self) -> str:
251        alias = self.args.get("alias")
252        if isinstance(alias, ExpressionCore):
253            return alias.name
254        return self.text("alias")
255
256    @property
257    def alias_column_names(self) -> t.List[str]:
258        table_alias = self.args.get("alias")
259        if not table_alias:
260            return []
261        return [c.name for c in table_alias.args.get("columns") or []]
262
263    @property
264    def alias_or_name(self) -> str:
265        return self.alias or self.name
266
267    @property
268    def output_name(self) -> str:
269        return ""
270
271    def is_leaf(self) -> bool:
272        return not any(
273            (isinstance(v, ExpressionCore) or type(v) is list) and v for v in self.args.values()
274        )
275
276    def __deepcopy__(self, memo: t.Any) -> ExpressionCore:
277        root = self.__class__()
278        stack: t.List[t.Tuple[ExpressionCore, ExpressionCore]] = [(self, root)]
279
280        while stack:
281            node, copy = stack.pop()
282
283            if node.comments is not None:
284                copy.comments = deepcopy(node.comments)
285            if node._type is not None:
286                copy._type = deepcopy(node._type)
287            if node._meta is not None:
288                copy._meta = deepcopy(node._meta)
289            if node._hash is not None:
290                copy._hash = node._hash
291
292            for k, vs in node.args.items():
293                if isinstance(vs, ExpressionCore):
294                    stack.append((vs, vs.__class__()))
295                    copy.set(k, stack[-1][-1])
296                elif type(vs) is list:
297                    copy.args[k] = []
298
299                    for v in vs:
300                        if isinstance(v, ExpressionCore):
301                            stack.append((v, v.__class__()))
302                            copy.append(k, stack[-1][-1])
303                        else:
304                            copy.append(k, v)
305                else:
306                    copy.args[k] = vs
307
308        return root
309
310    def copy(self: EC) -> EC:
311        return deepcopy(self)
312
313    def add_comments(self, comments: t.Optional[t.List[str]] = None, prepend: bool = False) -> None:
314        if self.comments is None:
315            self.comments = []
316
317        if comments:
318            for comment in comments:
319                _, *meta = comment.split(SQLGLOT_META)
320                if meta:
321                    for kv in "".join(meta).split(","):
322                        k, *v = kv.split("=")
323                        self.meta[k.strip()] = to_bool(v[0].strip() if v else True)
324
325                if not prepend:
326                    self.comments.append(comment)
327
328            if prepend:
329                self.comments = comments + self.comments
330
331    def set(
332        self,
333        arg_key: str,
334        value: object,
335        index: t.Optional[int] = None,
336        overwrite: bool = True,
337    ) -> None:
338        node: t.Optional[ExpressionCore] = self
339
340        while node and node._hash is not None:
341            node._hash = None
342            node = node.parent
343
344        if index is not None:
345            expressions = self.args.get(arg_key) or []
346
347            try:
348                if expressions[index] is None:
349                    return
350            except IndexError:
351                return
352
353            if value is None:
354                expressions.pop(index)
355                for v in expressions[index:]:
356                    v.index = v.index - 1
357                return
358
359            if isinstance(value, list):
360                expressions.pop(index)
361                expressions[index:index] = value
362            elif overwrite:
363                expressions[index] = value
364            else:
365                expressions.insert(index, value)
366
367            value = expressions
368        elif value is None:
369            self.args.pop(arg_key, None)
370            return
371
372        self.args[arg_key] = value
373        self._set_parent(arg_key, value, index)
374
375    def find(self, *expression_types: t.Type[EC], bfs: bool = True) -> t.Optional[EC]:
376        return next(self.find_all(*expression_types, bfs=bfs), None)
377
378    def find_all(self, *expression_types: t.Type[EC], bfs: bool = True) -> t.Iterator[EC]:
379        for expression in self.walk(bfs=bfs):
380            if isinstance(expression, expression_types):
381                yield expression
382
383    def walk(
384        self: EC,
385        bfs: bool = True,
386        prune: t.Optional[t.Callable[[EC], bool]] = None,
387    ) -> t.Iterator[EC]:
388        if bfs:
389            yield from self.bfs(prune=prune)
390        else:
391            yield from self.dfs(prune=prune)
392
393    def replace(self, expression: t.Any) -> t.Any:
394        parent = self.parent
395
396        if not parent or parent is expression:
397            return expression
398
399        key = self.arg_key
400        if not key:
401            return expression
402
403        value = parent.args.get(key)
404
405        if type(expression) is list and isinstance(value, ExpressionCore):
406            if value.parent:
407                value.parent.replace(expression)
408        else:
409            parent.set(key, expression, self.index)
410
411        if expression is not self:
412            self.parent = None
413            self.arg_key = None
414            self.index = None
415
416        return expression
417
418    def pop(self: EC) -> EC:
419        self.replace(None)
420        return self
421
422    def assert_is(self, type_: t.Type[EC]) -> EC:
423        if not isinstance(self, type_):
424            raise AssertionError(f"{self} is not {type_}.")
425        return self
426
427    def transform(
428        self, fun: t.Callable, *args: object, copy: bool = True, **kwargs: object
429    ) -> t.Any:
430        root: t.Any = None
431        new_node: t.Any = None
432
433        for node in (self.copy() if copy else self).dfs(prune=lambda n: n is not new_node):
434            parent, arg_key, index = node.parent, node.arg_key, node.index
435            new_node = fun(node, *args, **kwargs)
436
437            if not root:
438                root = new_node
439            elif parent and arg_key and new_node is not node:
440                parent.set(arg_key, new_node, index)
441
442        assert root
443        return root
POSITION_META_KEYS: Tuple[str, ...] = ('line', 'col', 'start', 'end')
SQLGLOT_META: str = 'sqlglot.meta'
UNITTEST: bool = True
@mypyc_attr(allow_interpreted_subclasses=True)
class ExpressionCore:
 21@mypyc_attr(allow_interpreted_subclasses=True)
 22class ExpressionCore:
 23    __slots__ = (
 24        "args",
 25        "parent",
 26        "arg_key",
 27        "index",
 28        "comments",
 29        "_type",
 30        "_meta",
 31        "_hash",
 32    )
 33
 34    key: t.ClassVar[str]
 35    arg_types: t.ClassVar[t.Dict[str, bool]] = {}
 36    required_args: t.ClassVar[t.Set[str]] = set()
 37    is_var_len_args: t.ClassVar[bool] = False
 38    is_func: t.ClassVar[bool] = False
 39    _hash_raw_args: t.ClassVar[bool] = False
 40
 41    def __init__(self, **args: object) -> None:
 42        self.args: t.Dict[str, t.Any] = args
 43        self.parent: t.Optional[ExpressionCore] = None
 44        self.arg_key: t.Optional[str] = None
 45        self.index: t.Optional[int] = None
 46        self.comments: t.Optional[t.List[str]] = None
 47        self._type: t.Optional[ExpressionCore] = None
 48        self._meta: t.Optional[t.Dict[str, t.Any]] = None
 49        self._hash: t.Optional[int] = None
 50
 51        for arg_key, value in self.args.items():
 52            self._set_parent(arg_key, value)
 53
 54    def _set_parent(self, arg_key: str, value: object, index: t.Optional[int] = None) -> None:
 55        if isinstance(value, ExpressionCore):
 56            value.parent = self
 57            value.arg_key = arg_key
 58            value.index = index
 59        elif isinstance(value, list):
 60            for i, v in enumerate(value):
 61                if isinstance(v, ExpressionCore):
 62                    v.parent = self
 63                    v.arg_key = arg_key
 64                    v.index = i
 65
 66    def iter_expressions(self: EC, reverse: bool = False) -> t.Iterator[EC]:
 67        for vs in reversed(self.args.values()) if reverse else self.args.values():
 68            if isinstance(vs, list):
 69                for v in reversed(vs) if reverse else vs:
 70                    if isinstance(v, ExpressionCore):
 71                        yield t.cast(EC, v)
 72            elif isinstance(vs, ExpressionCore):
 73                yield t.cast(EC, vs)
 74
 75    def bfs(self: EC, prune: t.Optional[t.Callable[[EC], bool]] = None) -> t.Iterator[EC]:
 76        queue: t.Deque[EC] = deque()
 77        queue.append(self)
 78        while queue:
 79            node = queue.popleft()
 80            yield node
 81            if prune and prune(node):
 82                continue
 83            for v in node.iter_expressions():
 84                queue.append(v)
 85
 86    def dfs(self: EC, prune: t.Optional[t.Callable[[EC], bool]] = None) -> t.Iterator[EC]:
 87        stack: t.List[EC] = [self]
 88        while stack:
 89            node = stack.pop()
 90            yield node
 91            if prune and prune(node):
 92                continue
 93            for v in node.iter_expressions(reverse=True):
 94                stack.append(v)
 95
 96    @property
 97    def meta(self) -> t.Dict[str, t.Any]:
 98        if self._meta is None:
 99            self._meta = {}
100        return self._meta
101
102    @property
103    def this(self) -> t.Any:
104        return self.args.get("this")
105
106    @property
107    def expression(self) -> t.Any:
108        return self.args.get("expression")
109
110    @property
111    def expressions(self) -> t.List[t.Any]:
112        return self.args.get("expressions") or []
113
114    def pop_comments(self) -> t.List[str]:
115        comments = self.comments or []
116        self.comments = None
117        return comments
118
119    def append(self, arg_key: str, value: t.Any) -> None:
120        if type(self.args.get(arg_key)) is not list:
121            self.args[arg_key] = []
122        self._set_parent(arg_key, value)
123        values = self.args[arg_key]
124        if hasattr(value, "parent"):
125            value.index = len(values)
126        values.append(value)
127
128    @property
129    def depth(self) -> int:
130        if self.parent:
131            return self.parent.depth + 1
132        return 0
133
134    def find_ancestor(self, *expression_types: t.Type[EC]) -> t.Optional[EC]:
135        ancestor = self.parent
136        while ancestor and not isinstance(ancestor, expression_types):
137            ancestor = ancestor.parent
138        return ancestor  # type: ignore[return-value]
139
140    @property
141    def same_parent(self) -> bool:
142        return type(self.parent) is self.__class__
143
144    def root(self) -> ExpressionCore:
145        expression = self
146        while expression.parent:
147            expression = expression.parent
148        return expression
149
150    def __eq__(self, other: object) -> bool:
151        return self is other or (type(self) is type(other) and hash(self) == hash(other))
152
153    def __hash__(self) -> int:
154        if self._hash is None:
155            nodes: t.List[ExpressionCore] = []
156            queue: t.Deque[ExpressionCore] = deque()
157            queue.append(self)
158
159            while queue:
160                node = queue.popleft()
161                nodes.append(node)
162
163                for child in node.iter_expressions():
164                    if child._hash is None:
165                        queue.append(child)
166
167            for node in reversed(nodes):
168                hash_ = hash(node.key)
169
170                if node._hash_raw_args:
171                    for k, v in sorted(node.args.items()):
172                        if v:
173                            hash_ = hash((hash_, k, v))
174                else:
175                    for k, v in sorted(node.args.items()):
176                        vt = type(v)
177
178                        if vt is list:
179                            for x in v:
180                                if x is not None and x is not False:
181                                    hash_ = hash((hash_, k, x.lower() if type(x) is str else x))
182                                else:
183                                    hash_ = hash((hash_, k))
184                        elif v is not None and v is not False:
185                            hash_ = hash((hash_, k, v.lower() if vt is str else v))
186
187                node._hash = hash_
188        assert self._hash
189        return self._hash
190
191    def error_messages(self, args: t.Optional[t.Sequence] = None) -> t.List[str]:
192        errors: t.List[str] = []
193
194        if UNITTEST:
195            for k in self.args:
196                if k not in self.arg_types:
197                    raise TypeError(f"Unexpected keyword: '{k}' for {self.__class__}")
198
199        for k in self.required_args:
200            v = self.args.get(k)
201            if v is None or (isinstance(v, list) and not v):
202                errors.append(f"Required keyword: '{k}' missing for {self.__class__}")
203
204        if args and self.is_func and len(args) > len(self.arg_types) and not self.is_var_len_args:
205            errors.append(
206                f"The number of provided arguments ({len(args)}) is greater than "
207                f"the maximum number of supported arguments ({len(self.arg_types)})"
208            )
209
210        return errors
211
212    def update_positions(
213        self: EC,
214        other: t.Optional[ExpressionCore | Token] = None,
215        line: t.Optional[int] = None,
216        col: t.Optional[int] = None,
217        start: t.Optional[int] = None,
218        end: t.Optional[int] = None,
219    ) -> EC:
220        if other is None:
221            self.meta["line"] = line
222            self.meta["col"] = col
223            self.meta["start"] = start
224            self.meta["end"] = end
225        elif isinstance(other, ExpressionCore):
226            for k in POSITION_META_KEYS:
227                if k in other.meta:
228                    self.meta[k] = other.meta[k]
229        else:
230            # Token: has .line, .col, .start, .end attributes
231            self.meta["line"] = other.line
232            self.meta["col"] = other.col
233            self.meta["start"] = other.start
234            self.meta["end"] = other.end
235        return self
236
237    def to_py(self) -> t.Any:
238        raise ValueError(f"{self} cannot be converted to a Python object.")
239
240    def text(self, key: str) -> str:
241        field = self.args.get(key)
242        if isinstance(field, str):
243            return field
244        return ""
245
246    @property
247    def name(self) -> str:
248        return self.text("this")
249
250    @property
251    def alias(self) -> str:
252        alias = self.args.get("alias")
253        if isinstance(alias, ExpressionCore):
254            return alias.name
255        return self.text("alias")
256
257    @property
258    def alias_column_names(self) -> t.List[str]:
259        table_alias = self.args.get("alias")
260        if not table_alias:
261            return []
262        return [c.name for c in table_alias.args.get("columns") or []]
263
264    @property
265    def alias_or_name(self) -> str:
266        return self.alias or self.name
267
268    @property
269    def output_name(self) -> str:
270        return ""
271
272    def is_leaf(self) -> bool:
273        return not any(
274            (isinstance(v, ExpressionCore) or type(v) is list) and v for v in self.args.values()
275        )
276
277    def __deepcopy__(self, memo: t.Any) -> ExpressionCore:
278        root = self.__class__()
279        stack: t.List[t.Tuple[ExpressionCore, ExpressionCore]] = [(self, root)]
280
281        while stack:
282            node, copy = stack.pop()
283
284            if node.comments is not None:
285                copy.comments = deepcopy(node.comments)
286            if node._type is not None:
287                copy._type = deepcopy(node._type)
288            if node._meta is not None:
289                copy._meta = deepcopy(node._meta)
290            if node._hash is not None:
291                copy._hash = node._hash
292
293            for k, vs in node.args.items():
294                if isinstance(vs, ExpressionCore):
295                    stack.append((vs, vs.__class__()))
296                    copy.set(k, stack[-1][-1])
297                elif type(vs) is list:
298                    copy.args[k] = []
299
300                    for v in vs:
301                        if isinstance(v, ExpressionCore):
302                            stack.append((v, v.__class__()))
303                            copy.append(k, stack[-1][-1])
304                        else:
305                            copy.append(k, v)
306                else:
307                    copy.args[k] = vs
308
309        return root
310
311    def copy(self: EC) -> EC:
312        return deepcopy(self)
313
314    def add_comments(self, comments: t.Optional[t.List[str]] = None, prepend: bool = False) -> None:
315        if self.comments is None:
316            self.comments = []
317
318        if comments:
319            for comment in comments:
320                _, *meta = comment.split(SQLGLOT_META)
321                if meta:
322                    for kv in "".join(meta).split(","):
323                        k, *v = kv.split("=")
324                        self.meta[k.strip()] = to_bool(v[0].strip() if v else True)
325
326                if not prepend:
327                    self.comments.append(comment)
328
329            if prepend:
330                self.comments = comments + self.comments
331
332    def set(
333        self,
334        arg_key: str,
335        value: object,
336        index: t.Optional[int] = None,
337        overwrite: bool = True,
338    ) -> None:
339        node: t.Optional[ExpressionCore] = self
340
341        while node and node._hash is not None:
342            node._hash = None
343            node = node.parent
344
345        if index is not None:
346            expressions = self.args.get(arg_key) or []
347
348            try:
349                if expressions[index] is None:
350                    return
351            except IndexError:
352                return
353
354            if value is None:
355                expressions.pop(index)
356                for v in expressions[index:]:
357                    v.index = v.index - 1
358                return
359
360            if isinstance(value, list):
361                expressions.pop(index)
362                expressions[index:index] = value
363            elif overwrite:
364                expressions[index] = value
365            else:
366                expressions.insert(index, value)
367
368            value = expressions
369        elif value is None:
370            self.args.pop(arg_key, None)
371            return
372
373        self.args[arg_key] = value
374        self._set_parent(arg_key, value, index)
375
376    def find(self, *expression_types: t.Type[EC], bfs: bool = True) -> t.Optional[EC]:
377        return next(self.find_all(*expression_types, bfs=bfs), None)
378
379    def find_all(self, *expression_types: t.Type[EC], bfs: bool = True) -> t.Iterator[EC]:
380        for expression in self.walk(bfs=bfs):
381            if isinstance(expression, expression_types):
382                yield expression
383
384    def walk(
385        self: EC,
386        bfs: bool = True,
387        prune: t.Optional[t.Callable[[EC], bool]] = None,
388    ) -> t.Iterator[EC]:
389        if bfs:
390            yield from self.bfs(prune=prune)
391        else:
392            yield from self.dfs(prune=prune)
393
394    def replace(self, expression: t.Any) -> t.Any:
395        parent = self.parent
396
397        if not parent or parent is expression:
398            return expression
399
400        key = self.arg_key
401        if not key:
402            return expression
403
404        value = parent.args.get(key)
405
406        if type(expression) is list and isinstance(value, ExpressionCore):
407            if value.parent:
408                value.parent.replace(expression)
409        else:
410            parent.set(key, expression, self.index)
411
412        if expression is not self:
413            self.parent = None
414            self.arg_key = None
415            self.index = None
416
417        return expression
418
419    def pop(self: EC) -> EC:
420        self.replace(None)
421        return self
422
423    def assert_is(self, type_: t.Type[EC]) -> EC:
424        if not isinstance(self, type_):
425            raise AssertionError(f"{self} is not {type_}.")
426        return self
427
428    def transform(
429        self, fun: t.Callable, *args: object, copy: bool = True, **kwargs: object
430    ) -> t.Any:
431        root: t.Any = None
432        new_node: t.Any = None
433
434        for node in (self.copy() if copy else self).dfs(prune=lambda n: n is not new_node):
435            parent, arg_key, index = node.parent, node.arg_key, node.index
436            new_node = fun(node, *args, **kwargs)
437
438            if not root:
439                root = new_node
440            elif parent and arg_key and new_node is not node:
441                parent.set(arg_key, new_node, index)
442
443        assert root
444        return root
ExpressionCore(**args: object)
41    def __init__(self, **args: object) -> None:
42        self.args: t.Dict[str, t.Any] = args
43        self.parent: t.Optional[ExpressionCore] = None
44        self.arg_key: t.Optional[str] = None
45        self.index: t.Optional[int] = None
46        self.comments: t.Optional[t.List[str]] = None
47        self._type: t.Optional[ExpressionCore] = None
48        self._meta: t.Optional[t.Dict[str, t.Any]] = None
49        self._hash: t.Optional[int] = None
50
51        for arg_key, value in self.args.items():
52            self._set_parent(arg_key, value)
key: ClassVar[str]
arg_types: ClassVar[Dict[str, bool]] = {}
required_args: ClassVar[Set[str]] = set()
is_var_len_args: ClassVar[bool] = False
is_func: ClassVar[bool] = False
args: Dict[str, Any]
parent: Optional[ExpressionCore]
arg_key: Optional[str]
index: Optional[int]
comments: Optional[List[str]]
def iter_expressions(self: ~EC, reverse: bool = False) -> Iterator[~EC]:
66    def iter_expressions(self: EC, reverse: bool = False) -> t.Iterator[EC]:
67        for vs in reversed(self.args.values()) if reverse else self.args.values():
68            if isinstance(vs, list):
69                for v in reversed(vs) if reverse else vs:
70                    if isinstance(v, ExpressionCore):
71                        yield t.cast(EC, v)
72            elif isinstance(vs, ExpressionCore):
73                yield t.cast(EC, vs)
def bfs( self: ~EC, prune: Optional[Callable[[~EC], bool]] = None) -> Iterator[~EC]:
75    def bfs(self: EC, prune: t.Optional[t.Callable[[EC], bool]] = None) -> t.Iterator[EC]:
76        queue: t.Deque[EC] = deque()
77        queue.append(self)
78        while queue:
79            node = queue.popleft()
80            yield node
81            if prune and prune(node):
82                continue
83            for v in node.iter_expressions():
84                queue.append(v)
def dfs( self: ~EC, prune: Optional[Callable[[~EC], bool]] = None) -> Iterator[~EC]:
86    def dfs(self: EC, prune: t.Optional[t.Callable[[EC], bool]] = None) -> t.Iterator[EC]:
87        stack: t.List[EC] = [self]
88        while stack:
89            node = stack.pop()
90            yield node
91            if prune and prune(node):
92                continue
93            for v in node.iter_expressions(reverse=True):
94                stack.append(v)
meta: Dict[str, Any]
 96    @property
 97    def meta(self) -> t.Dict[str, t.Any]:
 98        if self._meta is None:
 99            self._meta = {}
100        return self._meta
this: Any
102    @property
103    def this(self) -> t.Any:
104        return self.args.get("this")
expression: Any
106    @property
107    def expression(self) -> t.Any:
108        return self.args.get("expression")
expressions: List[Any]
110    @property
111    def expressions(self) -> t.List[t.Any]:
112        return self.args.get("expressions") or []
def pop_comments(self) -> List[str]:
114    def pop_comments(self) -> t.List[str]:
115        comments = self.comments or []
116        self.comments = None
117        return comments
def append(self, arg_key: str, value: Any) -> None:
119    def append(self, arg_key: str, value: t.Any) -> None:
120        if type(self.args.get(arg_key)) is not list:
121            self.args[arg_key] = []
122        self._set_parent(arg_key, value)
123        values = self.args[arg_key]
124        if hasattr(value, "parent"):
125            value.index = len(values)
126        values.append(value)
depth: int
128    @property
129    def depth(self) -> int:
130        if self.parent:
131            return self.parent.depth + 1
132        return 0
def find_ancestor(self, *expression_types: Type[~EC]) -> Optional[~EC]:
134    def find_ancestor(self, *expression_types: t.Type[EC]) -> t.Optional[EC]:
135        ancestor = self.parent
136        while ancestor and not isinstance(ancestor, expression_types):
137            ancestor = ancestor.parent
138        return ancestor  # type: ignore[return-value]
same_parent: bool
140    @property
141    def same_parent(self) -> bool:
142        return type(self.parent) is self.__class__
def root(self) -> ExpressionCore:
144    def root(self) -> ExpressionCore:
145        expression = self
146        while expression.parent:
147            expression = expression.parent
148        return expression
def error_messages(self, args: Optional[Sequence] = None) -> List[str]:
191    def error_messages(self, args: t.Optional[t.Sequence] = None) -> t.List[str]:
192        errors: t.List[str] = []
193
194        if UNITTEST:
195            for k in self.args:
196                if k not in self.arg_types:
197                    raise TypeError(f"Unexpected keyword: '{k}' for {self.__class__}")
198
199        for k in self.required_args:
200            v = self.args.get(k)
201            if v is None or (isinstance(v, list) and not v):
202                errors.append(f"Required keyword: '{k}' missing for {self.__class__}")
203
204        if args and self.is_func and len(args) > len(self.arg_types) and not self.is_var_len_args:
205            errors.append(
206                f"The number of provided arguments ({len(args)}) is greater than "
207                f"the maximum number of supported arguments ({len(self.arg_types)})"
208            )
209
210        return errors
def update_positions( self: ~EC, other: Union[ExpressionCore, sqlglot.tokenizer_core.Token, NoneType] = None, line: Optional[int] = None, col: Optional[int] = None, start: Optional[int] = None, end: Optional[int] = None) -> ~EC:
212    def update_positions(
213        self: EC,
214        other: t.Optional[ExpressionCore | Token] = None,
215        line: t.Optional[int] = None,
216        col: t.Optional[int] = None,
217        start: t.Optional[int] = None,
218        end: t.Optional[int] = None,
219    ) -> EC:
220        if other is None:
221            self.meta["line"] = line
222            self.meta["col"] = col
223            self.meta["start"] = start
224            self.meta["end"] = end
225        elif isinstance(other, ExpressionCore):
226            for k in POSITION_META_KEYS:
227                if k in other.meta:
228                    self.meta[k] = other.meta[k]
229        else:
230            # Token: has .line, .col, .start, .end attributes
231            self.meta["line"] = other.line
232            self.meta["col"] = other.col
233            self.meta["start"] = other.start
234            self.meta["end"] = other.end
235        return self
def to_py(self) -> Any:
237    def to_py(self) -> t.Any:
238        raise ValueError(f"{self} cannot be converted to a Python object.")
def text(self, key: str) -> str:
240    def text(self, key: str) -> str:
241        field = self.args.get(key)
242        if isinstance(field, str):
243            return field
244        return ""
name: str
246    @property
247    def name(self) -> str:
248        return self.text("this")
alias: str
250    @property
251    def alias(self) -> str:
252        alias = self.args.get("alias")
253        if isinstance(alias, ExpressionCore):
254            return alias.name
255        return self.text("alias")
alias_column_names: List[str]
257    @property
258    def alias_column_names(self) -> t.List[str]:
259        table_alias = self.args.get("alias")
260        if not table_alias:
261            return []
262        return [c.name for c in table_alias.args.get("columns") or []]
alias_or_name: str
264    @property
265    def alias_or_name(self) -> str:
266        return self.alias or self.name
output_name: str
268    @property
269    def output_name(self) -> str:
270        return ""
def is_leaf(self) -> bool:
272    def is_leaf(self) -> bool:
273        return not any(
274            (isinstance(v, ExpressionCore) or type(v) is list) and v for v in self.args.values()
275        )
def copy(self: ~EC) -> ~EC:
311    def copy(self: EC) -> EC:
312        return deepcopy(self)
def add_comments( self, comments: Optional[List[str]] = None, prepend: bool = False) -> None:
314    def add_comments(self, comments: t.Optional[t.List[str]] = None, prepend: bool = False) -> None:
315        if self.comments is None:
316            self.comments = []
317
318        if comments:
319            for comment in comments:
320                _, *meta = comment.split(SQLGLOT_META)
321                if meta:
322                    for kv in "".join(meta).split(","):
323                        k, *v = kv.split("=")
324                        self.meta[k.strip()] = to_bool(v[0].strip() if v else True)
325
326                if not prepend:
327                    self.comments.append(comment)
328
329            if prepend:
330                self.comments = comments + self.comments
def set( self, arg_key: str, value: object, index: Optional[int] = None, overwrite: bool = True) -> None:
332    def set(
333        self,
334        arg_key: str,
335        value: object,
336        index: t.Optional[int] = None,
337        overwrite: bool = True,
338    ) -> None:
339        node: t.Optional[ExpressionCore] = self
340
341        while node and node._hash is not None:
342            node._hash = None
343            node = node.parent
344
345        if index is not None:
346            expressions = self.args.get(arg_key) or []
347
348            try:
349                if expressions[index] is None:
350                    return
351            except IndexError:
352                return
353
354            if value is None:
355                expressions.pop(index)
356                for v in expressions[index:]:
357                    v.index = v.index - 1
358                return
359
360            if isinstance(value, list):
361                expressions.pop(index)
362                expressions[index:index] = value
363            elif overwrite:
364                expressions[index] = value
365            else:
366                expressions.insert(index, value)
367
368            value = expressions
369        elif value is None:
370            self.args.pop(arg_key, None)
371            return
372
373        self.args[arg_key] = value
374        self._set_parent(arg_key, value, index)
def find(self, *expression_types: Type[~EC], bfs: bool = True) -> Optional[~EC]:
376    def find(self, *expression_types: t.Type[EC], bfs: bool = True) -> t.Optional[EC]:
377        return next(self.find_all(*expression_types, bfs=bfs), None)
def find_all(self, *expression_types: Type[~EC], bfs: bool = True) -> Iterator[~EC]:
379    def find_all(self, *expression_types: t.Type[EC], bfs: bool = True) -> t.Iterator[EC]:
380        for expression in self.walk(bfs=bfs):
381            if isinstance(expression, expression_types):
382                yield expression
def walk( self: ~EC, bfs: bool = True, prune: Optional[Callable[[~EC], bool]] = None) -> Iterator[~EC]:
384    def walk(
385        self: EC,
386        bfs: bool = True,
387        prune: t.Optional[t.Callable[[EC], bool]] = None,
388    ) -> t.Iterator[EC]:
389        if bfs:
390            yield from self.bfs(prune=prune)
391        else:
392            yield from self.dfs(prune=prune)
def replace(self, expression: Any) -> Any:
394    def replace(self, expression: t.Any) -> t.Any:
395        parent = self.parent
396
397        if not parent or parent is expression:
398            return expression
399
400        key = self.arg_key
401        if not key:
402            return expression
403
404        value = parent.args.get(key)
405
406        if type(expression) is list and isinstance(value, ExpressionCore):
407            if value.parent:
408                value.parent.replace(expression)
409        else:
410            parent.set(key, expression, self.index)
411
412        if expression is not self:
413            self.parent = None
414            self.arg_key = None
415            self.index = None
416
417        return expression
def pop(self: ~EC) -> ~EC:
419    def pop(self: EC) -> EC:
420        self.replace(None)
421        return self
def assert_is(self, type_: Type[~EC]) -> ~EC:
423    def assert_is(self, type_: t.Type[EC]) -> EC:
424        if not isinstance(self, type_):
425            raise AssertionError(f"{self} is not {type_}.")
426        return self
def transform( self, fun: Callable, *args: object, copy: bool = True, **kwargs: object) -> Any:
428    def transform(
429        self, fun: t.Callable, *args: object, copy: bool = True, **kwargs: object
430    ) -> t.Any:
431        root: t.Any = None
432        new_node: t.Any = None
433
434        for node in (self.copy() if copy else self).dfs(prune=lambda n: n is not new_node):
435            parent, arg_key, index = node.parent, node.arg_key, node.index
436            new_node = fun(node, *args, **kwargs)
437
438            if not root:
439                root = new_node
440            elif parent and arg_key and new_node is not node:
441                parent.set(arg_key, new_node, index)
442
443        assert root
444        return root