use super::Outcome;
use crate::optimization;
use libfirm_rs::{
nodes::{Add, Div, Mod, Mul, Node, NodeTrait},
Graph, Mode, Tarval, TarvalKind,
};
pub struct NodeLocal {
graph: Graph,
changed: Outcome,
}
impl optimization::Local for NodeLocal {
fn optimize_function(graph: Graph) -> Outcome {
Self::new(graph).run()
}
}
impl NodeLocal {
fn new(graph: Graph) -> Self {
Self {
graph,
changed: Outcome::Unchanged,
}
}
fn try_optimize_add(&mut self, add: Add) {
let left = add.left();
let right = add.right();
if left == right {
log::debug!(
"LO: self addition {:?} [left:{:?},right:{:?}] \
replaced by '{:?} << 2'",
add,
left,
right,
left
);
let shift_operand_mode = if let Some(mode) = mode_to_unsigned(right.mode()) {
mode
} else {
return;
};
let tarval_1 = self.graph.new_const(Tarval::val(1, shift_operand_mode));
let shl = add.block().new_shl(left, tarval_1);
Graph::exchange(add, shl);
self.changed = Outcome::Changed;
}
}
fn try_optimize_mul(&mut self, mul: Mul) {
let left = mul.left();
let right = mul.right();
let (power_node, other, power) = {
if let Node::Const(op1) = left {
if let TarvalKind::Long(val) = op1.tarval().kind() {
(left, right, val)
} else {
return;
}
} else if let Node::Const(op2) = right {
if let TarvalKind::Long(val) = op2.tarval().kind() {
(right, left, val)
} else {
return;
}
} else {
return;
}
};
let abs_power = if let Some(abs) = power.checked_abs() {
abs as u64
} else {
return;
};
let has_minus = power < 0;
if abs_power.is_power_of_two() {
let shift_amount = 64 - 1 - abs_power.leading_zeros();
let shift_operand_mode = if let Some(mode) = mode_to_unsigned(power_node.mode()) {
mode
} else {
return;
};
let shift_amount_node = self
.graph
.new_const(Tarval::val(i64::from(shift_amount), shift_operand_mode));
let shl = mul.block().new_shl(other, shift_amount_node);
let shl_end = if has_minus {
Node::Minus(mul.block().new_minus(shl))
} else {
Node::Shl(shl)
};
log::debug!(
"LO: Mul2Shift: {:?} [left:{:?},right:{:?}] replaced by '{:?} << {}'",
mul,
left,
right,
shl,
shift_amount
);
Graph::exchange(mul, shl_end);
self.changed = Outcome::Changed;
}
}
fn try_optimize_mod(&mut self, modulo: Mod) {
log::debug!(
"LO: Mod2Shift: {:?}[left:{:?},right:{:?}] ",
modulo,
modulo.left(),
modulo.right(),
);
let divisor = match modulo.right() {
Node::Conv(conv) => {
log::debug!("Mod2Shift: with conv {:?}!", conv);
if let Node::Const(divisor) = conv.op() {
divisor
} else {
return;
}
}
Node::Const(divisor) => divisor,
_ => {
return;
}
};
log::debug!("Mod2Shift: divisor {:?}!", divisor);
if let TarvalKind::Long(divisor_value) = divisor.tarval().kind() {
log::debug!("Mod2Shift: divisor value {:?}!", divisor_value);
if divisor_value > std::i32::MAX.into() || divisor_value < std::i32::MIN.into() {
log::debug!("Mod2Shift: aborting since value is not in i32 range!",);
return;
}
let abs_divisor_value = divisor_value.abs() as u64;
if abs_divisor_value.is_power_of_two() {
log::debug!("Mod2Shift: is power of two!");
let shift_amount = 64 - 1 - abs_divisor_value.leading_zeros();
let shift_amount_node = self
.graph
.new_const(Tarval::val(i64::from(shift_amount), Mode::Iu()));
let const_31 = self.graph.new_const(Tarval::val(31, Mode::Iu()));
let modulo_proj_res = if let Some(res) = modulo.out_proj_res() {
res
} else {
return;
};
if modulo_proj_res.out_nodes().len() != 1 {
return;
}
let modulo_end = modulo_proj_res.out_nodes().nth(0).unwrap();
if modulo.left().in_nodes().len() != 1 {
return;
}
let real_left = modulo.left().in_nodes().nth(0).unwrap();
let block = modulo.block();
let shr_by_31 = block.new_shrs(real_left, const_31);
let mask_const = (1i32 << shift_amount) - 1;
let mask_const_node = self.graph.new_const(Tarval::mj_int(i64::from(mask_const)));
let binary_and = block.new_and(shr_by_31, mask_const_node);
let add_binary_and = block.new_add(real_left, binary_and);
let shift_to_result = block.new_shrs(add_binary_and, shift_amount_node);
let mul_result_by_divisor = block.new_shl(shift_to_result, shift_amount_node);
let modulo_subst_end = block.new_sub(real_left, mul_result_by_divisor);
log::debug!(
"LO: Mod2Shift: memory edge through modulo {:?} is {:?} -> {:?}",
modulo,
modulo.mem(),
modulo.out_proj_m(),
);
let modulo_proj_mem = if let Some(mem) = modulo.out_proj_m() {
mem
} else {
return;
};
Graph::exchange(modulo_proj_mem, modulo.mem());
log::debug!(
"LO: Mod2Shift: {:?}[left:{:?},right:{:?}] replaced",
modulo,
modulo.left(),
modulo.right(),
);
Graph::exchange(modulo_end, modulo_subst_end);
self.changed = Outcome::Changed;
}
}
}
fn try_optimize_div(&mut self, div: Div) {
log::debug!(
"LO: Div2Shift: {:?}[left:{:?},right:{:?}] ",
div,
div.left(),
div.right(),
);
let divisor = match div.right() {
Node::Conv(conv) => {
log::debug!("Div2Shift: with conv {:?}!", conv);
if let Node::Const(divisor) = conv.op() {
divisor
} else {
return;
}
}
Node::Const(divisor) => divisor,
_ => {
return;
}
};
log::debug!("Div2Shift: divisor {:?}!", divisor);
if let TarvalKind::Long(divisor_value) = divisor.tarval().kind() {
log::debug!("Div2Shift: divisor value {:?}!", divisor_value);
if divisor_value > std::i32::MAX.into() || divisor_value < std::i32::MIN.into() {
log::debug!("Div2Shift: aborting since value is not in i32 range!",);
return;
}
let abs_divisor_value = divisor_value.abs() as u64;
let has_minus = divisor_value < 0;
if abs_divisor_value.is_power_of_two() {
log::debug!("Div2Shift: is power of two!");
let shift_amount = 64 - 1 - abs_divisor_value.leading_zeros();
let shift_amount_node = self
.graph
.new_const(Tarval::val(i64::from(shift_amount), Mode::Iu()));
let const_31 = self.graph.new_const(Tarval::val(31, Mode::Iu()));
let div_proj_res = if let Some(res) = div.out_proj_res() {
res
} else {
return;
};
if div_proj_res.out_nodes().len() != 1 {
return;
}
let div_end = div_proj_res.out_nodes().nth(0).unwrap();
if div.left().in_nodes().len() != 1 {
return;
}
let real_left = div.left().in_nodes().nth(0).unwrap();
let block = div.block();
let shr_by_31 = block.new_shrs(real_left, const_31);
let mask_const = (1i32 << shift_amount) - 1;
let mask_const_node = self.graph.new_const(Tarval::mj_int(i64::from(mask_const)));
let binary_and = block.new_and(shr_by_31, mask_const_node);
let add_binary_and = block.new_add(real_left, binary_and);
let shift_to_result = block.new_shrs(add_binary_and, shift_amount_node);
let shr_end = if has_minus {
Node::Minus(div.block().new_minus(shift_to_result))
} else {
Node::Shrs(shift_to_result)
};
log::debug!(
"LO: Div2Shift: memory edge through div {:?} is {:?} -> {:?}",
div,
div.mem(),
div.out_proj_m(),
);
let div_proj_mem = if let Some(mem) = div.out_proj_m() {
mem
} else {
return;
};
Graph::exchange(div_proj_mem, div.mem());
log::debug!(
"LO: Div2Shift: {:?}[left:{:?},right:{:?}] replaced by '>> {}'",
div,
div.left(),
div.right(),
shift_amount
);
Graph::exchange(div_end, shr_end);
self.changed = Outcome::Changed;
}
}
}
fn visit_node(&mut self, current_node: Node) {
log::debug!("LO: visiting {:?}", current_node);
match current_node {
Node::Add(add) => self.try_optimize_add(add),
Node::Mul(mul) => self.try_optimize_mul(mul),
Node::Mod(modulo) => self.try_optimize_mod(modulo),
Node::Div(div) => self.try_optimize_div(div),
_ => {}
}
}
fn run(&mut self) -> Outcome {
self.changed = Outcome::Unchanged;
self.graph.assure_outs();
self.graph.walk_topological(|node| {
self.visit_node(*node);
});
if self.changed == Outcome::Changed {
self.graph.remove_unreachable_code();
self.graph.remove_bads();
}
self.changed
}
}
fn mode_to_unsigned(mode: Mode) -> Option<Mode> {
if mode == Mode::Is() {
Some(Mode::Iu())
} else if mode == Mode::Ls() {
Some(Mode::Lu())
} else {
None
}
}