Coverage for xerini/statement.py: 100%

96 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-20 19:54 +0000

1""" 

2Module to define a SQL Statement 

3""" 

4 

5from typing import Optional, Self, Set 

6 

7import sqlparse 

8from pydantic import BaseModel, model_validator 

9from sql_metadata import Parser, QueryType 

10 

11from xerini.enums import StatementType, TableWriteType 

12from xerini.utilities import ( 

13 meaningful_strings_count, 

14 successive_pairs, 

15 following_pairs_second_item, 

16) 

17 

18# pylint: disable=unused-import 

19from xerini.supported_query_types import SUPPORTED_QUERY_TYPES 

20 

21 

22class Statement(BaseModel): 

23 """SQL Statment class""" 

24 

25 text: str 

26 _parsed_text: Optional[Parser] = None 

27 _query_type: Optional[QueryType] = None 

28 

29 @model_validator(mode="before") 

30 @classmethod 

31 def only_one_statement(cls, data: dict): 

32 """Model validator to make sure that there is only a single sql statement in text""" 

33 if not 0 < meaningful_strings_count(data["text"]) < 2: 

34 raise ValueError( 

35 f"Expected a string containing a single SQL statement, but {data['text']=} !" 

36 ) 

37 return data 

38 

39 @property 

40 def formatted_text(self) -> str: 

41 """Returns sqlparse.formatted text""" 

42 return sqlparse.format( 

43 self.text, 

44 reindent=True, 

45 keyword_case="upper", 

46 strip_comments=True, 

47 comma_first=True, 

48 indent_tabs=True, 

49 ) 

50 

51 @model_validator(mode="after") 

52 def set_parsed_text(self) -> Self: 

53 """Model validator to set the parsed text""" 

54 self._parsed_text = Parser(self.formatted_text) 

55 return self 

56 

57 @model_validator(mode="after") 

58 def set_query_type(self) -> Self: 

59 """Model validator to fix the augmented via the supported_query_types.py query type""" 

60 self._query_type = None 

61 try: 

62 self._query_type = self.parsed_text.query_type 

63 except ValueError as _ve: 

64 pass 

65 return self 

66 

67 @property 

68 def query_type(self) -> QueryType | None: 

69 """The augmented via the supported_query_types.py query type""" 

70 return self._query_type 

71 

72 @property 

73 def parsed_text(self) -> Parser: 

74 """Parsed_text property""" 

75 return self._parsed_text 

76 

77 def __str__(self) -> str: 

78 """Returns a string for printing""" 

79 return self.text 

80 

81 @property 

82 def statement_type(self) -> StatementType: 

83 """Whether the statement is a views, updates, or transfers data""" 

84 return StatementType.from_query_type(self.query_type) 

85 

86 @property 

87 def write_type(self) -> TableWriteType: 

88 """Whether the statement builds, updates, or temporarily write""" 

89 return TableWriteType.from_statement_query(self.statement_type, self.query_type) 

90 

91 @property 

92 def tables(self) -> Set[Optional[str]]: 

93 """Tables integral to the statement""" 

94 ans: Set = set() 

95 if self.query_type is not None: 

96 ans = set(self.parsed_text.tables) 

97 return ans 

98 

99 @property 

100 def non_empty_tokens(self) -> list[sqlparse.sql.Token]: 

101 """From the sql_metadata Parser""" 

102 ans = [] 

103 if self.query_type is None: 

104 pass 

105 else: 

106 ans = self.parsed_text.non_empty_tokens 

107 return ans 

108 

109 @property 

110 def normalized_successive_token_pairs(self) -> list[tuple[str, str]]: 

111 """Needed to get to the affected tables""" 

112 return [ 

113 (sp[0].normalized, sp[1].normalized) 

114 for sp in successive_pairs(self.parsed_text.tokens) 

115 ] 

116 

117 @property 

118 def affected_table(self) -> str | None: 

119 """Where the data is going, according to the statement""" 

120 ans = None 

121 if self.query_type in {QueryType.TRUNCATE, QueryType.DROP}: 

122 ans = list(self.tables)[0] 

123 elif self.query_type == QueryType.UPDATE: 

124 ans = [ 

125 tbl 

126 for tbl in self.tables 

127 if tbl 

128 and ("UPDATE", tbl.upper()) in self.normalized_successive_token_pairs 

129 ][0] 

130 elif ("DELETE", "FROM") in self.normalized_successive_token_pairs: 

131 ans = following_pairs_second_item( 

132 self.normalized_successive_token_pairs, ("DELETE", "FROM") 

133 ) 

134 elif ("CREATE", "TABLE") in self.normalized_successive_token_pairs: 

135 ans = following_pairs_second_item( 

136 self.normalized_successive_token_pairs, ("CREATE", "TABLE") 

137 ) 

138 elif ("TEMP", "TABLE") in self.normalized_successive_token_pairs: 

139 ans = following_pairs_second_item( 

140 self.normalized_successive_token_pairs, ("TEMP", "TABLE") 

141 ) 

142 elif ("TEMPORARY", "TABLE") in self.normalized_successive_token_pairs: 

143 ans = following_pairs_second_item( 

144 self.normalized_successive_token_pairs, ("TEMPORARY", "TABLE") 

145 ) 

146 elif ("INSERT", "INTO") in self.normalized_successive_token_pairs: 

147 ans = following_pairs_second_item( 

148 self.normalized_successive_token_pairs, ("INSERT", "INTO") 

149 ) 

150 else: 

151 pass 

152 return ans.lower() if ans is not None else None 

153 

154 def what_are_we_truncating_dropping(self) -> str | None: 

155 """We need to specify the destination table when we truncate""" 

156 if not self.query_type in {QueryType.TRUNCATE, QueryType.DROP}: 

157 raise TypeError(f"This method is unsupported for {self.query_type}!") 

158 return self.affected_table 

159 

160 @property 

161 def source_tables(self) -> Set[Optional[str]]: 

162 """Tables from which the data is being sourced, according to the statement""" 

163 sources: Set = set() 

164 if self.statement_type in {StatementType.UPDATE, StatementType.VIEW}: 

165 if self.query_type in {QueryType.TRUNCATE, QueryType.DROP}: 

166 pass 

167 elif self.query_type == QueryType.DELETE: 

168 sources = self.tables - { 

169 self.affected_table, 

170 } 

171 else: 

172 sources = self.tables 

173 else: 

174 sources = self.tables - { 

175 self.affected_table, 

176 } 

177 return sources