package main
import (
"bytes"
"fmt"
"strconv"
)
type BinaryTreeNode struct {
p *BinaryTreeNode
left *BinaryTreeNode
right *BinaryTreeNode
value int64
}
type BinaryTree struct {
root *BinaryTreeNode
}
func main() {
tree := new(BinaryTree)
list := []int64{20, 13, 35, 8, 15, 28, 40, 4, 10, 14, 16, 30, 39, 2, 5, 9, 11, 29, 32, 38, 6}
for _, v := range list {
tree.add(v)
}
fmt.Printf("Tree: %s \n", tree)
fmt.Println("—-")
n1 := tree.find(8) // 只返回第一个
fmt.Printf("查找8: %s \n", n1)
fmt.Printf("查找8子树的最小值: %s \n", n1.minimum())
fmt.Println("—-")
n2 := tree.find(13)
n3 := tree.find(4)
tree.transPlant(n2, n3)
fmt.Printf("子树替换 4 -> 13: %s \n", tree)
fmt.Println("—-")
tree.remove(tree.find(20))
fmt.Printf("移除 20: %s \n", tree)
fmt.Println("—-")
n4 := tree.find(28)
fmt.Printf("查找 28 的前驱: %s \n", n4.front())
fmt.Printf("查找 28 的后继: %s \n", n4.after())
}
// r 替换 n
func (bt *BinaryTree) transPlant(n *BinaryTreeNode, r *BinaryTreeNode) {
if n.p == nil {
bt.root = r
} else if n == n.p.left { // 左子数放左边
n.p.left = r
} else { // 右子数放右边
n.p.right = r
}
if r != nil {
r.p = n.p
}
}
func (bt *BinaryTree) String() string {
// 中序遍历
return readNode(bt.root)
}
func readNode(node *BinaryTreeNode) string {
var s string
if node.left != nil {
s = s + readNode(node.left)
}
s = s + strconv.FormatInt(node.value, 10) + " "
if node.right != nil {
s = s + readNode(node.right)
}
return s
}
func (n *BinaryTreeNode) String() string {
if n == nil {
return "nil"
}
node := n
var b bytes.Buffer
for {
if node == nil {
break
}
fmtS := " <- %d"
if b.Len() == 0 {
fmtS = "%d"
}
_, err := fmt.Fprintf(&b, fmtS, node.value)
if err != nil {
return ""
}
node = node.p
}
return b.String()
}
func (bt *BinaryTree) add(value int64) {
node := &BinaryTreeNode{
value: value,
}
pNode := bt.root
if pNode == nil {
bt.root = node
return
}
for {
if value >= pNode.value {
if pNode.right != nil {
pNode = pNode.right
} else {
node.p = pNode
pNode.right = node
break
}
} else {
if pNode.left != nil {
pNode = pNode.left
} else {
node.p = pNode
pNode.left = node
break
}
}
}
}
func (bt *BinaryTree) find(value int64) *BinaryTreeNode {
node := bt.root
for {
if node == nil {
return nil
}
if value > node.value {
node = node.right
} else if value < node.value {
node = node.left
} else {
return node
}
}
}
func (n *BinaryTreeNode) minimum() *BinaryTreeNode {
node := n
if node == nil {
return nil
}
for {
if node.left == nil {
return node
}
node = node.left
}
}
func (n *BinaryTreeNode) front() *BinaryTreeNode {
node := n
if node == nil {
return nil
}
if node.left != nil {
node = node.left
} else {
return nil
}
for {
if node.right == nil {
return node
}
node = node.right
}
}
func (n *BinaryTreeNode) after() *BinaryTreeNode {
if n == nil {
return nil
}
if n.right != nil {
return n.right.minimum()
}
if n.p.left == n {
return n.p
}
node := n
for {
if node.p.right == node {
node = node.p
}
if node.p == nil { // 找到根了还没找到大的
return nil
}
if node.p.left == node {
return node.p
}
}
}
func (bt *BinaryTree) remove(node *BinaryTreeNode) {
if node == nil {
return
}
if node.left == nil {
bt.transPlant(node, node.right)
} else if node.right == nil {
bt.transPlant(node, node.left)
} else {
min := node.right.minimum()
if min.p != node {
bt.transPlant(min, min.right)
min.right = node.right
min.right.p = min
}
bt.transPlant(node, min)
min.left = node.left
min.left.p = min
}
}