#!/usr/bin/python

# this is an extauth script for use with ejabberd.
# it uses a binary protocol for communicating via stdin/stdout
# see ejabberd.jabber.ru/extauth for more examples
# this script uses the postgresql-database from our django mailadmin app
# for authentification
# see http://einfachkaffee.de/svn/mails/trunk for details
# this script totally relies on the views we created in postgresql for our
# db to function. So, unless you're installing all of this app or are willing
# to rewrite the SQL, this is probably useless to you.
#
#
# Author: Lukas Kolbe <lukas@einfachkaffee.de>


import sys, logging, psycopg, os
sys.stderr = open('/var/log/ejabberd/extauth_err.log', 'a')
from struct import *

# the file we got as an argument must contain exactly one line of the form:
# dbname= user= password= host= port=

conn_file = file(sys.argv[1])
conn_string = conn_file.readline()
conn_file.close()

# we currently only support crypt()ed passwords
PW_TYPE = "crypt"

# ejabberd starts one instance for each domain it's serving, so I needed
# a way to distinguish the logfiles. pid in the filename might not be a 
# good idea, though ...
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s %(levelname)s %(message)s',
                    filename='/var/log/ejabberd/extauth.log',
                    filemode='a')
logging.info('extauth script started, waiting for ejabberd requests')


class EjabberdInputError(Exception):
    def __init__(self, value):
        self.value = value
    def __str__(self):
        return repr(self.value)

def check_pw(passwd, passwd_crypt):
    if PW_TYPE == "crypt":
        from crypt import crypt
        return crypt(passwd, passwd_crypt) == passwd_crypt
    elif PW_TYPE == "md5":
        from md5 import md5
        return md5(passwd).hexdigest() == passwd_crypt
    else:
        return False

def make_pw(passwd):
    if PW_TYPE == "crypt":
        from random import getrandbits
        from crypt import crypt
        return crypt(passwd, "%s" % getrandbits(16))
    elif PW_TYPE == "md5":
        from md5 import md5
        return md5(passwd).hexdigest()
    else:
        return None
                            
    
def genanswer(bool):
    answer = 0
    if bool:
        answer = 1
    token = pack('>hh', 2, answer)
    return token 
	
def ejabberd_out(bool):
    logging.debug("Ejabberd gets: %s" % bool)
    token = genanswer(bool)
    logging.debug("sent bytes: %#x %#x %#x %#x" % (ord(token[0]), ord(token[1]), ord(token[2]), ord(token[3])))
    sys.stdout.write(token)
    sys.stdout.flush()

def ejabberd_in():
    logging.debug("trying to read 2 bytes from ejabberd:")
    try:
        input_length = sys.stdin.read(2)
    except IOError:
        logging.debug("ioerror")
    if len(input_length) is not 2:
        logging.debug("ejabberd sent us wrong things!")
        raise EjabberdInputError('Wrong input from ejabberd!')
    logging.debug('got 2 bytes via stdin')
    
    (size,) = unpack('>h', input_length)
    return sys.stdin.read(size).split(':')
	
def log_success(method, username, server, success):
    if success:
        logging.info("%s successful for %s@%s" % (method, username, server))
    else:
        logging.info("%s unsuccessful for %s@%s" % (method, username, server))

def dbexec(query, args):
    try:
        cnx = psycopg.connect(conn_string)
        cr = cnx.cursor()
        cr.execute(query, args)
        cnx.commit()
        return cr.fetchall()
    except psycopg.Error:
        return None
	
def auth(username, server, password):
    logging.debug("%s@%s wants authentication ..." % (username, server))
    result = dbexec("""SELECT password FROM xmpp_user
                       WHERE address=%s AND "domain"=%s;""",
                       [username, server])
    if result is not None and len(result) is 1:
        return check_pw(password, result[0][0]) 
    else:
        return False

def isuser(username, server):
    logging.debug("do we know %s@%s?" % (username, server))
    result = dbexec("""SELECT COUNT(full_address) FROM xmpp_user
                       WHERE address=%s AND "domain"=%s;""",
                       [username, server])
    if result is not None and result[0][0] == 1:
        return True
    else:
        return False

def setpass(username, server, newpassword):
    newpw_crypted = make_pw(newpassword)
    logging.debug("setpass for %s@%s" % (username, server))
    dbexec("""UPDATE xmpp_user 
        SET password=%s 
        WHERE address=%s AND "domain"=%s""",
        [newpw_crypted, username, server])

    result = dbexec("""SELECT count(full_address)
        FROM xmpp_user 
        WHERE password=%s AND address=%s AND "domain"=%s""",
        [newpw_crypted, username, server])
        
    # this doesn't currently work ... false is always returned
    if result is not None and result[0][0] is 1:
        logging.debug("password change for %s@%s successful" % (username, server))
        return True
    else:
        logging.debug("password change for %s@%s not successful" % (username, server))
        return False

# this is our main-loop. I hate infinite loops.
while True:
    logging.debug("start of infinite loop")

    try: 
        data = ejabberd_in()
    except EjabberdInputError, inst:
        logging.info("Exception occured: %s", inst)
        break

    logging.debug('Method: %s' % data[0])
    success = False
    
    if data[0] == "auth":
        success = auth(data[1], data[2], data[3])
        ejabberd_out(success)
        log_success("auth", data[1], data[2], success)

    elif data[0] == "isuser":
        success = isuser(data[1], data[2])
        ejabberd_out(success)
        log_success("isuser", data[1], data[2], success)

    elif data[0] == "setpass":
        success = setpass(data[1], data[2], data[3])
        ejabberd_out(success)
        log_success("setpass", data[1], data[2], success)
      
    logging.debug("end of infinite loop")

logging.info('extauth script terminating')



