diff --git a/html/html.go b/html/html.go index deb592a..1a66a12 100644 --- a/html/html.go +++ b/html/html.go @@ -7,6 +7,8 @@ import ( type HtmlNode html.Node +type NodeComparer func(n *HtmlNode) bool + func (n *HtmlNode) GetAttribute(key string) ([]string, bool) { if n != nil { for _, attr := range n.Attr { @@ -18,6 +20,12 @@ func (n *HtmlNode) GetAttribute(key string) ([]string, bool) { return []string{}, false } +func (n *HtmlNode) HasAttributeVal(attr string, val string) NodeComparer { + return func(n *HtmlNode) bool { + return n.CheckAttribute(attr, val) + } +} + func (n *HtmlNode) CheckAttribute(attr string, val string) bool { if n != nil && n.Type == html.ElementNode { attrvals, ok := n.GetAttribute(attr) @@ -32,7 +40,13 @@ func (n *HtmlNode) CheckAttribute(attr string, val string) bool { return false } -func (n *HtmlNode) Find(f func(n *HtmlNode) bool) *HtmlNode { +func (f NodeComparer) And(f2 NodeComparer) NodeComparer { + return func(n *HtmlNode) bool { + return f(n) && f2(n) + } +} + +func (n *HtmlNode) Find(f NodeComparer) *HtmlNode { if n != nil { if f(n) { return n @@ -47,7 +61,7 @@ func (n *HtmlNode) Find(f func(n *HtmlNode) bool) *HtmlNode { return nil } -func (n *HtmlNode) FindAll(f func(n *HtmlNode) bool) []*HtmlNode { +func (n *HtmlNode) FindAll(f NodeComparer) []*HtmlNode { all := []*HtmlNode{} if n != nil { if f(n) { @@ -63,9 +77,25 @@ func (n *HtmlNode) FindAll(f func(n *HtmlNode) bool) []*HtmlNode { return all } +func IsText() NodeComparer { + return func(n *HtmlNode) bool { return n.Type == html.TextNode } +} + +func IsTag(name string) NodeComparer { + return func(n *HtmlNode) bool { return n.Type == html.ElementNode && n.Data == name } +} + +func HasClass(name string) NodeComparer { + return func(n *HtmlNode) bool { return n.CheckAttribute("class", name) } +} + +func HasID(id string) NodeComparer { + return func(n *HtmlNode) bool { return n.CheckAttribute("id", id) } +} + func (n *HtmlNode) GetElementById(id string) *HtmlNode { if n != nil { - return n.Find(func(n *HtmlNode) bool { return n.CheckAttribute("id", id) }) + return n.Find(HasID(id)) } else { return nil } @@ -73,7 +103,7 @@ func (n *HtmlNode) GetElementById(id string) *HtmlNode { func (n *HtmlNode) GetElementsByClass(class string) []*HtmlNode { if n != nil { - return n.FindAll(func(n *HtmlNode) bool { return n.CheckAttribute("class", class) }) + return n.FindAll(HasClass(class)) } else { return nil }