Coverage for xerini/script.py: 81%
123 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-20 19:54 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-20 19:54 +0000
1"""
2Module to define a SQL script
3"""
5import subprocess
6from typing import Optional, Self
7from functools import cached_property
8from pathlib import Path
9import networkx as nx
10import sqlparse
11from pydantic import BaseModel, validate_call
12from xerini.statement import Statement
13from xerini.enums import QueryType, TableWriteType, StatementType
14from xerini.utilities import parallel_decomposition
17class Script(BaseModel):
18 """SQL Script class, basically a container for Statements"""
20 raw_code: Optional[str] = None
22 @classmethod
23 @validate_call
24 def from_string(cls, text: str) -> Self:
25 """This is just a helper function"""
26 return cls(raw_code=text)
28 @classmethod
29 @validate_call
30 def from_file(cls, pth: Path) -> Self:
31 """Constructs the Script object from the sql in a file"""
32 if pth.suffix.lower() != ".sql":
33 raise FileNotFoundError("I am only allowed to read *.sql files!")
34 with pth.open("r", encoding="utf-8") as fin:
35 text = fin.read()
36 return cls(raw_code=text)
38 @classmethod
39 @validate_call
40 def from_directory(cls, directory: Path) -> Self:
41 """Constructs the Script object from a directory of sql files"""
42 if not directory.is_dir():
43 raise NotADirectoryError(f"{dir=} is not a directory!")
44 texts = []
45 for pth in directory.iterdir():
46 if pth.is_file() and pth.suffix.lower() == ".sql": 46 ↛ 45line 46 didn't jump to line 45 because the condition on line 46 was always true
47 with pth.open("r", encoding="utf-8") as fin:
48 texts.append(fin.read())
49 return cls(raw_code="\n".join(texts))
51 @property
52 def formatted_text(self) -> str:
53 """Render string"""
54 return "\n\n".join(stmt.formatted_text for stmt in self.statements)
56 def keys(self) -> set[str | None]:
57 """The keys to the dictionary are the names of the affected table of the code"""
58 return {
59 stmt.affected_table
60 for stmt in self.statements
61 if stmt.write_type != TableWriteType.NO_WRITING
62 }
64 def __getitem__(self, item: str) -> list[Statement]:
65 """Dictionary form to retrieve statements pertinent to an affected table"""
66 if item not in self.keys():
67 raise KeyError(f"{item=} is not an affected table of the script!")
68 return [
69 stmt
70 for stmt in self.statements
71 if (
72 (stmt.write_type != TableWriteType.NO_WRITING)
73 & (stmt.affected_table == item)
74 )
75 ]
77 @cached_property
78 def statements(self) -> list[Statement]:
79 """The list of statements that make up the script"""
80 return [Statement(text=sub) for sub in sqlparse.split(self.raw_code)]
82 @property
83 def statement_types(self) -> list[StatementType]:
84 """The statement types"""
85 return [_s.statement_type for _s in self.statements]
87 @cached_property
88 def statement_query_types(self) -> list[QueryType | None]:
89 """The statement query types"""
90 return [_s.query_type for _s in self.statements]
92 @property
93 def statement_write_types(self) -> list[TableWriteType]:
94 """The statement write types"""
95 return [_s.write_type for _s in self.statements]
97 @property
98 def is_valid(self) -> bool:
99 """Are all the statements in the script valid ones"""
100 return all(_qt for _qt in self.statement_query_types)
102 @validate_call
103 def write(self, output: Path) -> None:
104 """Formatted output to the file at the path"""
105 if not self.statements:
106 raise ValueError("Nothing to write, as the script is empty!")
107 with output.open("w") as fout:
108 for stmt in self.statements:
109 fout.write(stmt.formatted_text)
110 fout.write("\n")
112 @cached_property
113 def digraph(self) -> nx.DiGraph:
114 """Returns a directed graph with the table dependency structure of the sql code"""
115 dsg = nx.DiGraph()
116 for stmt in self.statements:
117 if stmt.write_type != TableWriteType.NO_WRITING and stmt.affected_table:
118 for source in stmt.source_tables:
119 dsg.add_edge(source, stmt.affected_table)
120 dsg.edges[source, stmt.affected_table][
121 "tooltip"
122 ] = stmt.formatted_text
123 for _n in dsg.nodes():
124 dsg.nodes[_n]["shape"] = "box"
125 dsg.nodes[_n]["color"] = "blue"
126 return dsg
128 @validate_call
129 def write_dot(self, file: Path, name: str = "dotted_script") -> Path:
130 """Writes the dotfile for the graph representation"""
131 dsg = self.digraph
132 dsg.graph = {"name": name, "splines": "ortho", "rankdir": "LR"}
133 _adg = nx.nx_agraph.to_agraph(dsg)
134 _adg.write(file)
135 return file
137 @staticmethod
138 @validate_call
139 def write_svg(dot_file: Path) -> Path:
140 """Produce the SVG from the dot_file"""
141 subprocess.call(["dot", "-Tsvg", dot_file, "-O"])
142 svg_file: Path = dot_file.parent / (dot_file.name + ".svg")
143 return svg_file
145 def stage_decomposition(self) -> list[set[str]]:
146 """Use parallel decomposition of the graph representation
147 to produce a parallelized orchestration of the table builds
148 """
149 if not self.statements:
150 raise ValueError("I can't decompose an empty script!")
151 pcg = list(parallel_decomposition(self.digraph))
152 ell = []
153 for stg in pcg[1:]:
154 tables = set()
155 for tbl in stg:
156 if tbl in self.keys(): 156 ↛ 155line 156 didn't jump to line 155 because the condition on line 156 was always true
157 tables.add(tbl)
158 ell.append(tables)
160 return ell
162 @validate_call
163 def write_orchestration(self, staged_directory: Path):
164 """Write the script to disk"""
165 if not staged_directory.is_absolute():
166 staged_directory = staged_directory.expanduser().resolve()
168 if staged_directory.exists() and staged_directory.is_dir():
169 raise IsADirectoryError(
170 f"{staged_directory=} is an existing directory, I should not overwrite it!"
171 )
172 if staged_directory.exists() and staged_directory.is_file():
173 raise FileExistsError(
174 f"{staged_directory=} is an existing file, I should definitely not overwrite it!"
175 )
176 staged_directory.mkdir(parents=True, exist_ok=True)
177 decomposed = self.stage_decomposition()
178 for idx, stg in enumerate(decomposed):
179 stage = staged_directory / f"stage_{str(idx+1).zfill(2)}"
180 stage.mkdir()
181 for tbl in stg:
182 pth = stage / f"{tbl}.sql"
183 with pth.open("w", encoding="utf-8") as fout:
184 for stmt in self[tbl]:
185 fout.write(str(stmt))
186 fout.write("\n")