/*
    Windows NT Security functions library.
    Copyright (C) 1995  Jeremy R. Allison

    This library is free software; you can redistribute it and/or
    modify it under the terms of the GNU Library General Public
    License as published by the Free Software Foundation; either
    version 2 of the License, or any later version.

    This library is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
    Library General Public License for more details.

    You should have received a copy of the GNU Library General Public
    License along with this library; if not, write to the Free
    Software Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.

	$Log: acl.cpp,v $
// Revision 1.1  1995/06/30  18:32:55  jra
// Initial revision
//
*/

#include "seclib.h"
#include "acl.h"
#include "autoheap.h"
#include"list.h"

//
// Functions that implement the aceEntry class.
//

//
// Construct an ACE	from scratch.
//
BOOL aceEntry::makeAce( const SIDobj& sid, BYTE type, ACCESS_MASK mask, BYTE flag) {
	u_ace_.palAce_ = 0;
	u_ace_.padAce_ = 0;
	switch (type) {
	case ACCESS_ALLOWED_ACE_TYPE:
		{
		ACCESS_ALLOWED_ACE *pace;
		// Allocate a new ACCESS_ALLOWED_ACE structure
		pace = (ACCESS_ALLOWED_ACE *)
			new char [ sizeof(ACCESS_ALLOWED_ACE) + sid.sidSize() - sizeof(DWORD) ];
		if(pace == 0) {
			SetLastError(ERROR_OUTOFMEMORY);
			return FALSE;
		}
		if(!CopySid(sid.sidSize(), (SID *)&pace->SidStart, sid.getSid())) {
			delete [] (char *)pace;
			return FALSE;
		}
		pace->Mask = mask;
		pace->Header.AceType = ACCESS_ALLOWED_ACE_TYPE;
		pace->Header.AceFlags = flag;
		pace->Header.AceSize = sizeof(ACCESS_ALLOWED_ACE) + sid.sidSize() - sizeof(DWORD);
		u_ace_.palAce_ = pace;	   
		break;
		}
	case ACCESS_DENIED_ACE_TYPE:
		{
		ACCESS_DENIED_ACE *pace;
		
		// Allocate a new ACCESS_DENIED_ACE structure
		pace = (ACCESS_DENIED_ACE *)
			new char [ sizeof(ACCESS_DENIED_ACE) + sid.sidSize() -	sizeof(DWORD) ];
		if(pace == 0) {
			SetLastError(ERROR_OUTOFMEMORY);
			return FALSE;
		}
		if(!CopySid(sid.sidSize(), (SID *)&pace->SidStart, sid.getSid())) {
			delete [] (char *)pace;
			return FALSE;
		}
		pace->Mask = mask;
		pace->Header.AceType = ACCESS_DENIED_ACE_TYPE;
		pace->Header.AceFlags = flag;
		pace->Header.AceSize = sizeof(ACCESS_DENIED_ACE) + sid.sidSize() - sizeof(DWORD);
		u_ace_.padAce_ = pace;
		break;
		}
	}
	return TRUE;
}

//
// Create an ACE entry from a name.
//
aceEntry::aceEntry(const TCHAR *user, const TCHAR *machine, ACCESS_MASK mask, BYTE type, BYTE flags)
		:type_(type) {
	
	if(type_ != ACCESS_ALLOWED_ACE_TYPE && type_ != ACCESS_DENIED_ACE_TYPE) {
		SetLastErrorEx(ERROR_INVALID_PARAMETER, SLE_ERROR);
		throw (DWORD)ERROR_INVALID_PARAMETER;
	}
	SIDobj sid(user, machine);

	if(!makeAce( sid, type, mask, flags)) {
		throw GetLastError();
	}
}

//
// Create an ACE entry from a SID *.
//

aceEntry::aceEntry(const SIDobj &sid, ACCESS_MASK mask, BYTE type, BYTE flags)
		:type_(type) {
	if(type_ != ACCESS_ALLOWED_ACE_TYPE && type_ != ACCESS_DENIED_ACE_TYPE) {
		SetLastErrorEx(ERROR_INVALID_PARAMETER, SLE_ERROR);
		throw (DWORD)ERROR_INVALID_PARAMETER;
	}
	if(!makeAce( sid, type, mask, flags)) {
		throw GetLastError();
	}
}

//
// Destructor
//
aceEntry::~aceEntry() {
	if(type_ == ACCESS_ALLOWED_ACE_TYPE)
		delete [] (char *)u_ace_.palAce_;
	else
		delete [] (char *)u_ace_.padAce_;
}

//
// Get the SID from the ACE.
//
const SID *aceEntry::getSid() const {
	return (type_ == ACCESS_ALLOWED_ACE_TYPE) ? (SID *)&u_ace_.palAce_->SidStart : 
										(SID *)&u_ace_.padAce_->SidStart;
}

//
// Get the ACCESS_MASK from the ACE.
//
const ACCESS_MASK aceEntry::getMask() const {
	return (type_ == ACCESS_ALLOWED_ACE_TYPE) ? u_ace_.palAce_->Mask : 
										u_ace_.padAce_->Mask;
}

//
// Get the flags from the ACE.
//
const BYTE aceEntry::getFlags() const {
	return (type_ == ACCESS_ALLOWED_ACE_TYPE) ? u_ace_.palAce_->Header.AceFlags	: 
										u_ace_.padAce_->Header.AceFlags;
}

//
// Copy function called by operator= and copy constructor
//
void aceEntry::deepcopy(const aceEntry& ent) {
	type_ = ent.type_;
	SIDobj sid(ent.getSid());
	if(!makeAce( sid, type_, ent.getMask(), ent.getFlags())) {
		throw GetLastError();
	}
}

//
// operator= function.
//
aceEntry& aceEntry::operator=(const aceEntry& ent) { 
	if( &ent == this)
		return *this;

	if(type_ == ACCESS_ALLOWED_ACE_TYPE)
		delete [] (char *)u_ace_.palAce_;
	else
		delete [] (char *)u_ace_.padAce_;

	deepcopy(ent);
	return *this;
}

//
// operator== function.
//
int aceEntry::operator==(const aceEntry& ent) { 
	if( &ent == this)
		return 1;

	if(type_ != ent.type_)
		return 0;

	if(ent.getMask() != getMask())
		return 0;
	if(ent.getFlags() != getFlags())
		return 0;

	if(!EqualSid( (SID *)ent.getSid(), (SID *)getSid()) )
		return 0;
	
	return 1;
}

//
// Functions that implement the aclList class.
//

//
// Destructor
//
aclList::~aclList() {
	allowedList_.RemoveAll();	 
	deniedList_.RemoveAll();
}

//
// Add an allowed ACE for a username.
//
BOOL aclList::AddAllowedAce(const TCHAR *name, const TCHAR *machine, ACCESS_MASK mask, BYTE flag) {

	DWORD err = 0;
	try {
		// The next statement will throw a DWORD exception if
		// it fails.
		aceEntry ace(name, machine, mask, ACCESS_ALLOWED_ACE_TYPE, flag);
		(void)allowedList_.Insert(ace);
	}
	catch (DWORD caerr) {
		err = caerr;
	}
	if(err) {
		SetLastError(err);
		return FALSE;
	}
	return TRUE;
}

//
// Add an allowed ACE for a SIDobj&.
//
BOOL aclList::AddAllowedAce(const SIDobj& sid, ACCESS_MASK mask, BYTE flag) {

	DWORD err = 0;
	try {
		// The next statement will throw a DWORD exception if
		// it fails.
		aceEntry ace(sid, mask, ACCESS_ALLOWED_ACE_TYPE, flag);
		(void)allowedList_.Insert(ace);
	}
	catch (DWORD caerr) {
		err = caerr;
	}
	if(err) {
		SetLastError(err);
		return FALSE;
	}
	return TRUE;
}

//
// Add a denied ACE for a username.
//
BOOL aclList::AddDeniedAce(const TCHAR *name, const TCHAR *machine, ACCESS_MASK mask, BYTE flag) {

	DWORD err = 0;
	try {
		// The next statement will throw a DWORD exception if
		// it fails.
		aceEntry ace(name, machine, mask, ACCESS_DENIED_ACE_TYPE, flag);
		(void)deniedList_.Insert(ace);
	}
	catch (DWORD caerr) {
		err = caerr;
	}
	if(err) {
		SetLastError(err);
		return FALSE;
	}
	return TRUE;
}

//
// Add a denied ACE for a SIDobj&.
//
BOOL aclList::AddDeniedAce(const SIDobj& sid, ACCESS_MASK mask, BYTE flag) {

	DWORD err = 0;
	try {
		// The next statement will throw a DWORD exception if
		// it fails.
		aceEntry ace(sid, mask, ACCESS_DENIED_ACE_TYPE, flag);
		(void)deniedList_.Insert(ace);
	}
	catch (DWORD caerr) {
		err = caerr;
	}
	if(err) {
		SetLastError(err);
		return FALSE;
	}
	return TRUE;
}

//
// Remove an allowed ACE for a username.
//
BOOL aclList::RemoveAllowedAce(const TCHAR *name, const TCHAR *machine, ACCESS_MASK mask, BYTE flag) {

	DWORD err = 0;
	try {
		// The next statement will throw a DWORD exception if
		// it fails.
		aceEntry ace(name, machine, mask, ACCESS_ALLOWED_ACE_TYPE, flag);
		if(allowedList_.Remove(ace)) {
			SetLastErrorEx( ERROR_INVALID_PARAMETER, SLE_ERROR);
			return FALSE;
		}
	}
	catch (DWORD caerr) {
		err = caerr;
	}
	if(err) {
		SetLastError(err);
		return FALSE;
	}
	return TRUE;
}

//
// Remove an allowed ACE for a SIDobj&.
//
BOOL aclList::RemoveAllowedAce(const SIDobj& sid, ACCESS_MASK mask, BYTE flag) {

	DWORD err = 0;
	try {
		// The next statement will throw a DWORD exception if
		// it fails.
		aceEntry ace(sid, mask, ACCESS_ALLOWED_ACE_TYPE, flag);
		if(allowedList_.Remove(ace)) {
			SetLastErrorEx(ERROR_INVALID_PARAMETER, SLE_ERROR);
			return FALSE;
		}
	}
	catch (DWORD caerr) {
		err = caerr;
	}
	if(err) {
		SetLastError(err);
		return FALSE;
	}
	return TRUE;
}

//
// Remove a denied ACE for a username.
//
BOOL aclList::RemoveDeniedAce(const TCHAR *name, const TCHAR *machine, ACCESS_MASK mask, BYTE flag) {

	DWORD err = 0;
	try {
		// The next statement will throw a DWORD exception if
		// it fails.
		aceEntry ace(name, machine, mask, ACCESS_DENIED_ACE_TYPE, flag);
		if(deniedList_.Remove(ace)) {
			SetLastErrorEx(ERROR_INVALID_PARAMETER, SLE_ERROR);
			return FALSE;
		}
	}
	catch (DWORD caerr) {
		err = caerr;
	}
	if(err) {
		SetLastError(err);
		return FALSE;
	}
	return TRUE;
}

//
// Remove a denied ACE for a SIDobj&.
//
BOOL aclList::RemoveDeniedAce(const SIDobj& sid, ACCESS_MASK mask, BYTE flag) {

	DWORD err = 0;
	try {
		// The next statement will throw a DWORD exception if
		// it fails.
		aceEntry ace(sid, mask, ACCESS_DENIED_ACE_TYPE, flag);
		if(deniedList_.Remove(ace)) {
			SetLastErrorEx(ERROR_INVALID_PARAMETER, SLE_ERROR);
			return FALSE;
		}
	}
	catch (DWORD caerr) {
		err = caerr;
	}
	if(err) {
		SetLastError(err);
		return FALSE;
	}
	return TRUE;
}

//
// Calculate the size needed for an ACL containing the added ACE's.
//
DWORD aclList::calculateACLSize() const {

	DWORD size = sizeof(ACL);

	size += ( allowedList_.Count() * ( sizeof(ACCESS_ALLOWED_ACE) - sizeof(DWORD)));
	size += ( deniedList_.Count() * ( sizeof(ACCESS_DENIED_ACE) - sizeof(DWORD)));

	// Go through the allowedList and deniedList and add the sizes of the SID's
	{
		ListIterator<aceEntry> iter(&allowedList_);
		for( ; !iter.Done(); iter.Next()) {
			size += GetLengthSid((SID *)iter.Get().getSid());
		}
	}
	{
		ListIterator<aceEntry> iter(&deniedList_);
		for( ; !iter.Done(); iter.Next()) {
			size += GetLengthSid((SID *)iter.Get().getSid());
		}
	}
	return size;
}

BOOL aclList::createACLfromList(ACL *pacl, DWORD *psize) const {

	DWORD size = calculateACLSize();

	//
	// If there isn't enough room just set the required size then return.
	//
	if(*psize < size ) {
		*psize = size;
		SetLastErrorEx(ERROR_INSUFFICIENT_BUFFER, SLE_WARNING);
		return FALSE;
	}
	if(!InitializeAcl( pacl, *psize, ACL_REVISION))
		return FALSE;
	// Go through the deniedList and allowedList and add their ACE's
	// NB. The deniedList *MUST* be done first
	{
		ListIterator<aceEntry> iter(&deniedList_);
		for( ; !iter.Done(); iter.Next()) {
			if(!AddAce( pacl, ACL_REVISION, MAXDWORD, (void *)iter.Get().getAce(),
						(DWORD)iter.Get().getAceSize()))
				return FALSE;
		}
	}
	{
		ListIterator<aceEntry> iter(&allowedList_);
		for( ; !iter.Done(); iter.Next()) {
			if(!AddAce( pacl, ACL_REVISION, MAXDWORD, (void *)iter.Get().getAce(),
						(DWORD)iter.Get().getAceSize()))
				return FALSE;
		}
	}
	return TRUE;
}
