381 lines
13 KiB
Standard ML
381 lines
13 KiB
Standard ML
(*
|
|
* Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
|
|
*
|
|
* SPDX-License-Identifier: BSD-2-Clause
|
|
*)
|
|
|
|
(*
|
|
* Abstract data structures for representing and performing basic analysis on
|
|
* programs.
|
|
*)
|
|
structure Prog =
|
|
struct
|
|
|
|
(* Convenience abbreviations for set manipulation. *)
|
|
infix 1 INTER MINUS UNION
|
|
val empty_set = Varset.empty
|
|
val make_set = Varset.make
|
|
val union_sets = Varset.union_sets
|
|
fun (a INTER b) = Varset.inter a b
|
|
fun (a MINUS b) = Varset.subtract b a
|
|
fun (a UNION b) = Varset.union a b
|
|
|
|
(*
|
|
* A parsed program.
|
|
*
|
|
* "'a" is a generic type that appears at every node.
|
|
*
|
|
* "'e" is an expression type that appears once for each expression.
|
|
*
|
|
* "'m" is a modification type that appears only in modification clauses.
|
|
*
|
|
* "'c" is meta-data associated with calls.
|
|
*)
|
|
datatype ('a, 'e, 'm, 'c) prog =
|
|
Init of ('a * 'm)
|
|
| Modify of ('a * 'e * 'm)
|
|
| Guard of ('a * 'e)
|
|
| Throw of 'a
|
|
| Call of ('a * 'e list * 'e * 'm * 'c)
|
|
| Spec of ('a * 'e)
|
|
| Fail of 'a
|
|
| While of ('a * 'e * ('a, 'e, 'm, 'c) prog)
|
|
| Condition of ('a * 'e * ('a, 'e, 'm, 'c) prog * ('a, 'e, 'm, 'c) prog)
|
|
| Seq of ('a * ('a, 'e, 'm, 'c) prog * ('a, 'e, 'm, 'c) prog)
|
|
| Catch of ('a * ('a, 'e, 'm, 'c) prog * ('a, 'e, 'm, 'c) prog)
|
|
| RecGuard of ('a * ('a, 'e, 'm, 'c) prog)
|
|
|
|
datatype call_type = DecMeasure | NewMeasure
|
|
|
|
(* Extract the data associated with the node. *)
|
|
fun get_node_data prog =
|
|
case prog of
|
|
Init (a, _) => a
|
|
| Modify (a, _, _) => a
|
|
| Guard (a, _) => a
|
|
| Throw a => a
|
|
| Call (a, _, _, _, _) => a
|
|
| Spec (a, _) => a
|
|
| Fail a => a
|
|
| While (a, _, _) => a
|
|
| Condition (a, _, _, _) => a
|
|
| Seq (a, _, _) => a
|
|
| Catch (a, _, _) => a
|
|
| RecGuard (a, _) => a
|
|
|
|
(* Merge data payloads of two structurally identical programs. *)
|
|
fun zip_progs progA progB =
|
|
case (progA, progB) of
|
|
(Init (a1, m1), Init (a2, m2)) =>
|
|
Init ((a1, a2), (m1, m2))
|
|
| (Modify (a1, e1, m1), Modify (a2, e2, m2)) =>
|
|
Modify ((a1, a2), (e1, e2), (m1, m2))
|
|
| (Guard (a1, e1), Guard (a2, e2)) =>
|
|
Guard ((a1, a2), (e1, e2))
|
|
| (Throw a1, Throw a2) =>
|
|
Throw ((a1, a2))
|
|
| (Call (a1, e1, ee1, m1, c1), Call (a2, e2, ee2, m2, c2)) =>
|
|
Call ((a1, a2), Utils.zip e1 e2, (ee1, ee2), (m1, m2), (c1, c2))
|
|
| (Spec (a1, e1), Spec (a2, e2)) =>
|
|
Spec ((a1, a2), (e1, e2))
|
|
| (Fail a1, Fail a2) =>
|
|
Fail ((a1, a2))
|
|
| (While (a1, e1, body1), While (a2, e2, body2)) =>
|
|
While ((a1, a2), (e1, e2), zip_progs body1 body2)
|
|
| (Condition (a1, e1, lhs1, rhs1), Condition (a2, e2, lhs2, rhs2)) =>
|
|
Condition ((a1, a2), (e1, e2), zip_progs lhs1 lhs2, zip_progs rhs1 rhs2)
|
|
| (Seq (a1, lhs1, rhs1), Seq (a2, lhs2, rhs2)) =>
|
|
Seq ((a1, a2), zip_progs lhs1 lhs2, zip_progs rhs1 rhs2)
|
|
| (Catch (a1, lhs1, rhs1), Catch (a2, lhs2, rhs2)) =>
|
|
Catch ((a1, a2), zip_progs lhs1 lhs2, zip_progs rhs1 rhs2)
|
|
| (RecGuard (a1, body1), RecGuard (a2, body2)) =>
|
|
RecGuard ((a1, a2), zip_progs body1 body2)
|
|
| other =>
|
|
Utils.invalid_input "structurally identical programs" (@{make_string} other)
|
|
|
|
(* Map the data payloads of a given program. *)
|
|
fun map_prog node_fn expr_fn mod_fn call_fn prog =
|
|
case prog of
|
|
Init (a, m) => Init (node_fn a, mod_fn m)
|
|
| Modify (a, e, m) => Modify (node_fn a, expr_fn e, mod_fn m)
|
|
| Guard (a, e) => Guard (node_fn a, expr_fn e)
|
|
| Throw a => Throw (node_fn a)
|
|
| Call (a, e, ee, m, c) => Call (node_fn a, map expr_fn e, expr_fn ee, mod_fn m, call_fn c)
|
|
| Spec (a, e) => Spec (node_fn a, expr_fn e)
|
|
| Fail a => Fail (node_fn a)
|
|
| While (a, e, body) =>
|
|
While (node_fn a, expr_fn e, map_prog node_fn expr_fn mod_fn call_fn body)
|
|
| Condition (a, e, lhs, rhs) =>
|
|
Condition (node_fn a, expr_fn e,
|
|
map_prog node_fn expr_fn mod_fn call_fn lhs,
|
|
map_prog node_fn expr_fn mod_fn call_fn rhs)
|
|
| Seq (a, lhs, rhs) =>
|
|
Seq (node_fn a,
|
|
map_prog node_fn expr_fn mod_fn call_fn lhs,
|
|
map_prog node_fn expr_fn mod_fn call_fn rhs)
|
|
| Catch (a, lhs, rhs) =>
|
|
Catch (node_fn a,
|
|
map_prog node_fn expr_fn mod_fn call_fn lhs,
|
|
map_prog node_fn expr_fn mod_fn call_fn rhs)
|
|
| RecGuard (a, body) =>
|
|
RecGuard (node_fn a, map_prog node_fn expr_fn mod_fn call_fn body)
|
|
|
|
(* Fold nodes of the program together in pre-order. *)
|
|
fun fold_prog node_fn expr_fn mod_fn call_fn prog v =
|
|
case prog of
|
|
Init (a, m) => (node_fn a #> mod_fn m) v
|
|
| Modify (a, e, m) => (node_fn a #> expr_fn e #> mod_fn m) v
|
|
| Guard (a, e) => (node_fn a #> expr_fn e) v
|
|
| Throw a => (node_fn a) v
|
|
| Call (a, e, ee, m, c) =>
|
|
(node_fn a #> fold expr_fn e #> expr_fn ee #> mod_fn m #> call_fn c) v
|
|
| Spec (a, e) => (node_fn a #> expr_fn e) v
|
|
| Fail a => (node_fn a) v
|
|
| While (a, e, body) =>
|
|
(node_fn a #> expr_fn e
|
|
#> fold_prog node_fn expr_fn mod_fn call_fn body) v
|
|
| Condition (a, e, lhs, rhs) =>
|
|
(node_fn a #> expr_fn e
|
|
#> fold_prog node_fn expr_fn mod_fn call_fn lhs
|
|
#> fold_prog node_fn expr_fn mod_fn call_fn rhs) v
|
|
| Seq (a, lhs, rhs) =>
|
|
(node_fn a
|
|
#> fold_prog node_fn expr_fn mod_fn call_fn lhs
|
|
#> fold_prog node_fn expr_fn mod_fn call_fn rhs) v
|
|
| Catch (a, lhs, rhs) =>
|
|
(node_fn a
|
|
#> fold_prog node_fn expr_fn mod_fn call_fn lhs
|
|
#> fold_prog node_fn expr_fn mod_fn call_fn rhs) v
|
|
| RecGuard (a, body) =>
|
|
(node_fn a
|
|
#> fold_prog node_fn expr_fn mod_fn call_fn body) v
|
|
|
|
(*
|
|
* Perform a liveness analysis on the given program.
|
|
*
|
|
* Each node's data will contain the set of live variables _prior_ to the block
|
|
* being executed.
|
|
*
|
|
* For instance:
|
|
*
|
|
* Condition (a < 3) -- [a, b, c] live
|
|
* Modify (x := b) -- [b, x] live
|
|
* Modify (x := c) -- [c, x] live
|
|
* Modify (ret := x) -- [x] live
|
|
*)
|
|
local
|
|
fun calc_live_vars' term succ_live throw_live =
|
|
let
|
|
fun set_from_some x =
|
|
case x of
|
|
NONE => empty_set
|
|
| SOME x => make_set [x]
|
|
in
|
|
case term of
|
|
Init (old, written_vars) =>
|
|
Init (old UNION (succ_live MINUS (set_from_some written_vars)), written_vars)
|
|
| Modify (old, read_vars, written_vars) =>
|
|
Modify (old UNION read_vars UNION (succ_live MINUS (set_from_some written_vars)), read_vars, written_vars)
|
|
| Call (old, read_vars, ret_read_vars, written_vars, call_data) =>
|
|
Call (old UNION (union_sets read_vars)
|
|
UNION ret_read_vars
|
|
UNION (succ_live MINUS (set_from_some written_vars)),
|
|
read_vars, ret_read_vars, written_vars, call_data)
|
|
| Guard (old, read_vars) =>
|
|
Guard (old UNION succ_live UNION read_vars, read_vars)
|
|
| Throw _ =>
|
|
Throw throw_live
|
|
| Spec (old, read_vars) =>
|
|
Spec (old UNION read_vars, read_vars)
|
|
| Fail _ =>
|
|
Fail succ_live
|
|
| While (old, read_vars, body) =>
|
|
let
|
|
val new_body = calc_live_vars' body (succ_live UNION old) throw_live
|
|
val body_live = get_node_data new_body
|
|
in
|
|
While (old UNION body_live UNION read_vars, read_vars, new_body)
|
|
end
|
|
| Condition (old, read_vars, lhs, rhs) =>
|
|
let
|
|
val new_lhs = calc_live_vars' lhs succ_live throw_live
|
|
val lhs_live = get_node_data new_lhs
|
|
val new_rhs = calc_live_vars' rhs succ_live throw_live
|
|
val rhs_live = get_node_data new_rhs
|
|
in
|
|
Condition (old UNION lhs_live UNION rhs_live UNION read_vars, read_vars, new_lhs, new_rhs)
|
|
end
|
|
| Seq (old, lhs, rhs) =>
|
|
let
|
|
val new_rhs = calc_live_vars' rhs succ_live throw_live
|
|
val rhs_live = get_node_data new_rhs
|
|
val new_lhs = calc_live_vars' lhs rhs_live throw_live
|
|
val lhs_live = get_node_data new_lhs
|
|
in
|
|
Seq (old UNION lhs_live, new_lhs, new_rhs)
|
|
end
|
|
| Catch (old, lhs, rhs) =>
|
|
let
|
|
val new_rhs = calc_live_vars' rhs succ_live throw_live
|
|
val rhs_live = get_node_data new_rhs
|
|
val new_lhs = calc_live_vars' lhs succ_live rhs_live
|
|
val lhs_live = get_node_data new_lhs
|
|
in
|
|
Catch (old UNION lhs_live, new_lhs, new_rhs)
|
|
end
|
|
| RecGuard (old, body) =>
|
|
let
|
|
val new_body = calc_live_vars' body succ_live throw_live
|
|
val body_live = get_node_data new_body
|
|
in
|
|
RecGuard (old UNION body_live, new_body)
|
|
end
|
|
end
|
|
in
|
|
fun calc_live_vars prog output_vars =
|
|
let
|
|
val init = map_prog (fn _ => empty_set) (fn (_, a, _) => a) (fn x => x) (fn x => x) prog
|
|
in
|
|
Utils.fixpoint (fn x => calc_live_vars' x output_vars empty_set) (op =) init
|
|
end
|
|
end
|
|
|
|
|
|
(*
|
|
* Get the variables read by each block of code.
|
|
*
|
|
* Each node's data will contain the set of variables read in the
|
|
* given block.
|
|
*)
|
|
fun get_read_vars term =
|
|
case term of
|
|
Init (_, written_vars) =>
|
|
Init (empty_set, written_vars)
|
|
| Modify (_, read_vars, written_vars) =>
|
|
Modify (read_vars, read_vars, written_vars)
|
|
| Call (_, read_vars, ret_read_vars, written_vars, call_data) =>
|
|
Call (union_sets read_vars, read_vars, ret_read_vars, written_vars, call_data)
|
|
| Guard (_, read_vars) =>
|
|
Guard (read_vars, read_vars)
|
|
| Throw _ =>
|
|
Throw empty_set
|
|
| Spec (_, read_vars) =>
|
|
Spec (read_vars, read_vars)
|
|
| Fail _ =>
|
|
Fail empty_set
|
|
| While (_, read_vars, body) =>
|
|
let
|
|
val new_body = get_read_vars body
|
|
val new_reads = get_node_data new_body
|
|
in
|
|
While (new_reads UNION read_vars, read_vars, new_body)
|
|
end
|
|
| Condition (_, read_vars, lhs, rhs) =>
|
|
let
|
|
val new_lhs = get_read_vars lhs
|
|
val new_rhs = get_read_vars rhs
|
|
val new_reads = get_node_data new_lhs UNION get_node_data new_rhs
|
|
in
|
|
Condition (new_reads UNION read_vars, read_vars, new_lhs, new_rhs)
|
|
end
|
|
| Seq (_, lhs, rhs) =>
|
|
let
|
|
val new_lhs = get_read_vars lhs
|
|
val new_rhs = get_read_vars rhs
|
|
val new_reads = get_node_data new_lhs UNION get_node_data new_rhs
|
|
in
|
|
Seq (new_reads, new_lhs, new_rhs)
|
|
end
|
|
| Catch (_, lhs, rhs) =>
|
|
let
|
|
val new_lhs = get_read_vars lhs
|
|
val new_rhs = get_read_vars rhs
|
|
val new_reads = get_node_data new_lhs UNION get_node_data new_rhs
|
|
in
|
|
Catch (new_reads, new_lhs, new_rhs)
|
|
end
|
|
| RecGuard (_, body) =>
|
|
let
|
|
val new_body = get_read_vars body
|
|
val new_reads = get_node_data new_body
|
|
in
|
|
RecGuard (new_reads, new_body)
|
|
end
|
|
|
|
(*
|
|
* Get the variables modified by each block of code.
|
|
*
|
|
* Each node's data will contain the set of variables modified in the
|
|
* given block.
|
|
*)
|
|
fun get_modified_vars term =
|
|
let
|
|
(* Union variables, treating "NONE" as the set UNIV. *)
|
|
infix UNION'
|
|
fun (_ UNION' NONE) = NONE
|
|
| (NONE UNION' _) = NONE
|
|
| ((SOME x) UNION' (SOME y)) = SOME (x UNION y)
|
|
|
|
(* Create a set from the given list, treating NONE as empty. *)
|
|
fun set_from_some x =
|
|
case x of
|
|
NONE => empty_set
|
|
| SOME x => make_set [x]
|
|
in
|
|
case term of
|
|
Init (_, written_vars) =>
|
|
Init (SOME (set_from_some written_vars), written_vars)
|
|
| Modify (_, read_vars, written_vars) =>
|
|
Modify (SOME (set_from_some written_vars), read_vars, written_vars)
|
|
| Call (_, read_vars, ret_read_vars, written_vars, call_data) =>
|
|
Call (SOME (set_from_some written_vars), read_vars, ret_read_vars, written_vars, call_data)
|
|
| Guard (_, read_vars) =>
|
|
Guard (SOME empty_set, read_vars)
|
|
| Throw _ =>
|
|
Throw (SOME empty_set)
|
|
| Spec (_, read_vars) =>
|
|
Spec (NONE, read_vars)
|
|
| Fail _ =>
|
|
Fail (SOME empty_set)
|
|
| While (_, read_vars, body) =>
|
|
let
|
|
val new_body = get_modified_vars body
|
|
val new_modifies = get_node_data new_body
|
|
in
|
|
While (new_modifies, read_vars, new_body)
|
|
end
|
|
| Condition (_, read_vars, lhs, rhs) =>
|
|
let
|
|
val new_lhs = get_modified_vars lhs
|
|
val new_rhs = get_modified_vars rhs
|
|
val new_modifies = get_node_data new_lhs UNION' get_node_data new_rhs
|
|
in
|
|
Condition (new_modifies, read_vars, new_lhs, new_rhs)
|
|
end
|
|
| Seq (_, lhs, rhs) =>
|
|
let
|
|
val new_lhs = get_modified_vars lhs
|
|
val new_rhs = get_modified_vars rhs
|
|
val new_modifies = get_node_data new_lhs UNION' get_node_data new_rhs
|
|
in
|
|
Seq (new_modifies, new_lhs, new_rhs)
|
|
end
|
|
| Catch (_, lhs, rhs) =>
|
|
let
|
|
val new_lhs = get_modified_vars lhs
|
|
val new_rhs = get_modified_vars rhs
|
|
val new_modifies = get_node_data new_lhs UNION' get_node_data new_rhs
|
|
in
|
|
Catch (new_modifies, new_lhs, new_rhs)
|
|
end
|
|
| RecGuard (_, body) =>
|
|
let
|
|
val new_body = get_modified_vars body
|
|
val new_modifies = get_node_data new_body
|
|
in
|
|
RecGuard (new_modifies, new_body)
|
|
end
|
|
end
|
|
|
|
end
|