summaryrefslogtreecommitdiff
path: root/scripts/mir_to_dot.py
blob: 0834e2414e96a30b28af0b282a3780f0c43d453d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import re
import argparse

class Link(object):
    def __init__(self, src, dst, label):
        self._src = 'bb%i' % (src,)
        self._dst = dst if isinstance(dst, str) else 'bb%i' % (dst,)
        self._label = label

def main():

    #cratename,pat = 'rustc_lint','fn .*expr_refers_to_this_method.*'
    cratename,pat = 'std','fn resize.*HashMap'
    #cratename,pat = 'rustc', 'fn tables.*::"rustc"::ty::context::TyCtxt'

    argp = argparse.ArgumentParser()
    argp.add_argument("--file", type=str)
    argp.add_argument("--crate", type=str)
    argp.add_argument("--fn-name", type=str, default='resize.*HashMap')
    args = argp.parse_args()

    pat = 'fn '+args.fn_name
    infile = args.file or ('output/'+args.crate+'.hir_3_mir.rs')

    fp = open(infile)
    start_pat = re.compile(pat)
    def_line = None
    for line in fp:
        line = line.strip()
        if start_pat.match(line) != None:
            print "# ",line
            def_line = line
            break

    if def_line is None:
        return

    for line in fp:
        if line.strip() == "bb0: {":
            break

    bbs = []
    cur_bb_lines = []
    level = 2
    for line in fp:
        line = line.strip()
        if line == "}":
            level -= 1
            if level == 0:
                break
            else:
                bbs.append( cur_bb_lines )
                cur_bb_lines = []
                continue

        if "bb" in line and ": {" in line:
            level += 1
            continue

        outstr = ""
        comment_level = 0
        i = 0
        while i < len(line):
            if comment_level > 0:
                if line[i:i+2] == '*/':
                    comment_level -= 1
                    i += 2
                    continue
            if line[i:i+2] == '/*':
                comment_level += 1
                i += 2
                continue 
            if comment_level == 0:
                outstr += line[i]
            i += 1
        print "#",len(bbs),outstr

        cur_bb_lines.append(outstr)

    goto_regex = re.compile('goto bb(\d+);$')
    call_regex = re.compile('.*goto bb(\d+) else bb(\d+)$')
    if_regex = re.compile('.*goto bb(\d+); } else { goto bb(\d+); }$')
    switch_regex = re.compile('(\d+) => bb(\d+),')

    links = []
    for idx,bb in enumerate(bbs):
        if bb[-1] == 'return;':
            links.append( Link(idx, 'return', "return") )
            continue
        if bb[-1] == 'diverge;':
            #links.append( Link(idx, 'panic', "diverge") )
            continue
        m = goto_regex.match(bb[-1])
        if m != None:
            links.append( Link(idx, int(m.group(1)), "") )
            continue
        m = call_regex.match(bb[-1])
        if m != None:
            links.append( Link(idx, int(m.group(1)), "ret") )
            #links.append( Link(idx, int(m.group(2)), "panic") )
            continue
        m = if_regex.match(bb[-1])
        if m != None:
            links.append( Link(idx, int(m.group(1)), "true") )
            links.append( Link(idx, int(m.group(2)), "false") )
            continue


        for m in switch_regex.finditer(bb[-1]):
            links.append( Link(idx, int(m.group(2)), "var%s" % (m.group(1),) ) )



    print "digraph {"
    print "node [shape=box, labeljust=l; fontname=\"mono\"];"
    for l in links:
        print '"%s" -> "%s" [label="%s"];' % (l._src, l._dst, l._label)

    print ""
    for idx,bb in enumerate(bbs):
        print '"bb%i" [label="BB%i:' % (idx,idx,),
        for stmt in bb:
            print '\\l',stmt.replace('"', '\\"'),
        print '"];'
    print "}"


main()