Coverage for xerini/utilities.py: 100%

53 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-20 19:54 +0000

1""" 

2## Module with useful functions 

3 

4- successive_pairs 

5""" 

6 

7from typing import Iterable, TypeVar, Tuple, Optional, Collection, List, Iterator 

8from itertools import tee 

9from string import ascii_lowercase, digits 

10from random import choice, choices, shuffle 

11import sqlparse 

12import networkx as nx 

13 

14Things = TypeVar("Things") 

15 

16 

17def successive_pairs(iterable: Iterable[Things]) -> Iterable[Tuple[Things, Things]]: 

18 """ 

19 Given an iterable this function returns a generator of the successive overlapping 

20 pairs taken from the input iterable, 

21 ```python 

22 >>> list(successive_pairs('ABCDEFG')) 

23 [('A', 'B'), ('B', 'C'), ('C', 'D'), ('D', 'E'), ('E', 'F'), ('F', 'G')] 

24 ``` 

25 """ 

26 _a, _b = tee(iterable) 

27 next(_b, None) 

28 return zip(_a, _b) 

29 

30 

31def following_pairs_second_item( 

32 ell: List[Tuple[Things, Things]], search_pair: Tuple[Things, Things] 

33) -> Things: 

34 """Used to get affected table names""" 

35 if search_pair not in ell: 

36 raise ValueError(f"The {search_pair=} is not in {ell=}!") 

37 idx: int = ell.index(search_pair) 

38 if (idx + 1) == len(ell): 

39 raise IndexError(f"{search_pair} was the last element of {ell=}!") 

40 return ell[idx + 1][1] 

41 

42 

43def meaningful_strings_count(text: Optional[str] = None) -> int: 

44 """ 

45 We need this function to determine whether the text used to initialize a SQL Statement only has 

46 a single statement 

47 :param text: an arbitrary text 

48 :param separator: space, tab, but most likely ; 

49 :return: the number of meaningful string in the original text 

50 """ 

51 ans = 0 

52 if text is None: 

53 pass 

54 if text: 

55 ans = sum(bool(sub.replace(";", "")) for sub in sqlparse.split(text)) 

56 return ans 

57 

58 

59def random_id8() -> str: 

60 """ 

61 Returns an ascii string of length 8 which starts with a lower case character 

62 :return: 

63 """ 

64 alpha_numeric = ascii_lowercase + digits 

65 first = choice(ascii_lowercase) 

66 rest = "".join(choices(alpha_numeric, k=7)) 

67 return first + rest 

68 

69 

70def random_partitions( 

71 number_of_partitions, items: Collection[Things] 

72) -> Tuple[List[Things], ...]: 

73 """Partitions the elements of items into a tuple of random lists of roughly equal size""" 

74 if len(items) < number_of_partitions: 

75 raise ValueError( 

76 f"The provided collection has less elements, {len(items)=}, " 

77 f"than the number of partitions, {number_of_partitions=}!" 

78 ) 

79 if number_of_partitions < 2: 

80 raise ValueError( 

81 f"The partition size, {number_of_partitions=}, should be greater than 2!" 

82 ) 

83 _t: Tuple[List[Things], ...] = tuple([] for _ in range(number_of_partitions)) 

84 _ell = list(items) 

85 shuffle(_ell) 

86 for index, elem in enumerate(_ell): 

87 _t[index % number_of_partitions].append(elem) 

88 return _t 

89 

90 

91def minimal_nodes(directed: nx.DiGraph) -> set: 

92 """ 

93 Returns the set of nodes in the network that have no incoming edges 

94 :param directed: the graph from which we should extract the minimal nodes 

95 :return: a set of mininal nodes 

96 """ 

97 if not isinstance(directed, nx.DiGraph): 

98 raise TypeError(f"unable to compute minimal nodes for {type(directed)=}.") 

99 return {node for node in directed.nodes() if directed.in_degree(node) == 0} 

100 

101 

102def parallel_decomposition(directed: nx.DiGraph) -> Iterator[set]: 

103 """Decompose the DAG into stages""" 

104 if not nx.is_directed_acyclic_graph(directed): 

105 raise TypeError("parallel decomposition requires a directed acyclic graph.") 

106 _ctdg = directed.copy() 

107 while _ctdg.nodes(): 

108 minimals = minimal_nodes(_ctdg) 

109 _ctdg.remove_nodes_from(minimals) 

110 yield minimals