sqlglot.expression_core
1from __future__ import annotations 2 3import sys 4import typing as t 5from collections import deque 6from copy import deepcopy 7 8from sqlglot.helper import mypyc_attr, to_bool 9from sqlglot.tokenizer_core import Token 10 11 12EC = t.TypeVar("EC", bound="ExpressionCore") 13 14 15POSITION_META_KEYS: t.Tuple[str, ...] = ("line", "col", "start", "end") 16SQLGLOT_META: str = "sqlglot.meta" 17UNITTEST: bool = "unittest" in sys.modules or "pytest" in sys.modules 18 19 20@mypyc_attr(allow_interpreted_subclasses=True) 21class ExpressionCore: 22 __slots__ = ( 23 "args", 24 "parent", 25 "arg_key", 26 "index", 27 "comments", 28 "_type", 29 "_meta", 30 "_hash", 31 ) 32 33 key: t.ClassVar[str] 34 arg_types: t.ClassVar[t.Dict[str, bool]] = {} 35 required_args: t.ClassVar[t.Set[str]] = set() 36 is_var_len_args: t.ClassVar[bool] = False 37 is_func: t.ClassVar[bool] = False 38 _hash_raw_args: t.ClassVar[bool] = False 39 40 def __init__(self, **args: object) -> None: 41 self.args: t.Dict[str, t.Any] = args 42 self.parent: t.Optional[ExpressionCore] = None 43 self.arg_key: t.Optional[str] = None 44 self.index: t.Optional[int] = None 45 self.comments: t.Optional[t.List[str]] = None 46 self._type: t.Optional[ExpressionCore] = None 47 self._meta: t.Optional[t.Dict[str, t.Any]] = None 48 self._hash: t.Optional[int] = None 49 50 for arg_key, value in self.args.items(): 51 self._set_parent(arg_key, value) 52 53 def _set_parent(self, arg_key: str, value: object, index: t.Optional[int] = None) -> None: 54 if isinstance(value, ExpressionCore): 55 value.parent = self 56 value.arg_key = arg_key 57 value.index = index 58 elif isinstance(value, list): 59 for i, v in enumerate(value): 60 if isinstance(v, ExpressionCore): 61 v.parent = self 62 v.arg_key = arg_key 63 v.index = i 64 65 def iter_expressions(self: EC, reverse: bool = False) -> t.Iterator[EC]: 66 for vs in reversed(self.args.values()) if reverse else self.args.values(): 67 if isinstance(vs, list): 68 for v in reversed(vs) if reverse else vs: 69 if isinstance(v, ExpressionCore): 70 yield t.cast(EC, v) 71 elif isinstance(vs, ExpressionCore): 72 yield t.cast(EC, vs) 73 74 def bfs(self: EC, prune: t.Optional[t.Callable[[EC], bool]] = None) -> t.Iterator[EC]: 75 queue: t.Deque[EC] = deque() 76 queue.append(self) 77 while queue: 78 node = queue.popleft() 79 yield node 80 if prune and prune(node): 81 continue 82 for v in node.iter_expressions(): 83 queue.append(v) 84 85 def dfs(self: EC, prune: t.Optional[t.Callable[[EC], bool]] = None) -> t.Iterator[EC]: 86 stack: t.List[EC] = [self] 87 while stack: 88 node = stack.pop() 89 yield node 90 if prune and prune(node): 91 continue 92 for v in node.iter_expressions(reverse=True): 93 stack.append(v) 94 95 @property 96 def meta(self) -> t.Dict[str, t.Any]: 97 if self._meta is None: 98 self._meta = {} 99 return self._meta 100 101 @property 102 def this(self) -> t.Any: 103 return self.args.get("this") 104 105 @property 106 def expression(self) -> t.Any: 107 return self.args.get("expression") 108 109 @property 110 def expressions(self) -> t.List[t.Any]: 111 return self.args.get("expressions") or [] 112 113 def pop_comments(self) -> t.List[str]: 114 comments = self.comments or [] 115 self.comments = None 116 return comments 117 118 def append(self, arg_key: str, value: t.Any) -> None: 119 if type(self.args.get(arg_key)) is not list: 120 self.args[arg_key] = [] 121 self._set_parent(arg_key, value) 122 values = self.args[arg_key] 123 if hasattr(value, "parent"): 124 value.index = len(values) 125 values.append(value) 126 127 @property 128 def depth(self) -> int: 129 if self.parent: 130 return self.parent.depth + 1 131 return 0 132 133 def find_ancestor(self, *expression_types: t.Type[EC]) -> t.Optional[EC]: 134 ancestor = self.parent 135 while ancestor and not isinstance(ancestor, expression_types): 136 ancestor = ancestor.parent 137 return ancestor # type: ignore[return-value] 138 139 @property 140 def same_parent(self) -> bool: 141 return type(self.parent) is self.__class__ 142 143 def root(self) -> ExpressionCore: 144 expression = self 145 while expression.parent: 146 expression = expression.parent 147 return expression 148 149 def __eq__(self, other: object) -> bool: 150 return self is other or (type(self) is type(other) and hash(self) == hash(other)) 151 152 def __hash__(self) -> int: 153 if self._hash is None: 154 nodes: t.List[ExpressionCore] = [] 155 queue: t.Deque[ExpressionCore] = deque() 156 queue.append(self) 157 158 while queue: 159 node = queue.popleft() 160 nodes.append(node) 161 162 for child in node.iter_expressions(): 163 if child._hash is None: 164 queue.append(child) 165 166 for node in reversed(nodes): 167 hash_ = hash(node.key) 168 169 if node._hash_raw_args: 170 for k, v in sorted(node.args.items()): 171 if v: 172 hash_ = hash((hash_, k, v)) 173 else: 174 for k, v in sorted(node.args.items()): 175 vt = type(v) 176 177 if vt is list: 178 for x in v: 179 if x is not None and x is not False: 180 hash_ = hash((hash_, k, x.lower() if type(x) is str else x)) 181 else: 182 hash_ = hash((hash_, k)) 183 elif v is not None and v is not False: 184 hash_ = hash((hash_, k, v.lower() if vt is str else v)) 185 186 node._hash = hash_ 187 assert self._hash 188 return self._hash 189 190 def error_messages(self, args: t.Optional[t.Sequence] = None) -> t.List[str]: 191 errors: t.List[str] = [] 192 193 if UNITTEST: 194 for k in self.args: 195 if k not in self.arg_types: 196 raise TypeError(f"Unexpected keyword: '{k}' for {self.__class__}") 197 198 for k in self.required_args: 199 v = self.args.get(k) 200 if v is None or (isinstance(v, list) and not v): 201 errors.append(f"Required keyword: '{k}' missing for {self.__class__}") 202 203 if args and self.is_func and len(args) > len(self.arg_types) and not self.is_var_len_args: 204 errors.append( 205 f"The number of provided arguments ({len(args)}) is greater than " 206 f"the maximum number of supported arguments ({len(self.arg_types)})" 207 ) 208 209 return errors 210 211 def update_positions( 212 self: EC, 213 other: t.Optional[ExpressionCore | Token] = None, 214 line: t.Optional[int] = None, 215 col: t.Optional[int] = None, 216 start: t.Optional[int] = None, 217 end: t.Optional[int] = None, 218 ) -> EC: 219 if other is None: 220 self.meta["line"] = line 221 self.meta["col"] = col 222 self.meta["start"] = start 223 self.meta["end"] = end 224 elif isinstance(other, ExpressionCore): 225 for k in POSITION_META_KEYS: 226 if k in other.meta: 227 self.meta[k] = other.meta[k] 228 else: 229 # Token: has .line, .col, .start, .end attributes 230 self.meta["line"] = other.line 231 self.meta["col"] = other.col 232 self.meta["start"] = other.start 233 self.meta["end"] = other.end 234 return self 235 236 def to_py(self) -> t.Any: 237 raise ValueError(f"{self} cannot be converted to a Python object.") 238 239 def text(self, key: str) -> str: 240 field = self.args.get(key) 241 if isinstance(field, str): 242 return field 243 return "" 244 245 @property 246 def name(self) -> str: 247 return self.text("this") 248 249 @property 250 def alias(self) -> str: 251 alias = self.args.get("alias") 252 if isinstance(alias, ExpressionCore): 253 return alias.name 254 return self.text("alias") 255 256 @property 257 def alias_column_names(self) -> t.List[str]: 258 table_alias = self.args.get("alias") 259 if not table_alias: 260 return [] 261 return [c.name for c in table_alias.args.get("columns") or []] 262 263 @property 264 def alias_or_name(self) -> str: 265 return self.alias or self.name 266 267 @property 268 def output_name(self) -> str: 269 return "" 270 271 def is_leaf(self) -> bool: 272 return not any( 273 (isinstance(v, ExpressionCore) or type(v) is list) and v for v in self.args.values() 274 ) 275 276 def __deepcopy__(self, memo: t.Any) -> ExpressionCore: 277 root = self.__class__() 278 stack: t.List[t.Tuple[ExpressionCore, ExpressionCore]] = [(self, root)] 279 280 while stack: 281 node, copy = stack.pop() 282 283 if node.comments is not None: 284 copy.comments = deepcopy(node.comments) 285 if node._type is not None: 286 copy._type = deepcopy(node._type) 287 if node._meta is not None: 288 copy._meta = deepcopy(node._meta) 289 if node._hash is not None: 290 copy._hash = node._hash 291 292 for k, vs in node.args.items(): 293 if isinstance(vs, ExpressionCore): 294 stack.append((vs, vs.__class__())) 295 copy.set(k, stack[-1][-1]) 296 elif type(vs) is list: 297 copy.args[k] = [] 298 299 for v in vs: 300 if isinstance(v, ExpressionCore): 301 stack.append((v, v.__class__())) 302 copy.append(k, stack[-1][-1]) 303 else: 304 copy.append(k, v) 305 else: 306 copy.args[k] = vs 307 308 return root 309 310 def copy(self: EC) -> EC: 311 return deepcopy(self) 312 313 def add_comments(self, comments: t.Optional[t.List[str]] = None, prepend: bool = False) -> None: 314 if self.comments is None: 315 self.comments = [] 316 317 if comments: 318 for comment in comments: 319 _, *meta = comment.split(SQLGLOT_META) 320 if meta: 321 for kv in "".join(meta).split(","): 322 k, *v = kv.split("=") 323 self.meta[k.strip()] = to_bool(v[0].strip() if v else True) 324 325 if not prepend: 326 self.comments.append(comment) 327 328 if prepend: 329 self.comments = comments + self.comments 330 331 def set( 332 self, 333 arg_key: str, 334 value: object, 335 index: t.Optional[int] = None, 336 overwrite: bool = True, 337 ) -> None: 338 node: t.Optional[ExpressionCore] = self 339 340 while node and node._hash is not None: 341 node._hash = None 342 node = node.parent 343 344 if index is not None: 345 expressions = self.args.get(arg_key) or [] 346 347 try: 348 if expressions[index] is None: 349 return 350 except IndexError: 351 return 352 353 if value is None: 354 expressions.pop(index) 355 for v in expressions[index:]: 356 v.index = v.index - 1 357 return 358 359 if isinstance(value, list): 360 expressions.pop(index) 361 expressions[index:index] = value 362 elif overwrite: 363 expressions[index] = value 364 else: 365 expressions.insert(index, value) 366 367 value = expressions 368 elif value is None: 369 self.args.pop(arg_key, None) 370 return 371 372 self.args[arg_key] = value 373 self._set_parent(arg_key, value, index) 374 375 def find(self, *expression_types: t.Type[EC], bfs: bool = True) -> t.Optional[EC]: 376 return next(self.find_all(*expression_types, bfs=bfs), None) 377 378 def find_all(self, *expression_types: t.Type[EC], bfs: bool = True) -> t.Iterator[EC]: 379 for expression in self.walk(bfs=bfs): 380 if isinstance(expression, expression_types): 381 yield expression 382 383 def walk( 384 self: EC, 385 bfs: bool = True, 386 prune: t.Optional[t.Callable[[EC], bool]] = None, 387 ) -> t.Iterator[EC]: 388 if bfs: 389 yield from self.bfs(prune=prune) 390 else: 391 yield from self.dfs(prune=prune) 392 393 def replace(self, expression: t.Any) -> t.Any: 394 parent = self.parent 395 396 if not parent or parent is expression: 397 return expression 398 399 key = self.arg_key 400 if not key: 401 return expression 402 403 value = parent.args.get(key) 404 405 if type(expression) is list and isinstance(value, ExpressionCore): 406 if value.parent: 407 value.parent.replace(expression) 408 else: 409 parent.set(key, expression, self.index) 410 411 if expression is not self: 412 self.parent = None 413 self.arg_key = None 414 self.index = None 415 416 return expression 417 418 def pop(self: EC) -> EC: 419 self.replace(None) 420 return self 421 422 def assert_is(self, type_: t.Type[EC]) -> EC: 423 if not isinstance(self, type_): 424 raise AssertionError(f"{self} is not {type_}.") 425 return self 426 427 def transform( 428 self, fun: t.Callable, *args: object, copy: bool = True, **kwargs: object 429 ) -> t.Any: 430 root: t.Any = None 431 new_node: t.Any = None 432 433 for node in (self.copy() if copy else self).dfs(prune=lambda n: n is not new_node): 434 parent, arg_key, index = node.parent, node.arg_key, node.index 435 new_node = fun(node, *args, **kwargs) 436 437 if not root: 438 root = new_node 439 elif parent and arg_key and new_node is not node: 440 parent.set(arg_key, new_node, index) 441 442 assert root 443 return root
POSITION_META_KEYS: Tuple[str, ...] =
('line', 'col', 'start', 'end')
SQLGLOT_META: str =
'sqlglot.meta'
UNITTEST: bool =
True
@mypyc_attr(allow_interpreted_subclasses=True)
class
ExpressionCore:
21@mypyc_attr(allow_interpreted_subclasses=True) 22class ExpressionCore: 23 __slots__ = ( 24 "args", 25 "parent", 26 "arg_key", 27 "index", 28 "comments", 29 "_type", 30 "_meta", 31 "_hash", 32 ) 33 34 key: t.ClassVar[str] 35 arg_types: t.ClassVar[t.Dict[str, bool]] = {} 36 required_args: t.ClassVar[t.Set[str]] = set() 37 is_var_len_args: t.ClassVar[bool] = False 38 is_func: t.ClassVar[bool] = False 39 _hash_raw_args: t.ClassVar[bool] = False 40 41 def __init__(self, **args: object) -> None: 42 self.args: t.Dict[str, t.Any] = args 43 self.parent: t.Optional[ExpressionCore] = None 44 self.arg_key: t.Optional[str] = None 45 self.index: t.Optional[int] = None 46 self.comments: t.Optional[t.List[str]] = None 47 self._type: t.Optional[ExpressionCore] = None 48 self._meta: t.Optional[t.Dict[str, t.Any]] = None 49 self._hash: t.Optional[int] = None 50 51 for arg_key, value in self.args.items(): 52 self._set_parent(arg_key, value) 53 54 def _set_parent(self, arg_key: str, value: object, index: t.Optional[int] = None) -> None: 55 if isinstance(value, ExpressionCore): 56 value.parent = self 57 value.arg_key = arg_key 58 value.index = index 59 elif isinstance(value, list): 60 for i, v in enumerate(value): 61 if isinstance(v, ExpressionCore): 62 v.parent = self 63 v.arg_key = arg_key 64 v.index = i 65 66 def iter_expressions(self: EC, reverse: bool = False) -> t.Iterator[EC]: 67 for vs in reversed(self.args.values()) if reverse else self.args.values(): 68 if isinstance(vs, list): 69 for v in reversed(vs) if reverse else vs: 70 if isinstance(v, ExpressionCore): 71 yield t.cast(EC, v) 72 elif isinstance(vs, ExpressionCore): 73 yield t.cast(EC, vs) 74 75 def bfs(self: EC, prune: t.Optional[t.Callable[[EC], bool]] = None) -> t.Iterator[EC]: 76 queue: t.Deque[EC] = deque() 77 queue.append(self) 78 while queue: 79 node = queue.popleft() 80 yield node 81 if prune and prune(node): 82 continue 83 for v in node.iter_expressions(): 84 queue.append(v) 85 86 def dfs(self: EC, prune: t.Optional[t.Callable[[EC], bool]] = None) -> t.Iterator[EC]: 87 stack: t.List[EC] = [self] 88 while stack: 89 node = stack.pop() 90 yield node 91 if prune and prune(node): 92 continue 93 for v in node.iter_expressions(reverse=True): 94 stack.append(v) 95 96 @property 97 def meta(self) -> t.Dict[str, t.Any]: 98 if self._meta is None: 99 self._meta = {} 100 return self._meta 101 102 @property 103 def this(self) -> t.Any: 104 return self.args.get("this") 105 106 @property 107 def expression(self) -> t.Any: 108 return self.args.get("expression") 109 110 @property 111 def expressions(self) -> t.List[t.Any]: 112 return self.args.get("expressions") or [] 113 114 def pop_comments(self) -> t.List[str]: 115 comments = self.comments or [] 116 self.comments = None 117 return comments 118 119 def append(self, arg_key: str, value: t.Any) -> None: 120 if type(self.args.get(arg_key)) is not list: 121 self.args[arg_key] = [] 122 self._set_parent(arg_key, value) 123 values = self.args[arg_key] 124 if hasattr(value, "parent"): 125 value.index = len(values) 126 values.append(value) 127 128 @property 129 def depth(self) -> int: 130 if self.parent: 131 return self.parent.depth + 1 132 return 0 133 134 def find_ancestor(self, *expression_types: t.Type[EC]) -> t.Optional[EC]: 135 ancestor = self.parent 136 while ancestor and not isinstance(ancestor, expression_types): 137 ancestor = ancestor.parent 138 return ancestor # type: ignore[return-value] 139 140 @property 141 def same_parent(self) -> bool: 142 return type(self.parent) is self.__class__ 143 144 def root(self) -> ExpressionCore: 145 expression = self 146 while expression.parent: 147 expression = expression.parent 148 return expression 149 150 def __eq__(self, other: object) -> bool: 151 return self is other or (type(self) is type(other) and hash(self) == hash(other)) 152 153 def __hash__(self) -> int: 154 if self._hash is None: 155 nodes: t.List[ExpressionCore] = [] 156 queue: t.Deque[ExpressionCore] = deque() 157 queue.append(self) 158 159 while queue: 160 node = queue.popleft() 161 nodes.append(node) 162 163 for child in node.iter_expressions(): 164 if child._hash is None: 165 queue.append(child) 166 167 for node in reversed(nodes): 168 hash_ = hash(node.key) 169 170 if node._hash_raw_args: 171 for k, v in sorted(node.args.items()): 172 if v: 173 hash_ = hash((hash_, k, v)) 174 else: 175 for k, v in sorted(node.args.items()): 176 vt = type(v) 177 178 if vt is list: 179 for x in v: 180 if x is not None and x is not False: 181 hash_ = hash((hash_, k, x.lower() if type(x) is str else x)) 182 else: 183 hash_ = hash((hash_, k)) 184 elif v is not None and v is not False: 185 hash_ = hash((hash_, k, v.lower() if vt is str else v)) 186 187 node._hash = hash_ 188 assert self._hash 189 return self._hash 190 191 def error_messages(self, args: t.Optional[t.Sequence] = None) -> t.List[str]: 192 errors: t.List[str] = [] 193 194 if UNITTEST: 195 for k in self.args: 196 if k not in self.arg_types: 197 raise TypeError(f"Unexpected keyword: '{k}' for {self.__class__}") 198 199 for k in self.required_args: 200 v = self.args.get(k) 201 if v is None or (isinstance(v, list) and not v): 202 errors.append(f"Required keyword: '{k}' missing for {self.__class__}") 203 204 if args and self.is_func and len(args) > len(self.arg_types) and not self.is_var_len_args: 205 errors.append( 206 f"The number of provided arguments ({len(args)}) is greater than " 207 f"the maximum number of supported arguments ({len(self.arg_types)})" 208 ) 209 210 return errors 211 212 def update_positions( 213 self: EC, 214 other: t.Optional[ExpressionCore | Token] = None, 215 line: t.Optional[int] = None, 216 col: t.Optional[int] = None, 217 start: t.Optional[int] = None, 218 end: t.Optional[int] = None, 219 ) -> EC: 220 if other is None: 221 self.meta["line"] = line 222 self.meta["col"] = col 223 self.meta["start"] = start 224 self.meta["end"] = end 225 elif isinstance(other, ExpressionCore): 226 for k in POSITION_META_KEYS: 227 if k in other.meta: 228 self.meta[k] = other.meta[k] 229 else: 230 # Token: has .line, .col, .start, .end attributes 231 self.meta["line"] = other.line 232 self.meta["col"] = other.col 233 self.meta["start"] = other.start 234 self.meta["end"] = other.end 235 return self 236 237 def to_py(self) -> t.Any: 238 raise ValueError(f"{self} cannot be converted to a Python object.") 239 240 def text(self, key: str) -> str: 241 field = self.args.get(key) 242 if isinstance(field, str): 243 return field 244 return "" 245 246 @property 247 def name(self) -> str: 248 return self.text("this") 249 250 @property 251 def alias(self) -> str: 252 alias = self.args.get("alias") 253 if isinstance(alias, ExpressionCore): 254 return alias.name 255 return self.text("alias") 256 257 @property 258 def alias_column_names(self) -> t.List[str]: 259 table_alias = self.args.get("alias") 260 if not table_alias: 261 return [] 262 return [c.name for c in table_alias.args.get("columns") or []] 263 264 @property 265 def alias_or_name(self) -> str: 266 return self.alias or self.name 267 268 @property 269 def output_name(self) -> str: 270 return "" 271 272 def is_leaf(self) -> bool: 273 return not any( 274 (isinstance(v, ExpressionCore) or type(v) is list) and v for v in self.args.values() 275 ) 276 277 def __deepcopy__(self, memo: t.Any) -> ExpressionCore: 278 root = self.__class__() 279 stack: t.List[t.Tuple[ExpressionCore, ExpressionCore]] = [(self, root)] 280 281 while stack: 282 node, copy = stack.pop() 283 284 if node.comments is not None: 285 copy.comments = deepcopy(node.comments) 286 if node._type is not None: 287 copy._type = deepcopy(node._type) 288 if node._meta is not None: 289 copy._meta = deepcopy(node._meta) 290 if node._hash is not None: 291 copy._hash = node._hash 292 293 for k, vs in node.args.items(): 294 if isinstance(vs, ExpressionCore): 295 stack.append((vs, vs.__class__())) 296 copy.set(k, stack[-1][-1]) 297 elif type(vs) is list: 298 copy.args[k] = [] 299 300 for v in vs: 301 if isinstance(v, ExpressionCore): 302 stack.append((v, v.__class__())) 303 copy.append(k, stack[-1][-1]) 304 else: 305 copy.append(k, v) 306 else: 307 copy.args[k] = vs 308 309 return root 310 311 def copy(self: EC) -> EC: 312 return deepcopy(self) 313 314 def add_comments(self, comments: t.Optional[t.List[str]] = None, prepend: bool = False) -> None: 315 if self.comments is None: 316 self.comments = [] 317 318 if comments: 319 for comment in comments: 320 _, *meta = comment.split(SQLGLOT_META) 321 if meta: 322 for kv in "".join(meta).split(","): 323 k, *v = kv.split("=") 324 self.meta[k.strip()] = to_bool(v[0].strip() if v else True) 325 326 if not prepend: 327 self.comments.append(comment) 328 329 if prepend: 330 self.comments = comments + self.comments 331 332 def set( 333 self, 334 arg_key: str, 335 value: object, 336 index: t.Optional[int] = None, 337 overwrite: bool = True, 338 ) -> None: 339 node: t.Optional[ExpressionCore] = self 340 341 while node and node._hash is not None: 342 node._hash = None 343 node = node.parent 344 345 if index is not None: 346 expressions = self.args.get(arg_key) or [] 347 348 try: 349 if expressions[index] is None: 350 return 351 except IndexError: 352 return 353 354 if value is None: 355 expressions.pop(index) 356 for v in expressions[index:]: 357 v.index = v.index - 1 358 return 359 360 if isinstance(value, list): 361 expressions.pop(index) 362 expressions[index:index] = value 363 elif overwrite: 364 expressions[index] = value 365 else: 366 expressions.insert(index, value) 367 368 value = expressions 369 elif value is None: 370 self.args.pop(arg_key, None) 371 return 372 373 self.args[arg_key] = value 374 self._set_parent(arg_key, value, index) 375 376 def find(self, *expression_types: t.Type[EC], bfs: bool = True) -> t.Optional[EC]: 377 return next(self.find_all(*expression_types, bfs=bfs), None) 378 379 def find_all(self, *expression_types: t.Type[EC], bfs: bool = True) -> t.Iterator[EC]: 380 for expression in self.walk(bfs=bfs): 381 if isinstance(expression, expression_types): 382 yield expression 383 384 def walk( 385 self: EC, 386 bfs: bool = True, 387 prune: t.Optional[t.Callable[[EC], bool]] = None, 388 ) -> t.Iterator[EC]: 389 if bfs: 390 yield from self.bfs(prune=prune) 391 else: 392 yield from self.dfs(prune=prune) 393 394 def replace(self, expression: t.Any) -> t.Any: 395 parent = self.parent 396 397 if not parent or parent is expression: 398 return expression 399 400 key = self.arg_key 401 if not key: 402 return expression 403 404 value = parent.args.get(key) 405 406 if type(expression) is list and isinstance(value, ExpressionCore): 407 if value.parent: 408 value.parent.replace(expression) 409 else: 410 parent.set(key, expression, self.index) 411 412 if expression is not self: 413 self.parent = None 414 self.arg_key = None 415 self.index = None 416 417 return expression 418 419 def pop(self: EC) -> EC: 420 self.replace(None) 421 return self 422 423 def assert_is(self, type_: t.Type[EC]) -> EC: 424 if not isinstance(self, type_): 425 raise AssertionError(f"{self} is not {type_}.") 426 return self 427 428 def transform( 429 self, fun: t.Callable, *args: object, copy: bool = True, **kwargs: object 430 ) -> t.Any: 431 root: t.Any = None 432 new_node: t.Any = None 433 434 for node in (self.copy() if copy else self).dfs(prune=lambda n: n is not new_node): 435 parent, arg_key, index = node.parent, node.arg_key, node.index 436 new_node = fun(node, *args, **kwargs) 437 438 if not root: 439 root = new_node 440 elif parent and arg_key and new_node is not node: 441 parent.set(arg_key, new_node, index) 442 443 assert root 444 return root
ExpressionCore(**args: object)
41 def __init__(self, **args: object) -> None: 42 self.args: t.Dict[str, t.Any] = args 43 self.parent: t.Optional[ExpressionCore] = None 44 self.arg_key: t.Optional[str] = None 45 self.index: t.Optional[int] = None 46 self.comments: t.Optional[t.List[str]] = None 47 self._type: t.Optional[ExpressionCore] = None 48 self._meta: t.Optional[t.Dict[str, t.Any]] = None 49 self._hash: t.Optional[int] = None 50 51 for arg_key, value in self.args.items(): 52 self._set_parent(arg_key, value)
parent: Optional[ExpressionCore]
def
iter_expressions(self: ~EC, reverse: bool = False) -> Iterator[~EC]:
66 def iter_expressions(self: EC, reverse: bool = False) -> t.Iterator[EC]: 67 for vs in reversed(self.args.values()) if reverse else self.args.values(): 68 if isinstance(vs, list): 69 for v in reversed(vs) if reverse else vs: 70 if isinstance(v, ExpressionCore): 71 yield t.cast(EC, v) 72 elif isinstance(vs, ExpressionCore): 73 yield t.cast(EC, vs)
def
error_messages(self, args: Optional[Sequence] = None) -> List[str]:
191 def error_messages(self, args: t.Optional[t.Sequence] = None) -> t.List[str]: 192 errors: t.List[str] = [] 193 194 if UNITTEST: 195 for k in self.args: 196 if k not in self.arg_types: 197 raise TypeError(f"Unexpected keyword: '{k}' for {self.__class__}") 198 199 for k in self.required_args: 200 v = self.args.get(k) 201 if v is None or (isinstance(v, list) and not v): 202 errors.append(f"Required keyword: '{k}' missing for {self.__class__}") 203 204 if args and self.is_func and len(args) > len(self.arg_types) and not self.is_var_len_args: 205 errors.append( 206 f"The number of provided arguments ({len(args)}) is greater than " 207 f"the maximum number of supported arguments ({len(self.arg_types)})" 208 ) 209 210 return errors
def
update_positions( self: ~EC, other: Union[ExpressionCore, sqlglot.tokenizer_core.Token, NoneType] = None, line: Optional[int] = None, col: Optional[int] = None, start: Optional[int] = None, end: Optional[int] = None) -> ~EC:
212 def update_positions( 213 self: EC, 214 other: t.Optional[ExpressionCore | Token] = None, 215 line: t.Optional[int] = None, 216 col: t.Optional[int] = None, 217 start: t.Optional[int] = None, 218 end: t.Optional[int] = None, 219 ) -> EC: 220 if other is None: 221 self.meta["line"] = line 222 self.meta["col"] = col 223 self.meta["start"] = start 224 self.meta["end"] = end 225 elif isinstance(other, ExpressionCore): 226 for k in POSITION_META_KEYS: 227 if k in other.meta: 228 self.meta[k] = other.meta[k] 229 else: 230 # Token: has .line, .col, .start, .end attributes 231 self.meta["line"] = other.line 232 self.meta["col"] = other.col 233 self.meta["start"] = other.start 234 self.meta["end"] = other.end 235 return self
def
add_comments( self, comments: Optional[List[str]] = None, prepend: bool = False) -> None:
314 def add_comments(self, comments: t.Optional[t.List[str]] = None, prepend: bool = False) -> None: 315 if self.comments is None: 316 self.comments = [] 317 318 if comments: 319 for comment in comments: 320 _, *meta = comment.split(SQLGLOT_META) 321 if meta: 322 for kv in "".join(meta).split(","): 323 k, *v = kv.split("=") 324 self.meta[k.strip()] = to_bool(v[0].strip() if v else True) 325 326 if not prepend: 327 self.comments.append(comment) 328 329 if prepend: 330 self.comments = comments + self.comments
def
set( self, arg_key: str, value: object, index: Optional[int] = None, overwrite: bool = True) -> None:
332 def set( 333 self, 334 arg_key: str, 335 value: object, 336 index: t.Optional[int] = None, 337 overwrite: bool = True, 338 ) -> None: 339 node: t.Optional[ExpressionCore] = self 340 341 while node and node._hash is not None: 342 node._hash = None 343 node = node.parent 344 345 if index is not None: 346 expressions = self.args.get(arg_key) or [] 347 348 try: 349 if expressions[index] is None: 350 return 351 except IndexError: 352 return 353 354 if value is None: 355 expressions.pop(index) 356 for v in expressions[index:]: 357 v.index = v.index - 1 358 return 359 360 if isinstance(value, list): 361 expressions.pop(index) 362 expressions[index:index] = value 363 elif overwrite: 364 expressions[index] = value 365 else: 366 expressions.insert(index, value) 367 368 value = expressions 369 elif value is None: 370 self.args.pop(arg_key, None) 371 return 372 373 self.args[arg_key] = value 374 self._set_parent(arg_key, value, index)
def
walk( self: ~EC, bfs: bool = True, prune: Optional[Callable[[~EC], bool]] = None) -> Iterator[~EC]:
def
replace(self, expression: Any) -> Any:
394 def replace(self, expression: t.Any) -> t.Any: 395 parent = self.parent 396 397 if not parent or parent is expression: 398 return expression 399 400 key = self.arg_key 401 if not key: 402 return expression 403 404 value = parent.args.get(key) 405 406 if type(expression) is list and isinstance(value, ExpressionCore): 407 if value.parent: 408 value.parent.replace(expression) 409 else: 410 parent.set(key, expression, self.index) 411 412 if expression is not self: 413 self.parent = None 414 self.arg_key = None 415 self.index = None 416 417 return expression
def
transform( self, fun: Callable, *args: object, copy: bool = True, **kwargs: object) -> Any:
428 def transform( 429 self, fun: t.Callable, *args: object, copy: bool = True, **kwargs: object 430 ) -> t.Any: 431 root: t.Any = None 432 new_node: t.Any = None 433 434 for node in (self.copy() if copy else self).dfs(prune=lambda n: n is not new_node): 435 parent, arg_key, index = node.parent, node.arg_key, node.index 436 new_node = fun(node, *args, **kwargs) 437 438 if not root: 439 root = new_node 440 elif parent and arg_key and new_node is not node: 441 parent.set(arg_key, new_node, index) 442 443 assert root 444 return root