#!/usr/bin/env python

import rosbag

import datetime
from tf.msg import tfMessage
from argparse import ArgumentParser
from geometry_msgs.msg import Quaternion
from numpy import mean, array, hypot, diff, convolve, arange, sin, cos, ones, pi, matrix
from tf.transformations import euler_from_quaternion,quaternion_from_euler,quaternion_multiply,quaternion_matrix
import tf

# Prompt user if scipy is missing.
try:
  from scipy import optimize
except ImportError:
  print("This script requires scipy be available.")
  print("On Ubuntu: sudo apt-get install python-scipy")
  exit(1)

# Plots are optional
try:
  from matplotlib import pyplot
  from mpl_toolkits.mplot3d import Axes3D
except ImportError:
  pyplot = None

parser = ArgumentParser(description='Process bag file for compass calibration. Pass a bag containing /imu/mag topic, with the compass facing upright, being slowly rotated in a clockwise direction for 30-120 seconds.')
parser.add_argument('bag', metavar='FILE', type=str, help='input bag file')
parser.add_argument('outfile', metavar='OUTFILE', type=str, help='output yaml file',
                    nargs="?", default="/tmp/calibration.yaml")
parser.add_argument('--plots', type=bool, help='Show plots if matplotlib available.')
args = parser.parse_args()

if not args.plots:
    pyplot = None

bag = rosbag.Bag(args.bag)

vecs = []
for topic, msg, time in bag.read_messages(topics=("/imu/mag",)):
  vecs.append((float(msg.vector.x), float(msg.vector.y), float(msg.vector.z)))

print ("Using " + str(len(vecs)) + " samples.")

def calc_R(xc, yc):
    """ calculate the distance of each 2D points from the center (xc, yc) """
    return hypot(x-xc, y-yc)

def f_2(c):
    """ calculate the algebraic distance between the 2D points and the mean circle centered at c=(xc, yc) """
    Ri = calc_R(*c)
    return Ri - Ri.mean()


x,y,z = zip(*vecs)
center_estimate = mean(x), mean(y)
center, ier = optimize.leastsq(f_2, center_estimate)
radius = calc_R(*center).mean()
center = (center[0], center[1], mean(z))

a = arange(0, 2*pi + pi/50, pi/50)
circle_points = (center[0] + cos(a) * radius, 
                 center[1] + sin(a) * radius, 
                 center[2] * ones(len(a)))

print("Magnetic circle centered at " + str(center) + ", with radius " + str(radius))

with open(args.outfile, "w") as f:
  f.write("# Generated from %s on %s.\n" % (args.bag, datetime.date.today()))
  f.write("use_mag: true\n")
  f.write("mag_bias_x: {:.9f}\n".format(center[0]))
  f.write("mag_bias_y: {:.9f}\n".format(center[1]))
  f.write("mag_bias_z: {:.9f}\n".format(center[2]))

print("Calibration file written to " + args.outfile)

if pyplot:
  ax2 = fig.add_subplot(212, projection='3d')
  ax2.view_init(elev=80, azim=5)
  ax2.scatter(x, y, z)
  ax2.plot(*circle_points, c="red")
  pyplot.show()

