Coverage for xerini/script.py: 81%

123 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-20 19:54 +0000

1""" 

2Module to define a SQL script 

3""" 

4 

5import subprocess 

6from typing import Optional, Self 

7from functools import cached_property 

8from pathlib import Path 

9import networkx as nx 

10import sqlparse 

11from pydantic import BaseModel, validate_call 

12from xerini.statement import Statement 

13from xerini.enums import QueryType, TableWriteType, StatementType 

14from xerini.utilities import parallel_decomposition 

15 

16 

17class Script(BaseModel): 

18 """SQL Script class, basically a container for Statements""" 

19 

20 raw_code: Optional[str] = None 

21 

22 @classmethod 

23 @validate_call 

24 def from_string(cls, text: str) -> Self: 

25 """This is just a helper function""" 

26 return cls(raw_code=text) 

27 

28 @classmethod 

29 @validate_call 

30 def from_file(cls, pth: Path) -> Self: 

31 """Constructs the Script object from the sql in a file""" 

32 if pth.suffix.lower() != ".sql": 

33 raise FileNotFoundError("I am only allowed to read *.sql files!") 

34 with pth.open("r", encoding="utf-8") as fin: 

35 text = fin.read() 

36 return cls(raw_code=text) 

37 

38 @classmethod 

39 @validate_call 

40 def from_directory(cls, directory: Path) -> Self: 

41 """Constructs the Script object from a directory of sql files""" 

42 if not directory.is_dir(): 

43 raise NotADirectoryError(f"{dir=} is not a directory!") 

44 texts = [] 

45 for pth in directory.iterdir(): 

46 if pth.is_file() and pth.suffix.lower() == ".sql": 46 ↛ 45line 46 didn't jump to line 45 because the condition on line 46 was always true

47 with pth.open("r", encoding="utf-8") as fin: 

48 texts.append(fin.read()) 

49 return cls(raw_code="\n".join(texts)) 

50 

51 @property 

52 def formatted_text(self) -> str: 

53 """Render string""" 

54 return "\n\n".join(stmt.formatted_text for stmt in self.statements) 

55 

56 def keys(self) -> set[str | None]: 

57 """The keys to the dictionary are the names of the affected table of the code""" 

58 return { 

59 stmt.affected_table 

60 for stmt in self.statements 

61 if stmt.write_type != TableWriteType.NO_WRITING 

62 } 

63 

64 def __getitem__(self, item: str) -> list[Statement]: 

65 """Dictionary form to retrieve statements pertinent to an affected table""" 

66 if item not in self.keys(): 

67 raise KeyError(f"{item=} is not an affected table of the script!") 

68 return [ 

69 stmt 

70 for stmt in self.statements 

71 if ( 

72 (stmt.write_type != TableWriteType.NO_WRITING) 

73 & (stmt.affected_table == item) 

74 ) 

75 ] 

76 

77 @cached_property 

78 def statements(self) -> list[Statement]: 

79 """The list of statements that make up the script""" 

80 return [Statement(text=sub) for sub in sqlparse.split(self.raw_code)] 

81 

82 @property 

83 def statement_types(self) -> list[StatementType]: 

84 """The statement types""" 

85 return [_s.statement_type for _s in self.statements] 

86 

87 @cached_property 

88 def statement_query_types(self) -> list[QueryType | None]: 

89 """The statement query types""" 

90 return [_s.query_type for _s in self.statements] 

91 

92 @property 

93 def statement_write_types(self) -> list[TableWriteType]: 

94 """The statement write types""" 

95 return [_s.write_type for _s in self.statements] 

96 

97 @property 

98 def is_valid(self) -> bool: 

99 """Are all the statements in the script valid ones""" 

100 return all(_qt for _qt in self.statement_query_types) 

101 

102 @validate_call 

103 def write(self, output: Path) -> None: 

104 """Formatted output to the file at the path""" 

105 if not self.statements: 

106 raise ValueError("Nothing to write, as the script is empty!") 

107 with output.open("w") as fout: 

108 for stmt in self.statements: 

109 fout.write(stmt.formatted_text) 

110 fout.write("\n") 

111 

112 @cached_property 

113 def digraph(self) -> nx.DiGraph: 

114 """Returns a directed graph with the table dependency structure of the sql code""" 

115 dsg = nx.DiGraph() 

116 for stmt in self.statements: 

117 if stmt.write_type != TableWriteType.NO_WRITING and stmt.affected_table: 

118 for source in stmt.source_tables: 

119 dsg.add_edge(source, stmt.affected_table) 

120 dsg.edges[source, stmt.affected_table][ 

121 "tooltip" 

122 ] = stmt.formatted_text 

123 for _n in dsg.nodes(): 

124 dsg.nodes[_n]["shape"] = "box" 

125 dsg.nodes[_n]["color"] = "blue" 

126 return dsg 

127 

128 @validate_call 

129 def write_dot(self, file: Path, name: str = "dotted_script") -> Path: 

130 """Writes the dotfile for the graph representation""" 

131 dsg = self.digraph 

132 dsg.graph = {"name": name, "splines": "ortho", "rankdir": "LR"} 

133 _adg = nx.nx_agraph.to_agraph(dsg) 

134 _adg.write(file) 

135 return file 

136 

137 @staticmethod 

138 @validate_call 

139 def write_svg(dot_file: Path) -> Path: 

140 """Produce the SVG from the dot_file""" 

141 subprocess.call(["dot", "-Tsvg", dot_file, "-O"]) 

142 svg_file: Path = dot_file.parent / (dot_file.name + ".svg") 

143 return svg_file 

144 

145 def stage_decomposition(self) -> list[set[str]]: 

146 """Use parallel decomposition of the graph representation 

147 to produce a parallelized orchestration of the table builds 

148 """ 

149 if not self.statements: 

150 raise ValueError("I can't decompose an empty script!") 

151 pcg = list(parallel_decomposition(self.digraph)) 

152 ell = [] 

153 for stg in pcg[1:]: 

154 tables = set() 

155 for tbl in stg: 

156 if tbl in self.keys(): 156 ↛ 155line 156 didn't jump to line 155 because the condition on line 156 was always true

157 tables.add(tbl) 

158 ell.append(tables) 

159 

160 return ell 

161 

162 @validate_call 

163 def write_orchestration(self, staged_directory: Path): 

164 """Write the script to disk""" 

165 if not staged_directory.is_absolute(): 

166 staged_directory = staged_directory.expanduser().resolve() 

167 

168 if staged_directory.exists() and staged_directory.is_dir(): 

169 raise IsADirectoryError( 

170 f"{staged_directory=} is an existing directory, I should not overwrite it!" 

171 ) 

172 if staged_directory.exists() and staged_directory.is_file(): 

173 raise FileExistsError( 

174 f"{staged_directory=} is an existing file, I should definitely not overwrite it!" 

175 ) 

176 staged_directory.mkdir(parents=True, exist_ok=True) 

177 decomposed = self.stage_decomposition() 

178 for idx, stg in enumerate(decomposed): 

179 stage = staged_directory / f"stage_{str(idx+1).zfill(2)}" 

180 stage.mkdir() 

181 for tbl in stg: 

182 pth = stage / f"{tbl}.sql" 

183 with pth.open("w", encoding="utf-8") as fout: 

184 for stmt in self[tbl]: 

185 fout.write(str(stmt)) 

186 fout.write("\n")