// Message.cpp : Implementation of cMessage #include "stdafx.h" #include "DecalNet.h" #include "NetService.h" #include "Message.h" #include "MessageLoaders.h" #include "MessageParsers.h" #include "MessageRoot.h" cMessage::cFieldList::iterator cMessage::cLoadContext::lookupField( cMessage::cMessageElement *pElement ) { for( cLoadContext *pContext = this; pContext != NULL; pContext = pContext->m_pParent ) { for( cFieldList::iterator i = pContext->m_pMessage->m_fields.begin() + pContext->m_dwOffset; i != pContext->m_pMessage->m_fields.end(); i += i->m_nOwns ) { if( i->m_pSchema == pElement ) return i; } } return m_pMessage->m_fields.end(); } void cMessage::cMessageSchema::loadSchema( MSXML::IXMLDOMDocumentPtr &pDoc, DWORD dwSchema ) { TCHAR szQuery[ 255 ]; ::_stprintf( szQuery, _T( "/schema/messages/message[@type='%04X']" ), dwSchema ); MSXML::IXMLDOMElementPtr pMessage = pDoc->selectSingleNode( szQuery ); if( pMessage.GetInterfacePtr() == NULL ) // Nothing here, so we create a valid but empty message return; cElementParser::cContext c( &m_members ); c.parseChildren( pMessage ); } void cMessage::init() { // Load the schema CComBSTR strTemplate( _T( "%decal%\\messages.xml" ) ), strPath; m_pService->m_pDecal->MapPath( strTemplate, &strPath ); g_pXML.CreateInstance( __uuidof( MSXML::DOMDocument ) ); BOOL bSuccess = g_pXML->load( strPath.m_str ); if( ! bSuccess ) { USES_CONVERSION; std::string szXML; DecryptXML( OLE2A( strPath.m_str ), szXML ); if( szXML != "" ) bSuccess = g_pXML->loadXML( _bstr_t( szXML.c_str() ) ); } // Initialize our schema helper objects cFieldLoader::init(); cElementParser::init(); } void cMessage::DecryptXML( const char *szPath, std::string &szXML ) { if( szPath == NULL ) { szXML = ""; return; } FILE *f = fopen( szPath, "rb" ); if( f == NULL ) { szXML = ""; return; } szXML.clear(); unsigned char szBuffer[1025]; try { CCryptProv crypt; if( crypt.Initialize( PROV_RSA_FULL, "Decal_Memlocs", MS_DEF_PROV ) == NTE_BAD_KEYSET ) crypt.Initialize( PROV_RSA_FULL, "Decal_Memlocs", MS_DEF_PROV, CRYPT_NEWKEYSET ); CCryptMD5Hash hash; hash.Initialize( crypt ); hash.AddString( DECAL_KEY ); CCryptDerivedKey key; key.Initialize( crypt, hash ); DWORD dwDecLen = 0; while( ! feof(f) ) { memset( szBuffer, 0, sizeof( szBuffer ) ); dwDecLen = fread( szBuffer, 1, 1024, f ); key.Decrypt( feof(f), (BYTE *) szBuffer, &dwDecLen ); szXML += (char *)szBuffer; } key.Destroy(); hash.Destroy(); crypt.Release(); } catch( ... ) { // crap... szXML = ""; } fclose( f ); } void cMessage::term() { g_schema.clear(); cElementParser::term(); cFieldLoader::term(); if( g_pXML.GetInterfacePtr() != NULL ) g_pXML.Release(); } ///////////////////////////////////////////////////////////////////////////// // cMessage cMessage::cMessage() : m_nType( 0 ), m_pStartCrack( NULL ), m_pEndCrack( NULL ), m_pEndData( NULL ), m_pSchema( NULL ), m_pRoot( NULL ) { } cMessage::cFieldList::iterator cMessage::getFieldFromElement( cMessageElement *pElement ) { for( cFieldList::iterator i = m_fields.begin(); i != m_fields.end(); i += i->m_nOwns ) { cMessageElement *pSchema = i->m_pSchema; if( pSchema == pElement ) break; } return i; } void cMessage::crackMessage( BYTE *pBody, DWORD dwSize ) { m_pStartCrack = pBody; m_pEndData = pBody + dwSize + sizeof( DWORD ); m_nType = *reinterpret_cast< long * >( pBody ); m_fields.clear(); m_pEndCrack = m_pStartCrack + sizeof( long ); m_pSchema = NULL; } bool cMessage::loadNextElement() { if( m_iLoaded == m_pSchema->m_members.end() ) return false; // Attempt to load what we've got cLoadContext context( this ); if( !m_iLoaded->get()->load( context ) ) { m_iLoaded = m_pSchema->m_members.end(); return false; } ++ m_iLoaded; return true; } void cMessage::loadAllElements() { if( m_pSchema == NULL ) { // First look up the message to see if it's already decoded cMessageSchemaMap::iterator i_schema = g_schema.find( m_nType ); if( i_schema == g_schema.end() ) { // Make a new one m_pSchema = new cMessageSchema; m_pSchema->loadSchema( g_pXML, m_nType ); g_schema.insert( cMessageSchemaMap::value_type( m_nType, VSBridge::auto_ptr< cMessageSchema >( m_pSchema ) ) ); } else m_pSchema = i_schema->second.get(); // At this point we have "a" schema of some quality // set up the cursors m_iLoaded = m_pSchema->m_members.begin(); } if( m_iLoaded == m_pSchema->m_members.end() ) return; cLoadContext context( this ); while ( m_iLoaded != m_pSchema->m_members.end () ) { if ( m_iLoaded->get()->load ( context ) ) ++ m_iLoaded; else m_iLoaded = m_pSchema->m_members.end(); } } STDMETHODIMP cMessage::get_Type(long *pVal) { _ASSERTE( pVal != NULL ); *pVal = *reinterpret_cast< long * >( m_pStartCrack ); return S_OK; } STDMETHODIMP cMessage::get_Data(VARIANT *pVal) { long nSize = m_pEndData - m_pStartCrack - sizeof( DWORD ); if( nSize == 0 ) { // Special case, this message is entirely cracked - return NULL pVal->vt = VT_NULL; pVal->intVal = 0; return S_OK; } // We've got some data to share SAFEARRAYBOUND sab = { nSize, 0 }; SAFEARRAY *pArray = ::SafeArrayCreate( VT_UI1, 1, &sab ); ::SafeArrayAllocData( pArray ); LPVOID pvData; ::SafeArrayAccessData( pArray, &pvData ); ::memcpy( pvData, m_pStartCrack, nSize ); ::SafeArrayUnaccessData( pArray ); pVal->vt = VT_ARRAY | VT_UI1; pVal->parray = pArray; return S_OK; } STDMETHODIMP cMessage::get_Begin(IMessageIterator **pVal) { if( m_pRoot == NULL ) { // Create a new message root object CComObject< cMessageRoot > *pRoot; CComObject< cMessageRoot >::CreateInstance( &pRoot ); // Do an extra addref so this object sticks around after the client is done with it pRoot->AddRef(); pRoot->init( this ); m_pRoot = pRoot; } m_pRoot->Reset(); if( m_pSchema == NULL ) { // First look up the message to see if it's already decoded cMessageSchemaMap::iterator i_schema = g_schema.find( m_nType ); if( i_schema == g_schema.end() ) { // Make a new one m_pSchema = new cMessageSchema; m_pSchema->loadSchema( g_pXML, m_nType ); g_schema.insert( cMessageSchemaMap::value_type( m_nType, VSBridge::auto_ptr< cMessageSchema >( m_pSchema ) ) ); } else m_pSchema = i_schema->second.get(); // At this point we have "a" schema of some quality // set up the cursors m_iLoaded = m_pSchema->m_members.begin(); } return m_pRoot->QueryInterface( IID_IMessageIterator, reinterpret_cast< void ** >( pVal ) ); } STDMETHODIMP cMessage::get_Member(VARIANT vName, VARIANT *pVal) { loadAllElements (); ::VariantInit (pVal); if( vName.vt == VT_BSTR ) { _bstr_t bstrName = vName; // Iterate over the fields and return our match - in this loop we'll do incremental // cracking, so it's a little messy. When we hit the end, we look to see if there are // more uncracked fields int nFieldCount = m_fields.size(); for( int nField = 0; nField != nFieldCount; nField += m_fields[ nField ].m_nOwns ) { if( m_fields[ nField ].m_pSchema->m_strName == bstrName ) { m_fields[ nField ].m_pSchema->getValue( this, m_fields.begin() + nField, pVal ); return S_OK; } } pVal->vt = VT_EMPTY; return S_OK; } // Attempt to convert it into an index HRESULT hRes = ::VariantChangeType( &vName, &vName, 0, VT_I4 ); if( FAILED( hRes ) ) { _ASSERTE( FALSE ); return hRes; } // Check if the value is in range long nIndex = vName.lVal; if( nIndex < 0 ) { _ASSERTE( nIndex >= 0 ); return E_INVALIDARG; } // Now, one problem is we aren't exactly sure how big the array is, so we have to walk // through the fields, skipping appropriately for( cFieldList::iterator i = m_fields.begin(); i != m_fields.end(); i += i->m_nOwns, -- nIndex ) { if( nIndex == 0 ) { // We've found the index - extract the value i->m_pSchema->getValue( this, i, pVal ); return S_OK; } } // The index was too high _ASSERTE( FALSE ); return E_INVALIDARG; } STDMETHODIMP cMessage::get_MemberName(long nIndex, BSTR *pVal) { loadAllElements (); _ASSERTE( nIndex >= 0 ); _ASSERTE( pVal != NULL ); USES_CONVERSION; for( cFieldList::iterator i = m_fields.begin(); i != m_fields.end(); i += i->m_nOwns, -- nIndex ) { if( nIndex == 0 ) { // We've found the index - extract the value *pVal = OLE2BSTR( i->m_pSchema->m_strName ); return S_OK; } } // The index was too high return E_INVALIDARG; } STDMETHODIMP cMessage::get_Count(long *pVal) { loadAllElements (); _ASSERTE( pVal != NULL ); *pVal = 0; for( cFieldList::iterator i = m_fields.begin(); i != m_fields.end(); i += i->m_nOwns, ++ ( *pVal ) ); return S_OK; }