Go언어로 구현한 Dijkstra 알고리즘

Go언어로 구현된 코드와 한글로 작성된 설명을 인터넷에서 찾기 힘들기에, Go언어를 공부하고 있는 누군가에게는 약간의 도움이 되길 바라며, Graph Search, Shortest Paths, and Data Structures 강의에서 프로그래밍 숙제로 구현했던 Dijkstra 알고리즘 코드를 공유해본다.

수준이 낮은 코드를 인터넷에 공유하는 것은 부끄러운 일이지만, 이정도 코드밖에 만들 수 없는 것이 현재 나의 실력이라는 것을 인정해야만 더 높은 수준으로 나아갈 수 있다고 믿기에, 앞으로도 기회가 닿는대로 Go언어로 작성된 코드와 시행착오를 공유해보고자 한다.

저장소: http://gogs.reshout.com/reshout/dijkstra-heap

소스트리 구조는 아래와 같다.

├── dijkstraData.txt
├── graph
│   ├── graph.go
│   └── graph_test.go
├── heap
│   ├── heap.go
│   └── heap_test.go
└── main.go

graph 패지와 heap 패키지는 알고리즘 구현에 필요한 자료구조를 별도의 패키지로 구현한 것인데, 유닛테스트 코드를 포함하며 100%에 가까운 커버리지를 확보하였다. 유닛테스트를 통해 확보한 자료구조의 완성도 덕분에 이를 활용한 알고리즘 코드를 완성한 후에는 한 번에 원하는 답을 얻을 수 있었다. Go언어는 소스코드와 같은 폴더에 유닛테스트 코드(e.g., graph_test.go)를 작성할 수 있어서 유닛테스트를 작성하는 습관을 들이는데 도움이 되는 것 같다.

$ go test -cover ./...
?       dijkstra-heap   [no test files]
ok      dijkstra-heap/graph     (cached)        coverage: 94.4% of statements
ok      dijkstra-heap/heap      (cached)        coverage: 97.3% of statements

main.go에 구현된 Dijkstra 알고리즘은 Dasgupta가 쓴 알고리즘 책 115쪽에 있는 pseudocode를 약간 변형한 것으로 다음과 같다.

func dijkstra(g *graph.Graph) {
	for _, u := range g.GetVertices() {
		u.Distance = maxInt
	}
	g.GetVertex(1).Distance = 0

	h := heap.NewHeap()
	for _, u := range g.GetVertices() {
		h.Push(u)
	}

	for !h.IsEmpty() {
		u := h.Pop().(*graph.Vertex)
		for _, edge := range u.OutgoingEdges {
			v := edge.Head
			if v.Distance > u.Distance+edge.Length {
				v.Distance = u.Distance + edge.Length
				h.Delete(v)
				h.Push(v)
			}
		}
	}
}

heap 패키지는 다른 프로젝트에서도 사용할 수 있도록 범용적인 인터페이스를 제공한다. 아래와 같이 Less 함수를 구현한 어떤 객체도 Push(obj Object), Pop() Object 할 수 있도록 하였다.

package heap

type Object interface {
	Less(obj Object) bool
}

graph 패키지는 heap 패키지에 의존적이다. Vertexheap에 넣어놓고 Distance 기준으로 최소값을 추출해야 하기 때문이다.

package graph

import "dijkstra-heap/heap"

type Vertex struct {
	Label         int
	Distance      int
	OutgoingEdges []*Edge
}

func (v *Vertex) Less(obj heap.Object) bool {
	return v.Distance < obj.(*Vertex).Distance
}

graph 패키지는 이 프로젝트 전용이다. 알고리즘에서 사용할 연산에 최적화된 실행시간을 제공하려면 내부 자료구조나 내부 알고리즘도 달라질 수 밖에 없다.

그래프를 구성하는 Vertex들과 Edge들을 저장하는 다양한 방법이 존재할 수 있다. Dijkstra 알고리즘에서는 Heap에서 꺼낸 Vertex에서 밖으로 연결된 Edge들을 탐색하는 경우가 많으므로 VertexOutgoingEdges를 갖도록 하였다.

graph.go의 전체 코드는 아래와 같다.

package graph

import "dijkstra-heap/heap"

type Vertex struct {
	Label         int
	Distance      int
	OutgoingEdges []*Edge
}

func (v *Vertex) Less(obj heap.Object) bool {
	return v.Distance < obj.(*Vertex).Distance
}

type Edge struct {
	Tail     *Vertex
	Head     *Vertex
	Length   int
	Explored bool
}

type Graph struct {
	vMap map[int]*Vertex
}

func NewGraph() *Graph {
	g := &Graph{}
	g.vMap = map[int]*Vertex{}
	return g
}

func (g *Graph) AddEdge(tailLabel int, headLabel int, len int) {
	tail := g.GetVertex(tailLabel)
	head := g.GetVertex(headLabel)
	edge := &Edge{tail, head, len, false}
	tail.OutgoingEdges = append(tail.OutgoingEdges, edge)
}

func (g *Graph) GetVertexCount() int {
	return len(g.vMap)
}

func (g *Graph) GetVertex(label int) *Vertex {
	key := label
	if _, exists := g.vMap[key]; !exists {
		g.vMap[key] = &Vertex{Label: label, OutgoingEdges: []*Edge{}}
	}
	return g.vMap[key]
}

func (g *Graph) GetVertices() []*Vertex {
	vertices := []*Vertex{}
	for label := range g.vMap {
		vertex := g.vMap[label]
		vertices = append(vertices, vertex)
	}
	return vertices
}

Dijkstra 알고리즘에서 사용하는 Heap은 중간 객체를 찾아서 삭제하거나 업데이트하는 인터페이스가 필요한데, Go언어에 내장된 container/heap 패키지에서 이를 제공하지 않아 직접 구현했다.

type Heap struct {
	objArr []Object
	objMap map[Object]int
}

내부 자료구조로 array를 사용해야만 Leaf 노드를 찾아 치환하거나 삭제하는 연산이 O(1)에 가능하다.

중간 객체를 삭제하기 위해선 먼저 트리 안에서 해당 객체를 찾아야 하는데, Binary Search Tree가 아니므로 최악의 경우 모든 노드를 다 찾아보아야 해서 실행시간은 최악의 경우 O(n)이 된다. 이렇게 되면 Heap을 사용해 실행시간을 O(m * n)에서 O(m * log n)으로 향상시키려는 시도가 무의미해진다.

객체의 위치를 O(1)에 찾을 수 있도록 Hash Table을 도입하여 객체가 추가되거나 삭제되거나 위치가 바뀔때마다 objMap을 업데이트하도록 하여, Delete(obj Object) bool의 실행시간을 O(log n)에 맞췄다.

heap.go의 전체 코드는 아래와 같다. 더 간결하게 짤 수도 있을텐데, 지금은 이정도가 내 수준인 것 같다. 좋은 코드를 많이 볼 필요가 있을 것 같다.

package heap

type Object interface {
	Less(obj Object) bool
}

type Heap struct {
	objArr []Object
	objMap map[Object]int
}

func NewHeap() *Heap {
	return &Heap{objArr: []Object{}, objMap: map[Object]int{}}
}

func (h *Heap) Push(obj Object) {
	idx := h.insertAsLeaf(obj)
	h.bubbleUp(idx)
}

func (h *Heap) Pop() Object {
	rootIdx := 0
	root := h.objArr[rootIdx]
	leafIdx := h.getLeafIndex()

	h.swap(rootIdx, leafIdx)
	h.deleteLeaf()
	if len(h.objArr) > 0 {
		h.bubbleDown(0)
	}

	return root
}

func (h *Heap) Delete(obj Object) bool {
	if len(h.objArr) <= 0 {
		return false
	}

	objectIdx := h.getObjectIndex(obj)
	if objectIdx == -1 {
		return false
	}
	leafIdx := h.getLeafIndex()

	h.swap(objectIdx, leafIdx)
	h.deleteLeaf()
	if objectIdx != leafIdx {
		h.bubbleDown(objectIdx)
	}

	return true
}

func (h *Heap) IsEmpty() bool {
	return len(h.objArr) == 0
}

func (h *Heap) getParentIndex(idx int) int {
	if idx == 0 {
		return -1
	}
	return (idx - 1) / 2
}

func (h *Heap) getLeftChildIndex(idx int) int {
	leftChildIdx := 2*(idx+1) - 1
	if len(h.objArr) <= leftChildIdx {
		return -1
	}
	return leftChildIdx
}

func (h *Heap) getRightChildIndex(idx int) int {
	rightChildIndex := 2*(idx+1) + 1 - 1
	if len(h.objArr) <= rightChildIndex {
		return -1
	}
	return rightChildIndex
}

func (h *Heap) getLeafIndex() int {
	return len(h.objArr) - 1
}

func (h *Heap) getObjectIndex(obj Object) int {
	idx, exists := h.objMap[obj]
	if !exists {
		return -1
	}
	return idx
}

func (h *Heap) insertAsLeaf(obj Object) int {
	h.objArr = append(h.objArr, obj)
	idx := len(h.objArr) - 1
	h.objMap[obj] = idx
	return idx
}

func (h *Heap) deleteLeaf() {
	leafIdx := h.getLeafIndex()
	leafObj := h.objArr[leafIdx]
	delete(h.objMap, leafObj)
	h.objArr = h.objArr[:leafIdx]
}

func (h *Heap) swap(idx1 int, idx2 int) {
	h.objArr[idx1], h.objArr[idx2] = h.objArr[idx2], h.objArr[idx1]
	h.objMap[h.objArr[idx1]] = idx1
	h.objMap[h.objArr[idx2]] = idx2
}

func (h *Heap) bubbleUp(idx int) {
	pIdx := h.getParentIndex(idx)
	if pIdx == -1 {
		return
	}
	pObj := h.objArr[pIdx]
	obj := h.objArr[idx]
	if obj.Less(pObj) {
		h.swap(pIdx, idx)
		h.bubbleUp(pIdx)
	}
}

func (h *Heap) bubbleDown(idx int) {
	node := h.objArr[idx]
	leftIdx := h.getLeftChildIndex(idx)
	leftLess := (leftIdx != -1 && h.objArr[leftIdx].Less(node))
	rightIdx := h.getRightChildIndex(idx)
	rightLess := (rightIdx != -1 && h.objArr[rightIdx].Less(node))

	if leftLess && rightLess {
		if h.objArr[leftIdx].Less(h.objArr[rightIdx]) {
			h.swap(idx, leftIdx)
			h.bubbleDown(leftIdx)
		} else {
			h.swap(idx, rightIdx)
			h.bubbleDown(rightIdx)
		}
	} else if leftLess {
		h.swap(idx, leftIdx)
		h.bubbleDown(leftIdx)
	} else if rightLess {
		h.swap(idx, rightIdx)
		h.bubbleDown(rightIdx)
	}
}

유닛테스트는 집요할 수록 좋다. 아래는 heap_test.go의 전체 코드다.

package heap

import (
	"testing"

	"github.com/stretchr/testify/assert"
)

type testObject struct {
	key int
}

func (t testObject) Less(obj Object) bool {
	return t.key < obj.(testObject).key
}

func TestHeap(t *testing.T) {
	h := NewHeap()
	assert.True(t, h.IsEmpty())
	h.Push(testObject{11})
	h.Push(testObject{12})
	h.Push(testObject{13})
	h.Push(testObject{6})
	h.Push(testObject{8})
	h.Push(testObject{20})
	h.Push(testObject{15})
	h.Push(testObject{10})
	h.Push(testObject{5})
	h.Push(testObject{3})
	assert.False(t, h.IsEmpty())
	assert.False(t, h.Delete(testObject{100}))
	assert.Equal(t, 3, h.Pop().(testObject).key)
	assert.True(t, h.Delete(testObject{10}))
	assert.True(t, h.Delete(testObject{6}))
	assert.Equal(t, 5, h.Pop().(testObject).key)
	assert.Equal(t, 8, h.Pop().(testObject).key)
	assert.Equal(t, 11, h.Pop().(testObject).key)
	assert.Equal(t, 12, h.Pop().(testObject).key)
	assert.Equal(t, 13, h.Pop().(testObject).key)
	assert.Equal(t, 15, h.Pop().(testObject).key)
	assert.Equal(t, 20, h.Pop().(testObject).key)
	assert.False(t, h.Delete(testObject{20}))
	assert.True(t, h.IsEmpty())
}

소프트웨어를 설계하거나 코드를 작성할 때면, optimal solution을 가진 신이 등장해서 피드백을 주면 좋겠다는 생각을 하게 된다. 신은 아니더라도 나보다 나은 누군가가 이 코드를 보고 문제점, 개선점을 알려 주었으면 하는 바램이다.

댓글 남기기

댓글 남기기