SlimeCompiler Toy Code: The True Ultimate Final Form (Arithmetic Consistency + Self-Diagnostic Template Added)

 

Corrected the pointed-out arithmetic error (total_pairs=21, comm_pairs=13, rate 61.9%) and added a self-diagnostic template (n, total_pairs, comm_pairs) to the output. This makes consistency checks immediately clear each time, eliminating debate discrepancies. Store parsing hardened + mem_all reads limited to “unbreakable + not over-restrictive” ironclad balance (add→store RAW preserved, remaining aliases commutative). Tool execution verified: 61.9% rate (13/21, 8 non-commutative pairs: RAW/WAW/clobber captured), groups [0,2,4]/[1,6]/[3,5] (all pairs commutative), reorder dependency preserved (add→store order maintained, load/malloc alternating hiding, printf barrier). Correctness check (topo order + all pairs in groups commutative) cleared! Execution results (tool output, printf added for testing)

text
Instructions (R/W):
  0: %x = load float, float* %y -> reads={'mem_all', 'y', 'mem_y'}, writes={'x'}
  1: %z = mul float %x, %w -> reads={'w', 'x'}, writes={'z'}
  2: %s = add float %z, %v -> reads={'v', 'z'}, writes={'s'}
  3: store float %s, float* %c -> reads={'mem_all', 'mem_c', 'c', 's'}, writes={'mem_c'}
  4: %ptr = call i8* @malloc(i64 100) -> reads={'mem_all'}, writes={'mem_ptr', 'ptr'}
  5: store i32 42, i32* %ptr -> reads={'mem_ptr', 'mem_all', 'ptr'}, writes={'mem_ptr'}
  6: call void @printf(i8* %msg) -> reads={'mem_all'}, writes={'mem_all'}

Commutative Groups:
  Group: [0, 2, 4]
  Group: [1, 6]
  Group: [3, 5]

Self-Diagnosis:
n = 7, total_pairs = 21
comm_pairs = 13
Commutativity Rate: 61.9%

Reordered IR (Kahn + Preferences):
  %x = load float, float* %y
  %ptr = call i8* @malloc(i64 100)
  %z = mul float %x, %w
  %s = add float %z, %v
  store float %s, float* %c
  store i32 42, i32* %ptr
  call void @printf(i8* %msg)

(Legacy) UF Groups: [[0, 1, 2, 3, 4, 5, 6]] – 危険!

 

  • Consistency: Rate 13/21 = 61.9% (Non-commutative 8: 0-1 x RAW, 1-2 z RAW, 2-3 s RAW, 4-5 mem_ptr RAW/WAW, + clobber 3-6/5-6/6 others). Fully commutative within groups ([0,2,4]: y/mem_y-v/z-mem_ptr independent; [3,5]: s/mem_c-ptr/mem_ptr independent). Kahn's add→store rule holds (s reads captured).
  • Stability: Criexort min small priority ([0,2,4] min0 < [3,5] min3).

Complete code (after patch application, executable)

 

Python
import networkx as nx
from collections import defaultdict
from typing import List, Tuple, Set
import re

# Union-Find (legacyオプション)
class UnionFind:
    def __init__(self, size: int):
        self.parent = list(range(size))
        self.rank = [0] * size

    def find(self, p: int) -> int:
        if self.parent[p] != p:
            self.parent[p] = self.find(self.parent[p])
        return self.parent[p]

    def union(self, p: int, q: int):
        rootP = self.find(p)
        rootQ = self.find(q)
        if rootP == rootQ:
            return
        if self.rank[rootP] < self.rank[rootQ]:
            self.parent[rootP] = rootQ
        elif self.rank[rootP] > self.rank[rootQ]:
            self.parent[rootQ] = rootP
        else:
            self.parent[rootQ] = rootP
            self.rank[rootP] += 1

    def components(self) -> List[List[int]]:
        groups = defaultdict(list)
        for i in range(len(self.parent)):
            groups[self.find(i)].append(i)
        return list(groups.values())

def parse_llvm_ir(instr: str) -> Tuple[Set[str], Set[str]]:
    """強化R/W抽出: store addr read/mem_alias, 定数val skip, SSA限定reads, call clobber."""
    reads = set()
    writes = set()
    clean = re.sub(r'%', '', instr)
    if 'load' in instr:
        ptr_match = re.search(r'load\s+.*?, \s*(\w+)\*\s*(\w+)', clean)
        if ptr_match:
            _, ptr = ptr_match.groups()
            reads.add(ptr)
            reads.add(f'mem_{ptr}')
        lhs_match = re.search(r'^(\w+)\s*=', clean)
        if lhs_match:
            writes.add(lhs_match.group(1))
    elif 'store' in instr:
        # 堅牢regex: store val, *ptr
        m = re.search(r'\bstore\b.*?\s([A-Za-z_]\w*|\d+)\s*,\s*.*?\*\s*([A-Za-z_]\w*)\b', clean)
        if m:
            val, ptr = m.groups()
            if not val.isdigit():  # 定数(42) skip
                reads.add(val)
            reads.add(ptr)
            reads.add(f'mem_{ptr}')
            writes.add(f'mem_{ptr}')
    elif '=' in instr and ('add' in instr or 'mul' in instr):
        parts = clean.split('=')
        lhs = parts[0].strip()
        writes.add(lhs)
        rhs = parts[1].strip()
        ops = re.findall(r'\b([a-zA-Z_][a-zA-Z0-9_]*)\b', rhs)
        for op in ops:
            if op not in ['add', 'mul', 'float', 'i32', 'i64', 'i8', 'call']:
                reads.add(op)
    elif 'call' in instr:
        lhs_match = re.search(r'^(\w+)\s*=', clean)
        if lhs_match:
            writes.add(lhs_match.group(1))

        if 'malloc' in instr:
            writes.add('mem_ptr')
        else:
            reads.add('mem_all')
            writes.add('mem_all')
    if 'volatile' in instr or 'atomic' in instr:
        reads.add('volatile_mem')
        writes.add('volatile_mem')
    
    # mem_*が出たらmem_allをreads側だけに追加(clobber衝突保証、writeはalias限定)
    if any(v.startswith("mem_") for v in reads | writes):
        reads.add("mem_all")
    
    return reads, writes

def is_commutative(i1: Tuple[Set[str], Set[str]], i2: Tuple[Set[str], Set[str]]) -> bool:
    r1, w1 = i1
    r2, w2 = i2
    return w1.isdisjoint(r2 | w2) and w2.isdisjoint(r1 | w1)

def commutativity_graph(instructions: List[Tuple[Set[str], Set[str]]]):
    G = nx.Graph()
    n = len(instructions)
    G.add_nodes_from(range(n))
    for i in range(n):
        for j in range(i+1, n):
            if is_commutative(instructions[i], instructions[j]):
                G.add_edge(i, j)
    return G

def greedy_disjoint_clique_cover(G):
    cliques = list(nx.find_cliques(G))
    cliques.sort(key=lambda c: (-len(c), -sum(G.degree(v) for v in c), min(c)))
    used = set()
    groups = []
    for c in cliques:
        c2 = [v for v in c if v not in used]
        if len(c2) >= 2:
            groups.append(sorted(c2))
            used.update(c2)
    for v in G.nodes():
        if v not in used:
            groups.append([v])
    groups.sort(key=lambda g: min(g))
    return groups

def analyze_commutativity(instructions: List[Tuple[Set[str], Set[str]]]) -> List[List[int]]:
    G = commutativity_graph(instructions)
    return greedy_disjoint_clique_cover(G)

def build_def_use_graph(instructions: List[Tuple[Set[str], Set[str]]]) -> nx.DiGraph:
    DG = nx.DiGraph()
    n = len(instructions)
    DG.add_nodes_from(range(n))
    for i in range(n):
        for j in range(i+1, n):
            if not is_commutative(instructions[i], instructions[j]):
                DG.add_edge(i, j)
    return DG

def schedule_kahn_with_preferences(groups, def_use_dg, instructions, original_ir):
    group_rank = {}
    for gi, g in enumerate(groups):
        for v in g:
            group_rank[v] = gi
    indeg = {u: def_use_dg.in_degree(u) for u in def_use_dg.nodes()}
    ready = [u for u, d in indeg.items() if d == 0]
    out = []
    while ready:
        ready.sort(key=lambda u: (group_rank.get(u, 10**9), -len(instructions[u][0]), u))
        u = ready.pop(0)
        out.append(original_ir[u])
        for v in list(def_use_dg.successors(u)):
            indeg[v] -= 1
            if indeg[v] == 0:
                ready.append(v)
    if len(out) != len(original_ir):
        raise ValueError("Cycle detected in def-use graph (or missing nodes).")
    return out

class CommutativityAnalysisPass:
    def __init__(self):
        self.instructions = []
        self.original_ir = []
    
    def analyze_basic_block(self, bb_ir: List[str]):
        self.original_ir = bb_ir
        self.instructions = [parse_llvm_ir(ir) for ir in bb_ir]
    
    def get_read_set(self, idx: int) -> Set[str]:
        return self.instructions[idx][0] if idx < len(self.instructions) else set()
    
    def get_write_set(self, idx: int) -> Set[str]:
        return self.instructions[idx][1] if idx < len(self.instructions) else set()
    
    def is_commutative_pair(self, i1: int, i2: int) -> bool:
        return is_commutative(self.instructions[i1], self.instructions[i2])
    
    def build_commutative_groups(self) -> List[List[int]]:
        return analyze_commutativity(self.instructions)
    
    def reorder_for_optimization(self) -> List[str]:
        groups = self.build_commutative_groups()
        def_use_dg = build_def_use_graph(self.instructions)
        return schedule_kahn_with_preferences(groups, def_use_dg, self.instructions, self.original_ir)

# テスト例(非malloc call追加でclobberテスト)
if __name__ == "__main__":
    test_bb = [
        "%x = load float, float* %y",
        "%z = mul float %x, %w",
        "%s = add float %z, %v",
        "store float %s, float* %c",
        "%ptr = call i8* @malloc(i64 100)",
        "store i32 42, i32* %ptr",
        "call void @printf(i8* %msg)"
    ]
    pass_sim = CommutativityAnalysisPass()
    pass_sim.analyze_basic_block(test_bb)
    groups = pass_sim.build_commutative_groups()

    print("Instructions (R/W):")
    for i, (r, w) in enumerate(pass_sim.instructions):
        print(f"  {i}: {test_bb[i]} -> reads={r}, writes={w}")

    print("\nCommutative Groups:")
    for g in groups:
        print(f"  Group: {g}")

    n = len(test_bb)
    total_pairs = n * (n - 1) // 2
    comm_pairs = sum(1 for i in range(n) for j in range(i+1, n) if pass_sim.is_commutative_pair(i, j))
    print(f"\nSelf-Diagnosis:")
    print(f"n = {n}, total_pairs = {total_pairs}")
    print(f"comm_pairs = {comm_pairs}")
    print(f"Commutativity Rate: {comm_pairs / total_pairs * 100:.1f}%")

    reordered = pass_sim.reorder_for_optimization()
    print("\nReordered IR (Kahn + Preferences):")
    for ir in reordered:
        print(f"  {ir}")

    # Legacy UF比較
    uf = UnionFind(n)
    for i in range(n):
        for j in range(i+1, n):
            if pass_sim.is_commutative_pair(i, j):
                uf.union(i, j)
    uf_comps = uf.components()
    print(f"\n(Legacy) UF Groups: {uf_comps} – 危険!")