Coverage for xerini/statement.py: 100%
96 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-20 19:54 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-20 19:54 +0000
1"""
2Module to define a SQL Statement
3"""
5from typing import Optional, Self, Set
7import sqlparse
8from pydantic import BaseModel, model_validator
9from sql_metadata import Parser, QueryType
11from xerini.enums import StatementType, TableWriteType
12from xerini.utilities import (
13 meaningful_strings_count,
14 successive_pairs,
15 following_pairs_second_item,
16)
18# pylint: disable=unused-import
19from xerini.supported_query_types import SUPPORTED_QUERY_TYPES
22class Statement(BaseModel):
23 """SQL Statment class"""
25 text: str
26 _parsed_text: Optional[Parser] = None
27 _query_type: Optional[QueryType] = None
29 @model_validator(mode="before")
30 @classmethod
31 def only_one_statement(cls, data: dict):
32 """Model validator to make sure that there is only a single sql statement in text"""
33 if not 0 < meaningful_strings_count(data["text"]) < 2:
34 raise ValueError(
35 f"Expected a string containing a single SQL statement, but {data['text']=} !"
36 )
37 return data
39 @property
40 def formatted_text(self) -> str:
41 """Returns sqlparse.formatted text"""
42 return sqlparse.format(
43 self.text,
44 reindent=True,
45 keyword_case="upper",
46 strip_comments=True,
47 comma_first=True,
48 indent_tabs=True,
49 )
51 @model_validator(mode="after")
52 def set_parsed_text(self) -> Self:
53 """Model validator to set the parsed text"""
54 self._parsed_text = Parser(self.formatted_text)
55 return self
57 @model_validator(mode="after")
58 def set_query_type(self) -> Self:
59 """Model validator to fix the augmented via the supported_query_types.py query type"""
60 self._query_type = None
61 try:
62 self._query_type = self.parsed_text.query_type
63 except ValueError as _ve:
64 pass
65 return self
67 @property
68 def query_type(self) -> QueryType | None:
69 """The augmented via the supported_query_types.py query type"""
70 return self._query_type
72 @property
73 def parsed_text(self) -> Parser:
74 """Parsed_text property"""
75 return self._parsed_text
77 def __str__(self) -> str:
78 """Returns a string for printing"""
79 return self.text
81 @property
82 def statement_type(self) -> StatementType:
83 """Whether the statement is a views, updates, or transfers data"""
84 return StatementType.from_query_type(self.query_type)
86 @property
87 def write_type(self) -> TableWriteType:
88 """Whether the statement builds, updates, or temporarily write"""
89 return TableWriteType.from_statement_query(self.statement_type, self.query_type)
91 @property
92 def tables(self) -> Set[Optional[str]]:
93 """Tables integral to the statement"""
94 ans: Set = set()
95 if self.query_type is not None:
96 ans = set(self.parsed_text.tables)
97 return ans
99 @property
100 def non_empty_tokens(self) -> list[sqlparse.sql.Token]:
101 """From the sql_metadata Parser"""
102 ans = []
103 if self.query_type is None:
104 pass
105 else:
106 ans = self.parsed_text.non_empty_tokens
107 return ans
109 @property
110 def normalized_successive_token_pairs(self) -> list[tuple[str, str]]:
111 """Needed to get to the affected tables"""
112 return [
113 (sp[0].normalized, sp[1].normalized)
114 for sp in successive_pairs(self.parsed_text.tokens)
115 ]
117 @property
118 def affected_table(self) -> str | None:
119 """Where the data is going, according to the statement"""
120 ans = None
121 if self.query_type in {QueryType.TRUNCATE, QueryType.DROP}:
122 ans = list(self.tables)[0]
123 elif self.query_type == QueryType.UPDATE:
124 ans = [
125 tbl
126 for tbl in self.tables
127 if tbl
128 and ("UPDATE", tbl.upper()) in self.normalized_successive_token_pairs
129 ][0]
130 elif ("DELETE", "FROM") in self.normalized_successive_token_pairs:
131 ans = following_pairs_second_item(
132 self.normalized_successive_token_pairs, ("DELETE", "FROM")
133 )
134 elif ("CREATE", "TABLE") in self.normalized_successive_token_pairs:
135 ans = following_pairs_second_item(
136 self.normalized_successive_token_pairs, ("CREATE", "TABLE")
137 )
138 elif ("TEMP", "TABLE") in self.normalized_successive_token_pairs:
139 ans = following_pairs_second_item(
140 self.normalized_successive_token_pairs, ("TEMP", "TABLE")
141 )
142 elif ("TEMPORARY", "TABLE") in self.normalized_successive_token_pairs:
143 ans = following_pairs_second_item(
144 self.normalized_successive_token_pairs, ("TEMPORARY", "TABLE")
145 )
146 elif ("INSERT", "INTO") in self.normalized_successive_token_pairs:
147 ans = following_pairs_second_item(
148 self.normalized_successive_token_pairs, ("INSERT", "INTO")
149 )
150 else:
151 pass
152 return ans.lower() if ans is not None else None
154 def what_are_we_truncating_dropping(self) -> str | None:
155 """We need to specify the destination table when we truncate"""
156 if not self.query_type in {QueryType.TRUNCATE, QueryType.DROP}:
157 raise TypeError(f"This method is unsupported for {self.query_type}!")
158 return self.affected_table
160 @property
161 def source_tables(self) -> Set[Optional[str]]:
162 """Tables from which the data is being sourced, according to the statement"""
163 sources: Set = set()
164 if self.statement_type in {StatementType.UPDATE, StatementType.VIEW}:
165 if self.query_type in {QueryType.TRUNCATE, QueryType.DROP}:
166 pass
167 elif self.query_type == QueryType.DELETE:
168 sources = self.tables - {
169 self.affected_table,
170 }
171 else:
172 sources = self.tables
173 else:
174 sources = self.tables - {
175 self.affected_table,
176 }
177 return sources