Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move collect_bicolor_runs() functionality to rustworkx-core #1166

Open
Tracked by #1121
mtreinish opened this issue Apr 19, 2024 · 1 comment
Open
Tracked by #1121

Move collect_bicolor_runs() functionality to rustworkx-core #1166

mtreinish opened this issue Apr 19, 2024 · 1 comment
Assignees
Labels
rustworkx-core Issues tracking adding functionality to rustworkx-core

Comments

@mtreinish
Copy link
Member

mtreinish commented Apr 19, 2024

The collect_bicolor_runs() function is only exposed via a Python interface currently:

/// Collect runs that match a filter function given edge colors
///
/// A bicolor run is a list of group of nodes connected by edges of exactly
/// two colors. In addition, all nodes in the group must match the given
/// condition. Each node in the graph can appear in only a single group
/// in the bicolor run.
///
/// :param PyDiGraph graph: The graph to find runs in
/// :param filter_fn: The filter function to use for matching nodes. It takes
/// in one argument, the node data payload/weight object, and will return a
/// boolean whether the node matches the conditions or not.
/// If it returns ``True``, it will continue the bicolor chain.
/// If it returns ``False``, it will stop the bicolor chain.
/// If it returns ``None`` it will skip that node.
/// :param color_fn: The function that gives the color of the edge. It takes
/// in one argument, the edge data payload/weight object, and will
/// return a non-negative integer, the edge color. If the color is None,
/// the edge is ignored.
///
/// :returns: a list of groups with exactly two edge colors, where each group
/// is a list of node data payload/weight for the nodes in the bicolor run
/// :rtype: list
#[pyfunction]
#[pyo3(text_signature = "(graph, filter_fn, color_fn)")]
pub fn collect_bicolor_runs(
py: Python,
graph: &digraph::PyDiGraph,
filter_fn: PyObject,
color_fn: PyObject,
) -> PyResult<Vec<Vec<PyObject>>> {
let mut pending_list: Vec<Vec<PyObject>> = Vec::new();
let mut block_id: Vec<Option<usize>> = Vec::new();
let mut block_list: Vec<Vec<PyObject>> = Vec::new();
let filter_node = |node: &PyObject| -> PyResult<Option<bool>> {
let res = filter_fn.call1(py, (node,))?;
res.extract(py)
};
let color_edge = |edge: &PyObject| -> PyResult<Option<usize>> {
let res = color_fn.call1(py, (edge,))?;
res.extract(py)
};
let nodes = match algo::toposort(&graph.graph, None) {
Ok(nodes) => nodes,
Err(_err) => return Err(DAGHasCycle::new_err("Sort encountered a cycle")),
};
// Utility for ensuring pending_list has the color index
macro_rules! ensure_vector_has_index {
($pending_list: expr, $block_id: expr, $color: expr) => {
if $color >= $pending_list.len() {
$pending_list.resize($color + 1, Vec::new());
$block_id.resize($color + 1, None);
}
};
}
for node in nodes {
if let Some(is_match) = filter_node(&graph.graph[node])? {
let raw_edges = graph
.graph
.edges_directed(node, petgraph::Direction::Outgoing);
// Remove all edges that do not yield errors from color_fn
let colors = raw_edges
.map(|edge| {
let edge_weight = edge.weight();
color_edge(edge_weight)
})
.collect::<PyResult<Vec<Option<usize>>>>()?;
// Remove null edges from color_fn
let colors = colors.into_iter().flatten().collect::<Vec<usize>>();
if colors.len() <= 2 && is_match {
if colors.len() == 1 {
let c0 = colors[0];
ensure_vector_has_index!(pending_list, block_id, c0);
if let Some(c0_block_id) = block_id[c0] {
block_list[c0_block_id].push(graph.graph[node].clone_ref(py));
} else {
pending_list[c0].push(graph.graph[node].clone_ref(py));
}
} else if colors.len() == 2 {
let c0 = colors[0];
let c1 = colors[1];
ensure_vector_has_index!(pending_list, block_id, c0);
ensure_vector_has_index!(pending_list, block_id, c1);
if block_id[c0].is_some()
&& block_id[c1].is_some()
&& block_id[c0] == block_id[c1]
{
block_list[block_id[c0].unwrap_or_default()]
.push(graph.graph[node].clone_ref(py));
} else {
let mut new_block: Vec<PyObject> =
Vec::with_capacity(pending_list[c0].len() + pending_list[c1].len() + 1);
// Clears pending lits and add to new block
new_block.append(&mut pending_list[c0]);
new_block.append(&mut pending_list[c1]);
new_block.push(graph.graph[node].clone_ref(py));
// Create new block, assign its id to color pair
block_id[c0] = Some(block_list.len());
block_id[c1] = Some(block_list.len());
block_list.push(new_block);
}
}
} else {
for color in colors {
ensure_vector_has_index!(pending_list, block_id, color);
if let Some(color_block_id) = block_id[color] {
block_list[color_block_id].append(&mut pending_list[color]);
}
block_id[color] = None;
pending_list[color].clear();
}
}
}
}
Ok(block_list)
}

We should port it to rustworkx-core, so that rust users can leverage the function.

One tweak that probably makes sense for the rustworkx-core version is that instead of returning a Vec of Vecs of node weights we should have it return of an iterator of Vecs node ids. This would be more flexible and performant for rust space users and for the python side of rustworkx we can just collect the iterator to a Vec (for backwards compatibility).

@ElePT
Copy link

ElePT commented Apr 26, 2024

I'm interested!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
rustworkx-core Issues tracking adding functionality to rustworkx-core
Projects
None yet
Development

No branches or pull requests

2 participants