1/*
2 * Copyright (C) 2005 Frerich Raabe <raabe@kde.org>
3 * Copyright (C) 2006, 2009 Apple Inc.
4 * Copyright (C) 2007 Alexey Proskuryakov <ap@webkit.org>
5 *
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions
8 * are met:
9 *
10 * 1. Redistributions of source code must retain the above copyright
11 *    notice, this list of conditions and the following disclaimer.
12 * 2. Redistributions in binary form must reproduce the above copyright
13 *    notice, this list of conditions and the following disclaimer in the
14 *    documentation and/or other materials provided with the distribution.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
17 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
18 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
19 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
20 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
21 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
22 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
23 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
25 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27
28#include "config.h"
29#include "XPathFunctions.h"
30
31#include "Element.h"
32#include "ProcessingInstruction.h"
33#include "TreeScope.h"
34#include "XMLNames.h"
35#include "XPathUtil.h"
36#include "XPathValue.h"
37#include <wtf/MathExtras.h>
38#include <wtf/text/StringBuilder.h>
39
40namespace WebCore {
41namespace XPath {
42
43static inline bool isWhitespace(UChar c)
44{
45    return c == ' ' || c == '\n' || c == '\r' || c == '\t';
46}
47
48
49#define DEFINE_FUNCTION_CREATOR(Class) static Function* create##Class() { return new Class; }
50
51class Interval {
52public:
53    static const int Inf = -1;
54
55    Interval();
56    Interval(int value);
57    Interval(int min, int max);
58
59    bool contains(int value) const;
60
61private:
62    int m_min;
63    int m_max;
64};
65
66struct FunctionRec {
67    typedef Function *(*FactoryFn)();
68    FactoryFn factoryFn;
69    Interval args;
70};
71
72static HashMap<String, FunctionRec>* functionMap;
73
74class FunLast : public Function {
75    virtual Value evaluate() const;
76    virtual Value::Type resultType() const { return Value::NumberValue; }
77public:
78    FunLast() { setIsContextSizeSensitive(true); }
79};
80
81class FunPosition : public Function {
82    virtual Value evaluate() const;
83    virtual Value::Type resultType() const { return Value::NumberValue; }
84public:
85    FunPosition() { setIsContextPositionSensitive(true); }
86};
87
88class FunCount : public Function {
89    virtual Value evaluate() const;
90    virtual Value::Type resultType() const { return Value::NumberValue; }
91};
92
93class FunId : public Function {
94    virtual Value evaluate() const;
95    virtual Value::Type resultType() const { return Value::NodeSetValue; }
96};
97
98class FunLocalName : public Function {
99    virtual Value evaluate() const;
100    virtual Value::Type resultType() const { return Value::StringValue; }
101public:
102    FunLocalName() { setIsContextNodeSensitive(true); } // local-name() with no arguments uses context node.
103};
104
105class FunNamespaceURI : public Function {
106    virtual Value evaluate() const;
107    virtual Value::Type resultType() const { return Value::StringValue; }
108public:
109    FunNamespaceURI() { setIsContextNodeSensitive(true); } // namespace-uri() with no arguments uses context node.
110};
111
112class FunName : public Function {
113    virtual Value evaluate() const;
114    virtual Value::Type resultType() const { return Value::StringValue; }
115public:
116    FunName() { setIsContextNodeSensitive(true); } // name() with no arguments uses context node.
117};
118
119class FunString : public Function {
120    virtual Value evaluate() const;
121    virtual Value::Type resultType() const { return Value::StringValue; }
122public:
123    FunString() { setIsContextNodeSensitive(true); } // string() with no arguments uses context node.
124};
125
126class FunConcat : public Function {
127    virtual Value evaluate() const;
128    virtual Value::Type resultType() const { return Value::StringValue; }
129};
130
131class FunStartsWith : public Function {
132    virtual Value evaluate() const;
133    virtual Value::Type resultType() const { return Value::BooleanValue; }
134};
135
136class FunContains : public Function {
137    virtual Value evaluate() const;
138    virtual Value::Type resultType() const { return Value::BooleanValue; }
139};
140
141class FunSubstringBefore : public Function {
142    virtual Value evaluate() const;
143    virtual Value::Type resultType() const { return Value::StringValue; }
144};
145
146class FunSubstringAfter : public Function {
147    virtual Value evaluate() const;
148    virtual Value::Type resultType() const { return Value::StringValue; }
149};
150
151class FunSubstring : public Function {
152    virtual Value evaluate() const;
153    virtual Value::Type resultType() const { return Value::StringValue; }
154};
155
156class FunStringLength : public Function {
157    virtual Value evaluate() const;
158    virtual Value::Type resultType() const { return Value::NumberValue; }
159public:
160    FunStringLength() { setIsContextNodeSensitive(true); } // string-length() with no arguments uses context node.
161};
162
163class FunNormalizeSpace : public Function {
164    virtual Value evaluate() const;
165    virtual Value::Type resultType() const { return Value::StringValue; }
166public:
167    FunNormalizeSpace() { setIsContextNodeSensitive(true); } // normalize-space() with no arguments uses context node.
168};
169
170class FunTranslate : public Function {
171    virtual Value evaluate() const;
172    virtual Value::Type resultType() const { return Value::StringValue; }
173};
174
175class FunBoolean : public Function {
176    virtual Value evaluate() const;
177    virtual Value::Type resultType() const { return Value::BooleanValue; }
178};
179
180class FunNot : public Function {
181    virtual Value evaluate() const;
182    virtual Value::Type resultType() const { return Value::BooleanValue; }
183};
184
185class FunTrue : public Function {
186    virtual Value evaluate() const;
187    virtual Value::Type resultType() const { return Value::BooleanValue; }
188};
189
190class FunFalse : public Function {
191    virtual Value evaluate() const;
192    virtual Value::Type resultType() const { return Value::BooleanValue; }
193};
194
195class FunLang : public Function {
196    virtual Value evaluate() const;
197    virtual Value::Type resultType() const { return Value::BooleanValue; }
198public:
199    FunLang() { setIsContextNodeSensitive(true); } // lang() always works on context node.
200};
201
202class FunNumber : public Function {
203    virtual Value evaluate() const;
204    virtual Value::Type resultType() const { return Value::NumberValue; }
205public:
206    FunNumber() { setIsContextNodeSensitive(true); } // number() with no arguments uses context node.
207};
208
209class FunSum : public Function {
210    virtual Value evaluate() const;
211    virtual Value::Type resultType() const { return Value::NumberValue; }
212};
213
214class FunFloor : public Function {
215    virtual Value evaluate() const;
216    virtual Value::Type resultType() const { return Value::NumberValue; }
217};
218
219class FunCeiling : public Function {
220    virtual Value evaluate() const;
221    virtual Value::Type resultType() const { return Value::NumberValue; }
222};
223
224class FunRound : public Function {
225    virtual Value evaluate() const;
226    virtual Value::Type resultType() const { return Value::NumberValue; }
227public:
228    static double round(double);
229};
230
231DEFINE_FUNCTION_CREATOR(FunLast)
232DEFINE_FUNCTION_CREATOR(FunPosition)
233DEFINE_FUNCTION_CREATOR(FunCount)
234DEFINE_FUNCTION_CREATOR(FunId)
235DEFINE_FUNCTION_CREATOR(FunLocalName)
236DEFINE_FUNCTION_CREATOR(FunNamespaceURI)
237DEFINE_FUNCTION_CREATOR(FunName)
238
239DEFINE_FUNCTION_CREATOR(FunString)
240DEFINE_FUNCTION_CREATOR(FunConcat)
241DEFINE_FUNCTION_CREATOR(FunStartsWith)
242DEFINE_FUNCTION_CREATOR(FunContains)
243DEFINE_FUNCTION_CREATOR(FunSubstringBefore)
244DEFINE_FUNCTION_CREATOR(FunSubstringAfter)
245DEFINE_FUNCTION_CREATOR(FunSubstring)
246DEFINE_FUNCTION_CREATOR(FunStringLength)
247DEFINE_FUNCTION_CREATOR(FunNormalizeSpace)
248DEFINE_FUNCTION_CREATOR(FunTranslate)
249
250DEFINE_FUNCTION_CREATOR(FunBoolean)
251DEFINE_FUNCTION_CREATOR(FunNot)
252DEFINE_FUNCTION_CREATOR(FunTrue)
253DEFINE_FUNCTION_CREATOR(FunFalse)
254DEFINE_FUNCTION_CREATOR(FunLang)
255
256DEFINE_FUNCTION_CREATOR(FunNumber)
257DEFINE_FUNCTION_CREATOR(FunSum)
258DEFINE_FUNCTION_CREATOR(FunFloor)
259DEFINE_FUNCTION_CREATOR(FunCeiling)
260DEFINE_FUNCTION_CREATOR(FunRound)
261
262#undef DEFINE_FUNCTION_CREATOR
263
264inline Interval::Interval()
265    : m_min(Inf), m_max(Inf)
266{
267}
268
269inline Interval::Interval(int value)
270    : m_min(value), m_max(value)
271{
272}
273
274inline Interval::Interval(int min, int max)
275    : m_min(min), m_max(max)
276{
277}
278
279inline bool Interval::contains(int value) const
280{
281    if (m_min == Inf && m_max == Inf)
282        return true;
283
284    if (m_min == Inf)
285        return value <= m_max;
286
287    if (m_max == Inf)
288        return value >= m_min;
289
290    return value >= m_min && value <= m_max;
291}
292
293void Function::setArguments(const Vector<Expression*>& args)
294{
295    ASSERT(!subExprCount());
296
297    // Some functions use context node as implicit argument, so when explicit arguments are added, they may no longer be context node sensitive.
298    if (m_name != "lang" && !args.isEmpty())
299        setIsContextNodeSensitive(false);
300
301    Vector<Expression*>::const_iterator end = args.end();
302    for (Vector<Expression*>::const_iterator it = args.begin(); it != end; ++it)
303        addSubExpression(*it);
304}
305
306Value FunLast::evaluate() const
307{
308    return Expression::evaluationContext().size;
309}
310
311Value FunPosition::evaluate() const
312{
313    return Expression::evaluationContext().position;
314}
315
316Value FunId::evaluate() const
317{
318    Value a = arg(0)->evaluate();
319    StringBuilder idList; // A whitespace-separated list of IDs
320
321    if (a.isNodeSet()) {
322        const NodeSet& nodes = a.toNodeSet();
323        for (size_t i = 0; i < nodes.size(); ++i) {
324            String str = stringValue(nodes[i]);
325            idList.append(str);
326            idList.append(' ');
327        }
328    } else {
329        String str = a.toString();
330        idList.append(str);
331    }
332
333    TreeScope* contextScope = evaluationContext().node->treeScope();
334    NodeSet result;
335    HashSet<Node*> resultSet;
336
337    unsigned startPos = 0;
338    unsigned length = idList.length();
339    while (true) {
340        while (startPos < length && isWhitespace(idList[startPos]))
341            ++startPos;
342
343        if (startPos == length)
344            break;
345
346        size_t endPos = startPos;
347        while (endPos < length && !isWhitespace(idList[endPos]))
348            ++endPos;
349
350        // If there are several nodes with the same id, id() should return the first one.
351        // In WebKit, getElementById behaves so, too, although its behavior in this case is formally undefined.
352        Node* node = contextScope->getElementById(String(idList.characters() + startPos, endPos - startPos));
353        if (node && resultSet.add(node).isNewEntry)
354            result.append(node);
355
356        startPos = endPos;
357    }
358
359    result.markSorted(false);
360
361    return Value(result, Value::adopt);
362}
363
364static inline String expandedNameLocalPart(Node* node)
365{
366    // The local part of an XPath expanded-name matches DOM local name for most node types, except for namespace nodes and processing instruction nodes.
367    ASSERT(node->nodeType() != Node::XPATH_NAMESPACE_NODE); // Not supported yet.
368    if (node->nodeType() == Node::PROCESSING_INSTRUCTION_NODE)
369        return static_cast<ProcessingInstruction*>(node)->target();
370    return node->localName().string();
371}
372
373static inline String expandedName(Node* node)
374{
375    const AtomicString& prefix = node->prefix();
376    return prefix.isEmpty() ? expandedNameLocalPart(node) : prefix + ":" + expandedNameLocalPart(node);
377}
378
379Value FunLocalName::evaluate() const
380{
381    if (argCount() > 0) {
382        Value a = arg(0)->evaluate();
383        if (!a.isNodeSet())
384            return "";
385
386        Node* node = a.toNodeSet().firstNode();
387        return node ? expandedNameLocalPart(node) : "";
388    }
389
390    return expandedNameLocalPart(evaluationContext().node.get());
391}
392
393Value FunNamespaceURI::evaluate() const
394{
395    if (argCount() > 0) {
396        Value a = arg(0)->evaluate();
397        if (!a.isNodeSet())
398            return "";
399
400        Node* node = a.toNodeSet().firstNode();
401        return node ? node->namespaceURI().string() : "";
402    }
403
404    return evaluationContext().node->namespaceURI().string();
405}
406
407Value FunName::evaluate() const
408{
409    if (argCount() > 0) {
410        Value a = arg(0)->evaluate();
411        if (!a.isNodeSet())
412            return "";
413
414        Node* node = a.toNodeSet().firstNode();
415        return node ? expandedName(node) : "";
416    }
417
418    return expandedName(evaluationContext().node.get());
419}
420
421Value FunCount::evaluate() const
422{
423    Value a = arg(0)->evaluate();
424
425    return double(a.toNodeSet().size());
426}
427
428Value FunString::evaluate() const
429{
430    if (!argCount())
431        return Value(Expression::evaluationContext().node.get()).toString();
432    return arg(0)->evaluate().toString();
433}
434
435Value FunConcat::evaluate() const
436{
437    StringBuilder result;
438    result.reserveCapacity(1024);
439
440    unsigned count = argCount();
441    for (unsigned i = 0; i < count; ++i) {
442        String str(arg(i)->evaluate().toString());
443        result.append(str);
444    }
445
446    return result.toString();
447}
448
449Value FunStartsWith::evaluate() const
450{
451    String s1 = arg(0)->evaluate().toString();
452    String s2 = arg(1)->evaluate().toString();
453
454    if (s2.isEmpty())
455        return true;
456
457    return s1.startsWith(s2);
458}
459
460Value FunContains::evaluate() const
461{
462    String s1 = arg(0)->evaluate().toString();
463    String s2 = arg(1)->evaluate().toString();
464
465    if (s2.isEmpty())
466        return true;
467
468    return s1.contains(s2) != 0;
469}
470
471Value FunSubstringBefore::evaluate() const
472{
473    String s1 = arg(0)->evaluate().toString();
474    String s2 = arg(1)->evaluate().toString();
475
476    if (s2.isEmpty())
477        return "";
478
479    size_t i = s1.find(s2);
480
481    if (i == notFound)
482        return "";
483
484    return s1.left(i);
485}
486
487Value FunSubstringAfter::evaluate() const
488{
489    String s1 = arg(0)->evaluate().toString();
490    String s2 = arg(1)->evaluate().toString();
491
492    size_t i = s1.find(s2);
493    if (i == notFound)
494        return "";
495
496    return s1.substring(i + s2.length());
497}
498
499Value FunSubstring::evaluate() const
500{
501    String s = arg(0)->evaluate().toString();
502    double doublePos = arg(1)->evaluate().toNumber();
503    if (std::isnan(doublePos))
504        return "";
505    long pos = static_cast<long>(FunRound::round(doublePos));
506    bool haveLength = argCount() == 3;
507    long len = -1;
508    if (haveLength) {
509        double doubleLen = arg(2)->evaluate().toNumber();
510        if (std::isnan(doubleLen))
511            return "";
512        len = static_cast<long>(FunRound::round(doubleLen));
513    }
514
515    if (pos > long(s.length()))
516        return "";
517
518    if (pos < 1) {
519        if (haveLength) {
520            len -= 1 - pos;
521            if (len < 1)
522                return "";
523        }
524        pos = 1;
525    }
526
527    return s.substring(pos - 1, len);
528}
529
530Value FunStringLength::evaluate() const
531{
532    if (!argCount())
533        return Value(Expression::evaluationContext().node.get()).toString().length();
534    return arg(0)->evaluate().toString().length();
535}
536
537Value FunNormalizeSpace::evaluate() const
538{
539    if (!argCount()) {
540        String s = Value(Expression::evaluationContext().node.get()).toString();
541        return s.simplifyWhiteSpace();
542    }
543
544    String s = arg(0)->evaluate().toString();
545    return s.simplifyWhiteSpace();
546}
547
548Value FunTranslate::evaluate() const
549{
550    String s1 = arg(0)->evaluate().toString();
551    String s2 = arg(1)->evaluate().toString();
552    String s3 = arg(2)->evaluate().toString();
553    StringBuilder result;
554
555    for (unsigned i1 = 0; i1 < s1.length(); ++i1) {
556        UChar ch = s1[i1];
557        size_t i2 = s2.find(ch);
558
559        if (i2 == notFound)
560            result.append(ch);
561        else if (i2 < s3.length())
562            result.append(s3[i2]);
563    }
564
565    return result.toString();
566}
567
568Value FunBoolean::evaluate() const
569{
570    return arg(0)->evaluate().toBoolean();
571}
572
573Value FunNot::evaluate() const
574{
575    return !arg(0)->evaluate().toBoolean();
576}
577
578Value FunTrue::evaluate() const
579{
580    return true;
581}
582
583Value FunLang::evaluate() const
584{
585    String lang = arg(0)->evaluate().toString();
586
587    const Attribute* languageAttribute = 0;
588    Node* node = evaluationContext().node.get();
589    while (node) {
590        if (node->isElementNode()) {
591            Element* element = toElement(node);
592            if (element->hasAttributes())
593                languageAttribute = element->getAttributeItem(XMLNames::langAttr);
594        }
595        if (languageAttribute)
596            break;
597        node = node->parentNode();
598    }
599
600    if (!languageAttribute)
601        return false;
602
603    String langValue = languageAttribute->value();
604    while (true) {
605        if (equalIgnoringCase(langValue, lang))
606            return true;
607
608        // Remove suffixes one by one.
609        size_t index = langValue.reverseFind('-');
610        if (index == notFound)
611            break;
612        langValue = langValue.left(index);
613    }
614
615    return false;
616}
617
618Value FunFalse::evaluate() const
619{
620    return false;
621}
622
623Value FunNumber::evaluate() const
624{
625    if (!argCount())
626        return Value(Expression::evaluationContext().node.get()).toNumber();
627    return arg(0)->evaluate().toNumber();
628}
629
630Value FunSum::evaluate() const
631{
632    Value a = arg(0)->evaluate();
633    if (!a.isNodeSet())
634        return 0.0;
635
636    double sum = 0.0;
637    const NodeSet& nodes = a.toNodeSet();
638    // To be really compliant, we should sort the node-set, as floating point addition is not associative.
639    // However, this is unlikely to ever become a practical issue, and sorting is slow.
640
641    for (unsigned i = 0; i < nodes.size(); i++)
642        sum += Value(stringValue(nodes[i])).toNumber();
643
644    return sum;
645}
646
647Value FunFloor::evaluate() const
648{
649    return floor(arg(0)->evaluate().toNumber());
650}
651
652Value FunCeiling::evaluate() const
653{
654    return ceil(arg(0)->evaluate().toNumber());
655}
656
657double FunRound::round(double val)
658{
659    if (!std::isnan(val) && !std::isinf(val)) {
660        if (std::signbit(val) && val >= -0.5)
661            val *= 0; // negative zero
662        else
663            val = floor(val + 0.5);
664    }
665    return val;
666}
667
668Value FunRound::evaluate() const
669{
670    return round(arg(0)->evaluate().toNumber());
671}
672
673struct FunctionMapping {
674    const char* name;
675    FunctionRec function;
676};
677
678static void createFunctionMap()
679{
680    static const FunctionMapping functions[] = {
681        { "boolean", { &createFunBoolean, 1 } },
682        { "ceiling", { &createFunCeiling, 1 } },
683        { "concat", { &createFunConcat, Interval(2, Interval::Inf) } },
684        { "contains", { &createFunContains, 2 } },
685        { "count", { &createFunCount, 1 } },
686        { "false", { &createFunFalse, 0 } },
687        { "floor", { &createFunFloor, 1 } },
688        { "id", { &createFunId, 1 } },
689        { "lang", { &createFunLang, 1 } },
690        { "last", { &createFunLast, 0 } },
691        { "local-name", { &createFunLocalName, Interval(0, 1) } },
692        { "name", { &createFunName, Interval(0, 1) } },
693        { "namespace-uri", { &createFunNamespaceURI, Interval(0, 1) } },
694        { "normalize-space", { &createFunNormalizeSpace, Interval(0, 1) } },
695        { "not", { &createFunNot, 1 } },
696        { "number", { &createFunNumber, Interval(0, 1) } },
697        { "position", { &createFunPosition, 0 } },
698        { "round", { &createFunRound, 1 } },
699        { "starts-with", { &createFunStartsWith, 2 } },
700        { "string", { &createFunString, Interval(0, 1) } },
701        { "string-length", { &createFunStringLength, Interval(0, 1) } },
702        { "substring", { &createFunSubstring, Interval(2, 3) } },
703        { "substring-after", { &createFunSubstringAfter, 2 } },
704        { "substring-before", { &createFunSubstringBefore, 2 } },
705        { "sum", { &createFunSum, 1 } },
706        { "translate", { &createFunTranslate, 3 } },
707        { "true", { &createFunTrue, 0 } },
708    };
709
710    functionMap = new HashMap<String, FunctionRec>;
711    for (size_t i = 0; i < WTF_ARRAY_LENGTH(functions); ++i)
712        functionMap->set(functions[i].name, functions[i].function);
713}
714
715Function* createFunction(const String& name, const Vector<Expression*>& args)
716{
717    if (!functionMap)
718        createFunctionMap();
719
720    HashMap<String, FunctionRec>::iterator functionMapIter = functionMap->find(name);
721    FunctionRec* functionRec = 0;
722
723    if (functionMapIter == functionMap->end() || !(functionRec = &functionMapIter->value)->args.contains(args.size()))
724        return 0;
725
726    Function* function = functionRec->factoryFn();
727    function->setArguments(args);
728    function->setName(name);
729    return function;
730}
731
732}
733}
734