125 lines
3.6 KiB
Python
125 lines
3.6 KiB
Python
import argparse
|
|
import logging
|
|
from functools import wraps
|
|
|
|
import paho.mqtt.client as mqtt
|
|
import libvirt
|
|
|
|
DOMAIN_STATES = {
|
|
0: "nostate",
|
|
1: "running",
|
|
2: "blocked",
|
|
3: "paused",
|
|
4: "shutdown",
|
|
5: "shutoff",
|
|
6: "crashed",
|
|
7: "pmsuspended",
|
|
}
|
|
|
|
parser = argparse.ArgumentParser(description='libvirt-to-spejsiot adapter')
|
|
parser.add_argument('--libvirt', help='libvirt connection URI')
|
|
parser.add_argument('--broker', default='mqtt.waw.hackerspace.pl', help='MQTT broker')
|
|
|
|
def report(fn):
|
|
"""Wraps random function and logs its exceptions while returning None"""
|
|
@wraps(fn)
|
|
def wrapped(*args, **kwargs):
|
|
try:
|
|
return fn(*args, **kwargs)
|
|
except:
|
|
logging.exception('An error occured')
|
|
return None
|
|
return wrapped
|
|
|
|
class LibvirtSpejsIOTClient(mqtt.Client):
|
|
topic_prefix = None
|
|
|
|
def __init__(self, config, *args, **kwargs):
|
|
super(LibvirtSpejsIOTClient, self).__init__(*args, **kwargs)
|
|
|
|
self.logger = logging.getLogger(self.__class__.__name__)
|
|
self.config = config
|
|
self.topic_prefix = self.config.TOPIC_PREFIX
|
|
|
|
def on_connect(self, client, userdata, flags, rc):
|
|
self.subscribe(self.topic_prefix + '+/+/set')
|
|
self.logger.info('Connected')
|
|
|
|
for dom in self.config.DOMAINS:
|
|
self.publish_state(self.conn.lookupByName(dom))
|
|
|
|
@report
|
|
def on_message(self, client, userdata, msg):
|
|
topic = msg.topic[len(self.topic_prefix):]
|
|
self.logger.info('mqtt -> %r %r', topic, msg.payload)
|
|
|
|
node, attrib, _ = topic.split('/')
|
|
|
|
if not msg.payload:
|
|
return
|
|
|
|
if node not in self.config.DOMAINS:
|
|
self.logger.warning('Forbidden domain')
|
|
return
|
|
|
|
if attrib != 'state':
|
|
self.logger.warning('Unknown attribute')
|
|
return
|
|
|
|
# We don't like persistence here
|
|
self.publish(msg.topic, '', retain=True)
|
|
|
|
dom = self.conn.lookupByName(node)
|
|
state = dom.state(0)[0]
|
|
if msg.payload == 'running':
|
|
if state == libvirt.VIR_DOMAIN_PAUSED:
|
|
dom.resume()
|
|
elif state == libvirt.VIR_DOMAIN_PMSUSPENDED:
|
|
dom.pMWakeup()
|
|
else:
|
|
dom.create()
|
|
elif msg.payload == 'paused':
|
|
dom.suspend()
|
|
elif msg.payload == 'pmsuspended':
|
|
dom.pMSuspendForDuration(0, 0)
|
|
elif msg.payload in ['shutoff', 'shutdown']:
|
|
dom.shutdown()
|
|
elif msg.payload == 'kill':
|
|
dom.destroy()
|
|
elif msg.payload == 'reset':
|
|
dom.reset()
|
|
|
|
def lifecycle_callback(self, connection, domain, event, detail, console):
|
|
self.logger.info('%s -> %r / %r', domain.name(), event, detail)
|
|
self.publish_state(domain)
|
|
|
|
def publish_state(self, domain):
|
|
domname = domain.name()
|
|
if domname not in self.config.DOMAINS:
|
|
return
|
|
|
|
state = DOMAIN_STATES.get(domain.state(0)[0])
|
|
self.publish('%s%s/state' % (self.topic_prefix, domname), state, 0, True)
|
|
|
|
def run(self, args):
|
|
libvirt.virEventRegisterDefaultImpl()
|
|
|
|
self.conn = libvirt.open(args.libvirt)
|
|
self.conn.domainEventRegisterAny(
|
|
None, libvirt.VIR_DOMAIN_EVENT_ID_LIFECYCLE,
|
|
self.lifecycle_callback, self.conn)
|
|
|
|
self.connect(args.broker)
|
|
self.loop_start()
|
|
|
|
while True:
|
|
self.logger.debug('*beep*')
|
|
libvirt.virEventRunDefaultImpl()
|
|
|
|
if __name__ == "__main__":
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
args = parser.parse_args()
|
|
|
|
import config
|
|
LibvirtSpejsIOTClient(config).run(args)
|