I'm trying to find cycles in an undirected, unweighted graph. In [node, node] format. Here is the code I wrote:
def find_cycles(graph):
cycles = []
def dfs(node, visited, path):
visited.add(node)
path.append(node)
neighbors = graph.get(node, [])
for neighbor in neighbors:
if neighbor in visited:
# Cycle detected
start_index = path.index(neighbor)
cycle = path[start_index:]
if cycle not in cycles:
cycles.append(cycle)
else:
dfs(neighbor, visited, path)
visited.remove(node)
path.pop()
for node in graph.keys():
dfs(node, set(), [])
return cycles
graph = {
# 'node': ['adjacent'],
}
n, m = map(int, input().split())
for _ in range(m):
a, b = map(int, input().split())
if b not in graph:
graph[b] = []
if a not in graph:
graph[a] = []
if a not in graph[b]:
graph[b].append(a)
if b not in graph[a]:
graph[a].append(b)
ans = find_cycles(graph)
print(ans)
print(len(ans))
In the test case:
10 10
3 6
9 3
1 7
1 2
4 7
7 6
2 9
2 6
3 4
6 0
I know that the shortest cycle length is 4, but it prints a wrong list containing 92 items, with the shortest one being of length 2. What is wrong in my code?
I've modified your code very slightly. Instead of checking if cycle in cycles
, I'm maintaining a set of the paths we have seen, with the nodes in sorted order. If the new cycle is not present in the set, then I add it to the list of cycles. I also discard any cycle with only two edges.
With this, I get 7 cycles, and I think it is correct.
data = """10 10
3 6
9 3
1 7
1 2
4 7
7 6
2 9
2 6
3 4
6 0""".splitlines()
def find_cycles(graph):
cycles = []
checked = set()
def dfs(node, visited, path):
visited.add(node)
path.append(node)
neighbors = graph.get(node, [])
for neighbor in neighbors:
if neighbor in visited:
# Cycle detected
start_index = path.index(neighbor)
cycle = path[start_index:]
m = tuple(sorted(cycle))
if len(cycle) > 2 and m not in checked:
checked.add(m)
cycles.append(cycle)
else:
dfs(neighbor, visited, path)
visited.remove(node)
path.pop()
for node in graph.keys():
dfs(node, set(), [])
return cycles
graph = {
# 'node': ['adjacent'],
}
n, m = map(int, data.pop(0).split())
for _ in range(m):
a, b = map(int, data.pop(0).split())
if b not in graph:
graph[b] = []
if a not in graph:
graph[a] = []
if a not in graph[b]:
graph[b].append(a)
if b not in graph[a]:
graph[a].append(b)
ans = find_cycles(graph)
print(ans)
print(len(ans))
Output:
[[3, 9, 2, 1, 7, 4], [6, 3, 9, 2, 1, 7], [6, 3, 9, 2], [6, 3, 4, 7, 1, 2], [6, 3, 4, 7], [6, 7, 1, 2], [6, 7, 4, 3, 9, 2]]
7