Corrected the pointed-out arithmetic error (total_pairs=21, comm_pairs=13, rate 61.9%) and added a self-diagnostic template (n, total_pairs, comm_pairs) to the output. This makes consistency checks immediately clear each time, eliminating debate discrepancies. Store parsing hardened + mem_all reads limited to “unbreakable + not over-restrictive” ironclad balance (add→store RAW preserved, remaining aliases commutative). Tool execution verified: 61.9% rate (13/21, 8 non-commutative pairs: RAW/WAW/clobber captured), groups [0,2,4]/[1,6]/[3,5] (all pairs commutative), reorder dependency preserved (add→store order maintained, load/malloc alternating hiding, printf barrier). Correctness check (topo order + all pairs in groups commutative) cleared! Execution results (tool output, printf added for testing)
Instructions (R/W):
0: %x = load float, float* %y -> reads={'mem_all', 'y', 'mem_y'}, writes={'x'}
1: %z = mul float %x, %w -> reads={'w', 'x'}, writes={'z'}
2: %s = add float %z, %v -> reads={'v', 'z'}, writes={'s'}
3: store float %s, float* %c -> reads={'mem_all', 'mem_c', 'c', 's'}, writes={'mem_c'}
4: %ptr = call i8* @malloc(i64 100) -> reads={'mem_all'}, writes={'mem_ptr', 'ptr'}
5: store i32 42, i32* %ptr -> reads={'mem_ptr', 'mem_all', 'ptr'}, writes={'mem_ptr'}
6: call void @printf(i8* %msg) -> reads={'mem_all'}, writes={'mem_all'}
Commutative Groups:
Group: [0, 2, 4]
Group: [1, 6]
Group: [3, 5]
Self-Diagnosis:
n = 7, total_pairs = 21
comm_pairs = 13
Commutativity Rate: 61.9%
Reordered IR (Kahn + Preferences):
%x = load float, float* %y
%ptr = call i8* @malloc(i64 100)
%z = mul float %x, %w
%s = add float %z, %v
store float %s, float* %c
store i32 42, i32* %ptr
call void @printf(i8* %msg)
(Legacy) UF Groups: [[0, 1, 2, 3, 4, 5, 6]] – 危険!
import networkx as nx
from collections import defaultdict
from typing import List, Tuple, Set
import re
# Union-Find (legacyオプション)
class UnionFind:
def __init__(self, size: int):
self.parent = list(range(size))
self.rank = [0] * size
def find(self, p: int) -> int:
if self.parent[p] != p:
self.parent[p] = self.find(self.parent[p])
return self.parent[p]
def union(self, p: int, q: int):
rootP = self.find(p)
rootQ = self.find(q)
if rootP == rootQ:
return
if self.rank[rootP] < self.rank[rootQ]:
self.parent[rootP] = rootQ
elif self.rank[rootP] > self.rank[rootQ]:
self.parent[rootQ] = rootP
else:
self.parent[rootQ] = rootP
self.rank[rootP] += 1
def components(self) -> List[List[int]]:
groups = defaultdict(list)
for i in range(len(self.parent)):
groups[self.find(i)].append(i)
return list(groups.values())
def parse_llvm_ir(instr: str) -> Tuple[Set[str], Set[str]]:
"""強化R/W抽出: store addr read/mem_alias, 定数val skip, SSA限定reads, call clobber."""
reads = set()
writes = set()
clean = re.sub(r'%', '', instr)
if 'load' in instr:
ptr_match = re.search(r'load\s+.*?, \s*(\w+)\*\s*(\w+)', clean)
if ptr_match:
_, ptr = ptr_match.groups()
reads.add(ptr)
reads.add(f'mem_{ptr}')
lhs_match = re.search(r'^(\w+)\s*=', clean)
if lhs_match:
writes.add(lhs_match.group(1))
elif 'store' in instr:
# 堅牢regex: store val, *ptr
m = re.search(r'\bstore\b.*?\s([A-Za-z_]\w*|\d+)\s*,\s*.*?\*\s*([A-Za-z_]\w*)\b', clean)
if m:
val, ptr = m.groups()
if not val.isdigit(): # 定数(42) skip
reads.add(val)
reads.add(ptr)
reads.add(f'mem_{ptr}')
writes.add(f'mem_{ptr}')
elif '=' in instr and ('add' in instr or 'mul' in instr):
parts = clean.split('=')
lhs = parts[0].strip()
writes.add(lhs)
rhs = parts[1].strip()
ops = re.findall(r'\b([a-zA-Z_][a-zA-Z0-9_]*)\b', rhs)
for op in ops:
if op not in ['add', 'mul', 'float', 'i32', 'i64', 'i8', 'call']:
reads.add(op)
elif 'call' in instr:
lhs_match = re.search(r'^(\w+)\s*=', clean)
if lhs_match:
writes.add(lhs_match.group(1))
if 'malloc' in instr:
writes.add('mem_ptr')
else:
reads.add('mem_all')
writes.add('mem_all')
if 'volatile' in instr or 'atomic' in instr:
reads.add('volatile_mem')
writes.add('volatile_mem')
# mem_*が出たらmem_allをreads側だけに追加(clobber衝突保証、writeはalias限定)
if any(v.startswith("mem_") for v in reads | writes):
reads.add("mem_all")
return reads, writes
def is_commutative(i1: Tuple[Set[str], Set[str]], i2: Tuple[Set[str], Set[str]]) -> bool:
r1, w1 = i1
r2, w2 = i2
return w1.isdisjoint(r2 | w2) and w2.isdisjoint(r1 | w1)
def commutativity_graph(instructions: List[Tuple[Set[str], Set[str]]]):
G = nx.Graph()
n = len(instructions)
G.add_nodes_from(range(n))
for i in range(n):
for j in range(i+1, n):
if is_commutative(instructions[i], instructions[j]):
G.add_edge(i, j)
return G
def greedy_disjoint_clique_cover(G):
cliques = list(nx.find_cliques(G))
cliques.sort(key=lambda c: (-len(c), -sum(G.degree(v) for v in c), min(c)))
used = set()
groups = []
for c in cliques:
c2 = [v for v in c if v not in used]
if len(c2) >= 2:
groups.append(sorted(c2))
used.update(c2)
for v in G.nodes():
if v not in used:
groups.append([v])
groups.sort(key=lambda g: min(g))
return groups
def analyze_commutativity(instructions: List[Tuple[Set[str], Set[str]]]) -> List[List[int]]:
G = commutativity_graph(instructions)
return greedy_disjoint_clique_cover(G)
def build_def_use_graph(instructions: List[Tuple[Set[str], Set[str]]]) -> nx.DiGraph:
DG = nx.DiGraph()
n = len(instructions)
DG.add_nodes_from(range(n))
for i in range(n):
for j in range(i+1, n):
if not is_commutative(instructions[i], instructions[j]):
DG.add_edge(i, j)
return DG
def schedule_kahn_with_preferences(groups, def_use_dg, instructions, original_ir):
group_rank = {}
for gi, g in enumerate(groups):
for v in g:
group_rank[v] = gi
indeg = {u: def_use_dg.in_degree(u) for u in def_use_dg.nodes()}
ready = [u for u, d in indeg.items() if d == 0]
out = []
while ready:
ready.sort(key=lambda u: (group_rank.get(u, 10**9), -len(instructions[u][0]), u))
u = ready.pop(0)
out.append(original_ir[u])
for v in list(def_use_dg.successors(u)):
indeg[v] -= 1
if indeg[v] == 0:
ready.append(v)
if len(out) != len(original_ir):
raise ValueError("Cycle detected in def-use graph (or missing nodes).")
return out
class CommutativityAnalysisPass:
def __init__(self):
self.instructions = []
self.original_ir = []
def analyze_basic_block(self, bb_ir: List[str]):
self.original_ir = bb_ir
self.instructions = [parse_llvm_ir(ir) for ir in bb_ir]
def get_read_set(self, idx: int) -> Set[str]:
return self.instructions[idx][0] if idx < len(self.instructions) else set()
def get_write_set(self, idx: int) -> Set[str]:
return self.instructions[idx][1] if idx < len(self.instructions) else set()
def is_commutative_pair(self, i1: int, i2: int) -> bool:
return is_commutative(self.instructions[i1], self.instructions[i2])
def build_commutative_groups(self) -> List[List[int]]:
return analyze_commutativity(self.instructions)
def reorder_for_optimization(self) -> List[str]:
groups = self.build_commutative_groups()
def_use_dg = build_def_use_graph(self.instructions)
return schedule_kahn_with_preferences(groups, def_use_dg, self.instructions, self.original_ir)
# テスト例(非malloc call追加でclobberテスト)
if __name__ == "__main__":
test_bb = [
"%x = load float, float* %y",
"%z = mul float %x, %w",
"%s = add float %z, %v",
"store float %s, float* %c",
"%ptr = call i8* @malloc(i64 100)",
"store i32 42, i32* %ptr",
"call void @printf(i8* %msg)"
]
pass_sim = CommutativityAnalysisPass()
pass_sim.analyze_basic_block(test_bb)
groups = pass_sim.build_commutative_groups()
print("Instructions (R/W):")
for i, (r, w) in enumerate(pass_sim.instructions):
print(f" {i}: {test_bb[i]} -> reads={r}, writes={w}")
print("\nCommutative Groups:")
for g in groups:
print(f" Group: {g}")
n = len(test_bb)
total_pairs = n * (n - 1) // 2
comm_pairs = sum(1 for i in range(n) for j in range(i+1, n) if pass_sim.is_commutative_pair(i, j))
print(f"\nSelf-Diagnosis:")
print(f"n = {n}, total_pairs = {total_pairs}")
print(f"comm_pairs = {comm_pairs}")
print(f"Commutativity Rate: {comm_pairs / total_pairs * 100:.1f}%")
reordered = pass_sim.reorder_for_optimization()
print("\nReordered IR (Kahn + Preferences):")
for ir in reordered:
print(f" {ir}")
# Legacy UF比較
uf = UnionFind(n)
for i in range(n):
for j in range(i+1, n):
if pass_sim.is_commutative_pair(i, j):
uf.union(i, j)
uf_comps = uf.components()
print(f"\n(Legacy) UF Groups: {uf_comps} – 危険!")