Go언어로 구현된 코드와 한글로 작성된 설명을 인터넷에서 찾기 힘들기에, Go언어를 공부하고 있는 누군가에게는 약간의 도움이 되길 바라며, Graph Search, Shortest Paths, and Data Structures 강의에서 프로그래밍 숙제로 구현했던 Dijkstra 알고리즘 코드를 공유해본다.
수준이 낮은 코드를 인터넷에 공유하는 것은 부끄러운 일이지만, 이정도 코드밖에 만들 수 없는 것이 현재 나의 실력이라는 것을 인정해야만 더 높은 수준으로 나아갈 수 있다고 믿기에, 앞으로도 기회가 닿는대로 Go언어로 작성된 코드와 시행착오를 공유해보고자 한다.
저장소: http://gogs.reshout.com/reshout/dijkstra-heap
소스트리 구조는 아래와 같다.
├── dijkstraData.txt
├── graph
│ ├── graph.go
│ └── graph_test.go
├── heap
│ ├── heap.go
│ └── heap_test.go
└── main.go
graph
패지와 heap
패키지는 알고리즘 구현에 필요한 자료구조를 별도의 패키지로 구현한 것인데, 유닛테스트 코드를 포함하며 100%에 가까운 커버리지를 확보하였다. 유닛테스트를 통해 확보한 자료구조의 완성도 덕분에 이를 활용한 알고리즘 코드를 완성한 후에는 한 번에 원하는 답을 얻을 수 있었다. Go언어는 소스코드와 같은 폴더에 유닛테스트 코드(e.g., graph_test.go
)를 작성할 수 있어서 유닛테스트를 작성하는 습관을 들이는데 도움이 되는 것 같다.
$ go test -cover ./...
? dijkstra-heap [no test files]
ok dijkstra-heap/graph (cached) coverage: 94.4% of statements
ok dijkstra-heap/heap (cached) coverage: 97.3% of statements
main.go
에 구현된 Dijkstra 알고리즘은 Dasgupta가 쓴 알고리즘 책 115쪽에 있는 pseudocode를 약간 변형한 것으로 다음과 같다.
func dijkstra(g *graph.Graph) {
for _, u := range g.GetVertices() {
u.Distance = maxInt
}
g.GetVertex(1).Distance = 0
h := heap.NewHeap()
for _, u := range g.GetVertices() {
h.Push(u)
}
for !h.IsEmpty() {
u := h.Pop().(*graph.Vertex)
for _, edge := range u.OutgoingEdges {
v := edge.Head
if v.Distance > u.Distance+edge.Length {
v.Distance = u.Distance + edge.Length
h.Delete(v)
h.Push(v)
}
}
}
}
heap
패키지는 다른 프로젝트에서도 사용할 수 있도록 범용적인 인터페이스를 제공한다. 아래와 같이 Less
함수를 구현한 어떤 객체도 Push(obj Object)
, Pop() Object
할 수 있도록 하였다.
package heap
type Object interface {
Less(obj Object) bool
}
graph
패키지는 heap
패키지에 의존적이다. Vertex
를 heap
에 넣어놓고 Distance
기준으로 최소값을 추출해야 하기 때문이다.
package graph
import "dijkstra-heap/heap"
type Vertex struct {
Label int
Distance int
OutgoingEdges []*Edge
}
func (v *Vertex) Less(obj heap.Object) bool {
return v.Distance < obj.(*Vertex).Distance
}
graph
패키지는 이 프로젝트 전용이다. 알고리즘에서 사용할 연산에 최적화된 실행시간을 제공하려면 내부 자료구조나 내부 알고리즘도 달라질 수 밖에 없다.
그래프를 구성하는 Vertex
들과 Edge
들을 저장하는 다양한 방법이 존재할 수 있다. Dijkstra 알고리즘에서는 Heap
에서 꺼낸 Vertex
에서 밖으로 연결된 Edge
들을 탐색하는 경우가 많으므로 Vertex
가 OutgoingEdges
를 갖도록 하였다.
graph.go
의 전체 코드는 아래와 같다.
package graph
import "dijkstra-heap/heap"
type Vertex struct {
Label int
Distance int
OutgoingEdges []*Edge
}
func (v *Vertex) Less(obj heap.Object) bool {
return v.Distance < obj.(*Vertex).Distance
}
type Edge struct {
Tail *Vertex
Head *Vertex
Length int
Explored bool
}
type Graph struct {
vMap map[int]*Vertex
}
func NewGraph() *Graph {
g := &Graph{}
g.vMap = map[int]*Vertex{}
return g
}
func (g *Graph) AddEdge(tailLabel int, headLabel int, len int) {
tail := g.GetVertex(tailLabel)
head := g.GetVertex(headLabel)
edge := &Edge{tail, head, len, false}
tail.OutgoingEdges = append(tail.OutgoingEdges, edge)
}
func (g *Graph) GetVertexCount() int {
return len(g.vMap)
}
func (g *Graph) GetVertex(label int) *Vertex {
key := label
if _, exists := g.vMap[key]; !exists {
g.vMap[key] = &Vertex{Label: label, OutgoingEdges: []*Edge{}}
}
return g.vMap[key]
}
func (g *Graph) GetVertices() []*Vertex {
vertices := []*Vertex{}
for label := range g.vMap {
vertex := g.vMap[label]
vertices = append(vertices, vertex)
}
return vertices
}
Dijkstra 알고리즘에서 사용하는 Heap은 중간 객체를 찾아서 삭제하거나 업데이트하는 인터페이스가 필요한데, Go언어에 내장된 container/heap
패키지에서 이를 제공하지 않아 직접 구현했다.
type Heap struct {
objArr []Object
objMap map[Object]int
}
내부 자료구조로 array를 사용해야만 Leaf 노드를 찾아 치환하거나 삭제하는 연산이 O(1)에 가능하다.
중간 객체를 삭제하기 위해선 먼저 트리 안에서 해당 객체를 찾아야 하는데, Binary Search Tree가 아니므로 최악의 경우 모든 노드를 다 찾아보아야 해서 실행시간은 최악의 경우 O(n)이 된다. 이렇게 되면 Heap을 사용해 실행시간을 O(m * n)에서 O(m * log n)으로 향상시키려는 시도가 무의미해진다.
객체의 위치를 O(1)에 찾을 수 있도록 Hash Table을 도입하여 객체가 추가되거나 삭제되거나 위치가 바뀔때마다 objMap
을 업데이트하도록 하여, Delete(obj Object) bool
의 실행시간을 O(log n)에 맞췄다.
heap.go
의 전체 코드는 아래와 같다. 더 간결하게 짤 수도 있을텐데, 지금은 이정도가 내 수준인 것 같다. 좋은 코드를 많이 볼 필요가 있을 것 같다.
package heap
type Object interface {
Less(obj Object) bool
}
type Heap struct {
objArr []Object
objMap map[Object]int
}
func NewHeap() *Heap {
return &Heap{objArr: []Object{}, objMap: map[Object]int{}}
}
func (h *Heap) Push(obj Object) {
idx := h.insertAsLeaf(obj)
h.bubbleUp(idx)
}
func (h *Heap) Pop() Object {
rootIdx := 0
root := h.objArr[rootIdx]
leafIdx := h.getLeafIndex()
h.swap(rootIdx, leafIdx)
h.deleteLeaf()
if len(h.objArr) > 0 {
h.bubbleDown(0)
}
return root
}
func (h *Heap) Delete(obj Object) bool {
if len(h.objArr) <= 0 {
return false
}
objectIdx := h.getObjectIndex(obj)
if objectIdx == -1 {
return false
}
leafIdx := h.getLeafIndex()
h.swap(objectIdx, leafIdx)
h.deleteLeaf()
if objectIdx != leafIdx {
h.bubbleDown(objectIdx)
}
return true
}
func (h *Heap) IsEmpty() bool {
return len(h.objArr) == 0
}
func (h *Heap) getParentIndex(idx int) int {
if idx == 0 {
return -1
}
return (idx - 1) / 2
}
func (h *Heap) getLeftChildIndex(idx int) int {
leftChildIdx := 2*(idx+1) - 1
if len(h.objArr) <= leftChildIdx {
return -1
}
return leftChildIdx
}
func (h *Heap) getRightChildIndex(idx int) int {
rightChildIndex := 2*(idx+1) + 1 - 1
if len(h.objArr) <= rightChildIndex {
return -1
}
return rightChildIndex
}
func (h *Heap) getLeafIndex() int {
return len(h.objArr) - 1
}
func (h *Heap) getObjectIndex(obj Object) int {
idx, exists := h.objMap[obj]
if !exists {
return -1
}
return idx
}
func (h *Heap) insertAsLeaf(obj Object) int {
h.objArr = append(h.objArr, obj)
idx := len(h.objArr) - 1
h.objMap[obj] = idx
return idx
}
func (h *Heap) deleteLeaf() {
leafIdx := h.getLeafIndex()
leafObj := h.objArr[leafIdx]
delete(h.objMap, leafObj)
h.objArr = h.objArr[:leafIdx]
}
func (h *Heap) swap(idx1 int, idx2 int) {
h.objArr[idx1], h.objArr[idx2] = h.objArr[idx2], h.objArr[idx1]
h.objMap[h.objArr[idx1]] = idx1
h.objMap[h.objArr[idx2]] = idx2
}
func (h *Heap) bubbleUp(idx int) {
pIdx := h.getParentIndex(idx)
if pIdx == -1 {
return
}
pObj := h.objArr[pIdx]
obj := h.objArr[idx]
if obj.Less(pObj) {
h.swap(pIdx, idx)
h.bubbleUp(pIdx)
}
}
func (h *Heap) bubbleDown(idx int) {
node := h.objArr[idx]
leftIdx := h.getLeftChildIndex(idx)
leftLess := (leftIdx != -1 && h.objArr[leftIdx].Less(node))
rightIdx := h.getRightChildIndex(idx)
rightLess := (rightIdx != -1 && h.objArr[rightIdx].Less(node))
if leftLess && rightLess {
if h.objArr[leftIdx].Less(h.objArr[rightIdx]) {
h.swap(idx, leftIdx)
h.bubbleDown(leftIdx)
} else {
h.swap(idx, rightIdx)
h.bubbleDown(rightIdx)
}
} else if leftLess {
h.swap(idx, leftIdx)
h.bubbleDown(leftIdx)
} else if rightLess {
h.swap(idx, rightIdx)
h.bubbleDown(rightIdx)
}
}
유닛테스트는 집요할 수록 좋다. 아래는 heap_test.go
의 전체 코드다.
package heap
import (
"testing"
"github.com/stretchr/testify/assert"
)
type testObject struct {
key int
}
func (t testObject) Less(obj Object) bool {
return t.key < obj.(testObject).key
}
func TestHeap(t *testing.T) {
h := NewHeap()
assert.True(t, h.IsEmpty())
h.Push(testObject{11})
h.Push(testObject{12})
h.Push(testObject{13})
h.Push(testObject{6})
h.Push(testObject{8})
h.Push(testObject{20})
h.Push(testObject{15})
h.Push(testObject{10})
h.Push(testObject{5})
h.Push(testObject{3})
assert.False(t, h.IsEmpty())
assert.False(t, h.Delete(testObject{100}))
assert.Equal(t, 3, h.Pop().(testObject).key)
assert.True(t, h.Delete(testObject{10}))
assert.True(t, h.Delete(testObject{6}))
assert.Equal(t, 5, h.Pop().(testObject).key)
assert.Equal(t, 8, h.Pop().(testObject).key)
assert.Equal(t, 11, h.Pop().(testObject).key)
assert.Equal(t, 12, h.Pop().(testObject).key)
assert.Equal(t, 13, h.Pop().(testObject).key)
assert.Equal(t, 15, h.Pop().(testObject).key)
assert.Equal(t, 20, h.Pop().(testObject).key)
assert.False(t, h.Delete(testObject{20}))
assert.True(t, h.IsEmpty())
}
소프트웨어를 설계하거나 코드를 작성할 때면, optimal solution을 가진 신이 등장해서 피드백을 주면 좋겠다는 생각을 하게 된다. 신은 아니더라도 나보다 나은 누군가가 이 코드를 보고 문제점, 개선점을 알려 주었으면 하는 바램이다.