/*
 *  Copyright 2001-2005 Internet2
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/* SAMLAssertion.cpp - SAML assertion implementation

   Scott Cantor
   5/27/02

   $History:$
*/

#include "internal.h"

#include <ctime>

#include <xsec/enc/XSECCryptoException.hpp>
#include <xsec/framework/XSECException.hpp>

using namespace std;
using namespace saml;

SAMLAssertion::SAMLAssertion(
    const XMLCh* issuer,
    const SAMLDateTime* notBefore,
    const SAMLDateTime* notOnOrAfter,
    const Iterator<SAMLCondition*>& conditions,
    const Iterator<SAMLStatement*>& statements,
    const Iterator<const XMLCh*>& adviceRefs,
    const Iterator<SAMLAssertion*>& adviceAssertions,
    const Iterator<DOMElement*>& adviceElements,
    const XMLCh* id,
    const SAMLDateTime* issueInstant
    ) : m_issuer(NULL), m_issueInstant(NULL), m_notBefore(NULL), m_notOnOrAfter(NULL), m_scratch(NULL)
{
    RTTI(SAMLAssertion);

    m_minor=SAMLConfig::getConfig().compatibility_mode ? 0 : 1;
    m_id=XML::assign(id);
    m_issuer=XML::assign(issuer);
    if (issueInstant) {
        m_issueInstant=new SAMLDateTime(*issueInstant);
        m_issueInstant->parseDateTime();
    }
    if (notBefore) {
        m_notBefore=new SAMLDateTime(*notBefore);
        m_notBefore->parseDateTime();
    }
    if (notOnOrAfter) {
        m_notOnOrAfter=new SAMLDateTime(*notOnOrAfter);
        m_notOnOrAfter->parseDateTime();
    }

    while (conditions.hasNext())
        m_conditions.push_back(static_cast<SAMLCondition*>(conditions.next()->setParent(this)));

    while (statements.hasNext())
        m_statements.push_back(static_cast<SAMLStatement*>(statements.next()->setParent(this)));
    
    while (adviceRefs.hasNext())
        m_adviceRefs.push_back(XML::assign(adviceRefs.next()));
    
    while (adviceAssertions.hasNext())
        m_adviceAssertions.push_back(static_cast<SAMLAssertion*>(adviceAssertions.next()->setParent(this)));
    
    while (adviceElements.hasNext()) {
        if (!m_scratch)
            m_scratch=DOMImplementationRegistry::getDOMImplementation(NULL)->createDocument();
        m_adviceElements.push_back(static_cast<DOMElement*>(m_scratch->importNode(adviceElements.next(),true)));
    }
}

SAMLAssertion::SAMLAssertion(DOMElement* e)
    : m_issuer(NULL), m_issueInstant(NULL), m_notBefore(NULL), m_notOnOrAfter(NULL), m_scratch(NULL)
{
    RTTI(SAMLAssertion);
    fromDOM(e);
}

SAMLAssertion::SAMLAssertion(istream& in)
    : SAMLSignedObject(in), m_issuer(NULL), m_issueInstant(NULL), m_notBefore(NULL), m_notOnOrAfter(NULL), m_scratch(NULL)
{
    RTTI(SAMLAssertion);
    fromDOM(m_document->getDocumentElement());
}

SAMLAssertion::SAMLAssertion(istream& in, int minor)
    : SAMLSignedObject(in,minor), m_issuer(NULL), m_issueInstant(NULL), m_notBefore(NULL), m_notOnOrAfter(NULL), m_scratch(NULL)
{
    RTTI(SAMLAssertion);
    fromDOM(m_document->getDocumentElement());
}

SAMLAssertion::~SAMLAssertion()
{
    if (m_scratch)
        m_scratch->release();
    if (m_bOwnStrings) {
        XMLString::release(&m_issuer);
        for (vector<const XMLCh*>::const_iterator i=m_adviceRefs.begin(); i!=m_adviceRefs.end(); i++) {
            XMLCh* temp=const_cast<XMLCh*>(*i);
            XMLString::release(&temp);
        }
    }
    delete m_issueInstant;
    delete m_notBefore;
    delete m_notOnOrAfter;
    for (vector<SAMLCondition*>::const_iterator j=m_conditions.begin(); j!=m_conditions.end(); j++)
        delete (*j);
    for (vector<SAMLStatement*>::const_iterator k=m_statements.begin(); k!=m_statements.end(); k++)
        delete (*k);
    for (vector<SAMLAssertion*>::const_iterator a=m_adviceAssertions.begin(); a!=m_adviceAssertions.end(); a++)
        delete (*a);
}

void SAMLAssertion::ownStrings()
{
    if (!m_bOwnStrings) {
        SAMLSignedObject::ownStrings();
        m_issuer=XML::assign(m_issuer);
        for (vector<const XMLCh*>::iterator i=m_adviceRefs.begin(); i!=m_adviceRefs.end(); i++)
            (*i)=XML::assign(*i);
        m_bOwnStrings = true;
    }
}

void SAMLAssertion::fromDOM(DOMElement* e)
{
    SAMLObject::fromDOM(e);

    if (SAMLConfig::getConfig().strict_dom_checking && !XML::isElementNamed(e,XML::SAML_NS,L(Assertion)))
        throw MalformedException(SAMLException::RESPONDER,"SAMLAssertion::fromDOM() missing saml:Assertion at root");
    
    if (XMLString::parseInt(e->getAttributeNS(NULL,L(MajorVersion)))!=1)
        throw MalformedException(SAMLException::VERSIONMISMATCH,"SAMLAssertion::fromDOM() detected incompatible assertion major version");
    
    m_minor=XMLString::parseInt(e->getAttributeNS(NULL,L(MinorVersion)));
    m_id=const_cast<XMLCh*>(e->getAttributeNS(NULL,L(AssertionID)));
    m_issuer=const_cast<XMLCh*>(e->getAttributeNS(NULL,L(Issuer)));
    m_issueInstant=new SAMLDateTime(e->getAttributeNS(NULL,L(IssueInstant)));
    m_issueInstant->parseDateTime();

    DOMElement* n=XML::getFirstChildElement(e);
    while (n) {
        // The top level children may be one of three different types.
        if (XML::isElementNamed(n,XML::SAML_NS,L(Conditions))) {

            if (n->hasAttributeNS(NULL,L(NotBefore))) {
                m_notBefore=new SAMLDateTime(n->getAttribute(L(NotBefore)));
                m_notBefore->parseDateTime();
            }

            if (n->hasAttributeNS(NULL,L(NotOnOrAfter))) {
                m_notOnOrAfter=new SAMLDateTime(n->getAttribute(L(NotOnOrAfter)));
                m_notOnOrAfter->parseDateTime();
            }
            
            // Iterate over conditions.
            DOMElement* cond=XML::getFirstChildElement(n);
            while (cond) {
                SAMLCondition* pcond=SAMLCondition::getInstance(cond);
                if (!pcond)
                    throw UnsupportedExtensionException("SAMLAssertion::fromDOM() unable to locate implementation for condition type");
                pcond->setParent(this);
                m_conditions.push_back(pcond);
                cond=XML::getNextSiblingElement(cond);
            }
        }
        else if (XML::isElementNamed(n,XML::SAML_NS,L(Advice))) {
            DOMElement* child=XML::getFirstChildElement(n);
            while (child) {
                if (XML::isElementNamed(child, XML::SAML_NS, L(AssertionIDReference)) && child->hasChildNodes()) {
                    m_adviceRefs.push_back(child->getFirstChild()->getNodeValue());
                }
                else if (XML::isElementNamed(child, XML::SAML_NS, L(Assertion))) {
                    SAMLAssertion* a=new SAMLAssertion(child);
                    a->setParent(this);
                    m_adviceAssertions.push_back(a);
                }
                else {
                    m_adviceElements.push_back(child);
                }
                child=XML::getNextSiblingElement(child);
            }
        }
        else if (XML::isElementNamed(n,XML::XMLSIG_NS,L(Signature))) {
            SAMLInternalConfig& conf=dynamic_cast<SAMLInternalConfig&>(SAMLConfig::getConfig());
            try {
                m_signature=conf.m_xsec->newSignatureFromDOM(n->getOwnerDocument(),n);
                m_signature->load();
                m_sigElement=n;
            }
            catch(XSECException& e) {
                auto_ptr_char temp(e.getMsg());
                SAML_log.error("caught an XMLSec exception: %s",temp.get());
                throw MalformedException("caught an XMLSec exception while parsing signature: $1",params(1,temp.get()));
            }
            catch(XSECCryptoException& e) {
                SAML_log.error("caught an XMLSec crypto exception: %s",e.getMsg());
                throw MalformedException("caught an XMLSec crypto exception while parsing signature: $1",params(1,e.getMsg()));
            }
        }
        else {
            SAMLStatement* pstate=SAMLStatement::getInstance(n);
            if (!pstate)
                throw UnsupportedExtensionException("SAMLAssertion::fromDOM() unable to locate implementation for statement type");
            pstate->setParent(this);
            m_statements.push_back(pstate);
        }
        n=XML::getNextSiblingElement(n);
    }
    checkValidity();
}

void SAMLAssertion::insertSignature()
{
    m_root->appendChild(getSignatureElement());
}

void SAMLAssertion::setMinorVersion(int minor)
{
    m_minor=minor;
    ownStrings();
    setDirty();
}

void SAMLAssertion::setIssuer(const XMLCh* issuer)
{
    if (XML::isEmpty(issuer))
        throw SAMLException("issuer cannot be null or empty");
        
    if (m_bOwnStrings)
        XMLString::release(&m_issuer);
    else {
        m_issuer=NULL;
        ownStrings();
    }
    m_issuer=XML::assign(issuer);
    setDirty();
}

void SAMLAssertion::setIssueInstant(const SAMLDateTime* instant)
{
    delete m_issueInstant;
    m_issueInstant=NULL;
    if (instant) {
        m_issueInstant=new SAMLDateTime(*instant);
        m_issueInstant->parseDateTime();
    }
    ownStrings();
    setDirty();
}

void SAMLAssertion::setNotBefore(const SAMLDateTime* notBefore)
{
    delete m_notBefore;
    m_notBefore=NULL;
    if (notBefore) {
        m_notBefore=new SAMLDateTime(*notBefore);
        m_notBefore->parseDateTime();
    }
    ownStrings();
    setDirty();
}

void SAMLAssertion::setNotOnOrAfter(const SAMLDateTime* notOnOrAfter)
{
    delete m_notOnOrAfter;
    m_notOnOrAfter=NULL;
    if (notOnOrAfter) {
        m_notOnOrAfter=new SAMLDateTime(*notOnOrAfter);
        m_notOnOrAfter->parseDateTime();
    }
    ownStrings();
    setDirty();
}

void SAMLAssertion::setConditions(const Iterator<SAMLCondition*>& conditions)
{
    while (m_conditions.size())
        removeCondition(0);
    while (conditions.hasNext())
        addCondition(conditions.next());
}

void SAMLAssertion::addCondition(SAMLCondition* condition)
{
    if (condition) {
        condition->setParent(this);
        m_conditions.push_back(condition);
        ownStrings();
        setDirty();
    }
    else
        throw SAMLException("condition cannot be null");
}

void SAMLAssertion::removeCondition(unsigned long index)
{
    SAMLCondition* kill=m_conditions[index];
    m_conditions.erase(m_conditions.begin()+index);
    delete kill;
    ownStrings();
    setDirty();
}

void SAMLAssertion::setStatements(const Iterator<SAMLStatement*>& statements)
{
    while (m_statements.size())
        removeStatement(0);
    while (statements.hasNext())
        addStatement(statements.next());
}

void SAMLAssertion::addStatement(SAMLStatement* statement)
{
    if (statement) {
        statement->setParent(this);
        m_statements.push_back(statement);
        ownStrings();
        setDirty();
    }
    else
        throw SAMLException("statement cannot be null");
}

void SAMLAssertion::removeStatement(unsigned long index)
{
    SAMLStatement* kill=m_statements[index];
    m_statements.erase(m_statements.begin()+index);
    delete kill;
    ownStrings();
    setDirty();
}

void SAMLAssertion::setAdvice(const Iterator<const XMLCh*>& advice)
{
    while (m_adviceRefs.size())
        removeAdviceRef(0);
    while (advice.hasNext())
        addAdvice(advice.next());
}

void SAMLAssertion::setAdvice(const Iterator<SAMLAssertion*>& advice)
{
    while (m_adviceAssertions.size())
        removeAdviceAssertion(0);
    while (advice.hasNext())
        addAdvice(advice.next());
}

void SAMLAssertion::setAdvice(const Iterator<DOMElement*>& advice)
{
    while (m_adviceElements.size())
        removeAdviceElement(0);
    while (advice.hasNext())
        addAdvice(advice.next());
}

void SAMLAssertion::addAdvice(const XMLCh* advice)
{
    if (XML::isEmpty(advice))
        throw SAMLException("Advice assertion reference cannot be null or empty");
    
    ownStrings();
    m_adviceRefs.push_back(XML::assign(advice));
    setDirty();
}

void SAMLAssertion::addAdvice(SAMLAssertion* advice)
{
    if (!advice)
        throw SAMLException("advice assertion cannot be null");
    
    ownStrings();
    m_adviceAssertions.push_back(static_cast<SAMLAssertion*>(advice->setParent(this)));
    setDirty();
}

void SAMLAssertion::addAdvice(DOMElement* advice)
{
    if (!advice || advice->getParentNode() || !XMLString::compareString(advice->getNamespaceURI(),XML::SAML_NS))
        throw SAMLException("advice element must have no parent and must not be in the SAML namespace");
    
    ownStrings();
    if (m_document)
        m_adviceElements.push_back(static_cast<DOMElement*>(m_document->importNode(advice,true)));
    else {
        if (!m_scratch)
            m_scratch=DOMImplementationRegistry::getDOMImplementation(NULL)->createDocument();
        m_adviceElements.push_back(static_cast<DOMElement*>(m_scratch->importNode(advice,true)));
    }
    setDirty();
}

void SAMLAssertion::removeAdviceRef(unsigned long index)
{
    if (m_bOwnStrings) {
        XMLCh* ch=const_cast<XMLCh*>(m_adviceRefs[index]);
        XMLString::release(&ch);
    }
    m_adviceRefs.erase(m_adviceRefs.begin()+index);
    ownStrings();
    setDirty();
}

void SAMLAssertion::removeAdviceAssertion(unsigned long index)
{
    delete m_adviceAssertions[index];
    m_adviceAssertions.erase(m_adviceAssertions.begin()+index);
    ownStrings();
    setDirty();
}

void SAMLAssertion::removeAdviceElement(unsigned long index)
{
    DOMElement* advice=m_adviceElements[index];
    m_adviceElements.erase(m_adviceElements.begin()+index);
    if (advice) {
        if (advice->getParentNode())
            advice->getParentNode()->removeChild(advice);
        advice->release();
    }
    ownStrings();
    setDirty();
}

DOMElement* SAMLAssertion::buildRoot(DOMDocument* doc, bool xmlns) const
{
    DOMElement* a=doc->createElementNS(XML::SAML_NS,L(Assertion));
    a->setAttributeNS(XML::XMLNS_NS,L(xmlns),XML::SAML_NS);
    a->setAttributeNS(XML::XMLNS_NS,L_QNAME(xmlns,saml),XML::SAML_NS);
    a->setAttributeNS(XML::XMLNS_NS,L_QNAME(xmlns,samlp),XML::SAMLP_NS);
    if (xmlns) {
        a->setAttributeNS(XML::XMLNS_NS,L_QNAME(xmlns,xsi),XML::XSI_NS);
        a->setAttributeNS(XML::XMLNS_NS,L_QNAME(xmlns,xsd),XML::XSD_NS);
    }
    return a;
}

DOMNode* SAMLAssertion::toDOM(DOMDocument* doc, bool xmlns) const
{
    SAMLObject::toDOM(doc,xmlns);
    DOMElement* a=static_cast<DOMElement*>(m_root);
    
    if (m_bDirty) {
        doc=a->getOwnerDocument();
        static const XMLCh One[]={chDigit_1, chNull};
        static const XMLCh Zero[]={chDigit_0, chNull};
        a->setAttributeNS(NULL,L(MajorVersion),One);
        a->setAttributeNS(NULL,L(MinorVersion),m_minor==0 ? Zero : One);
    
        // Only generate a new ID if we don't have one already.
        if (!m_id) {
            SAMLIdentifier id;
            m_id=XMLString::replicate(id);
        }
        a->setAttributeNS(NULL,L(AssertionID),m_id);
        if (m_minor==1)
            a->setIdAttributeNS(NULL,L(AssertionID));
        a->setAttributeNS(NULL,L(Issuer),m_issuer);
    
        if (!m_issueInstant) {
            m_issueInstant=new SAMLDateTime(time(NULL));
            m_issueInstant->parseDateTime();
        }
        a->setAttributeNS(NULL,L(IssueInstant),m_issueInstant->getRawData());
    
        if (m_notBefore || m_notOnOrAfter || !m_conditions.empty()) {
            DOMElement* c=doc->createElementNS(XML::SAML_NS,L(Conditions));
            if (m_notBefore)
                c->setAttributeNS(NULL,L(NotBefore),m_notBefore->getRawData());
            if (m_notOnOrAfter)
                c->setAttributeNS(NULL,L(NotOnOrAfter),m_notOnOrAfter->getRawData());
            for (vector<SAMLCondition*>::const_iterator i=m_conditions.begin(); i!=m_conditions.end(); i++)
                c->appendChild((*i)->toDOM(doc,false));
            a->appendChild(c);
        }

        DOMElement* advice=NULL;
        if (!m_adviceRefs.empty()) {
            if (!advice)
                advice=doc->createElementNS(XML::SAML_NS,L(Advice));
            Iterator<const XMLCh*> refs(m_adviceRefs);
            while (refs.hasNext()) {
                DOMElement* ref=doc->createElementNS(XML::SAML_NS, L(AssertionIDReference));
                ref->appendChild(doc->createTextNode(refs.next()));
                advice->appendChild(ref);
            }
        }
        if (!m_adviceAssertions.empty()) {
            if (!advice)
                advice=doc->createElementNS(XML::SAML_NS,L(Advice));
            Iterator<SAMLAssertion*> asns(m_adviceAssertions);
            while (asns.hasNext())
                advice->appendChild(asns.next()->toDOM(doc,false));
        }
        if (!m_adviceElements.empty()) {
            if (!advice)
                advice=doc->createElementNS(XML::SAML_NS,L(Advice));
            for (vector<DOMElement*>::iterator els=m_adviceElements.begin(); els!=m_adviceElements.end(); els++) {
                if ((*els)->getOwnerDocument() != doc) {
                    DOMElement* copy=static_cast<DOMElement*>(doc->importNode((*els),true));
                    if ((*els)->getParentNode())
                        (*els)->getParentNode()->removeChild(*els);
                    (*els)->release();
                    (*els)=copy;
                }
                advice->appendChild(*els);
            }
        }
        if (advice)
            a->appendChild(advice);
    
        for (vector<SAMLStatement*>::const_iterator j=m_statements.begin(); j!=m_statements.end(); j++)
            a->appendChild((*j)->toDOM(doc,false));
        
        setClean();
    }
    else if (xmlns) {
        DECLARE_DEF_NAMESPACE(a,XML::SAML_NS);
        DECLARE_NAMESPACE(a,saml,XML::SAML_NS);
        DECLARE_NAMESPACE(a,samlp,XML::SAMLP_NS);
        DECLARE_NAMESPACE(a,xsi,XML::XSI_NS);
        DECLARE_NAMESPACE(a,xsd,XML::XSD_NS);
    }

    return m_root;
}

void SAMLAssertion::checkValidity() const
{
    if (!m_issuer || m_statements.empty())
        throw MalformedException("Assertion is invalid, must have Issuer and at least one Statement");
}

SAMLObject* SAMLAssertion::clone() const
{
    SAMLAssertion* a = new SAMLAssertion(
        m_issuer,
        m_notBefore,
        m_notOnOrAfter,
        getConditions().clone(),
        getStatements().clone(),
        m_adviceRefs,
        getAdviceAssertions().clone(),
        m_adviceElements,
        m_id,
        m_issueInstant
        );
    a->setMinorVersion(m_minor);
    return a;
}
