{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Fitting in different orbital bases\n", "\n", "In this tutorial, we show how one can perform orbit-fits in different coordinate bases amongst the ones supported by `orbitize`. Currently fitting in different bases is only supported in MCMC, so we will use MCMC to perform an orbit-fit in an orbital basis distinct from the default one. For a general introduction to MCMC, be sure to check out the [MCMC Introduction tutorial](https://orbitize.readthedocs.io/en/latest/tutorials/MCMC_tutorial.html) first!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The \"standard\" and \"XYZ\" bases" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The default way to define an orbit in `orbitize` is through what we call the 'standard basis', which consists of eight parameters: semi-major axis (sma), eccentricity (ecc), inclination (inc), argument of periastron (aop), position angle of the nodes (pan), epoch of periastron expressed as a fraction of the period past a reference epoch (tau), parallax (plx) and total system mass (mtot). Each orbital element has an associated default prior; to see how to explore and modify these priors check out the [Modifying priors tutorial](http://orbitize.info/en/latest/tutorials/Modifying_Priors.html).\n", "\n", "An alternative way to define an orbit is through its position and velocity components in XYZ space for a given epoch; we will call this the 'XYZ basis'. The orbit is thus defined with the array ($x$, $y$, $z$, $\\dot{x}$, $\\dot{y}$,$\\dot{z}$, plx, mtot), with position coordinates measured in AU and velocity components in $\\text{km s}^{-1}$. In this basis, the sky-plane coordinates ($x,y$) are the separations of the planet relative to the primary, with the positive $x$ and $y$ directions coinciding with the positive RA and Dec directions, respectively. The $z$ direction is the line-of-sight coordinate, such that movement in the positive $z$ direction causes a redshift. The default priors are uniform all uniform." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setting up Sampler in the XYZ basis" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The easiest way to run an orbit-fit in an alternative orbital basis in `orbitize` is through the `orbitize.driver.Driver` interface. The process is exactly like initializing a regular `Driver` object, but setting the `fitting_basis` keyword to 'XYZ':" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Converting ra/dec data points in data_table to sep/pa. Original data are stored in input_table.\n" ] } ], "source": [ "import numpy as np\n", "\n", "import orbitize\n", "from orbitize import driver\n", "import multiprocessing as mp\n", "\n", "filename = \"{}xyz_test_data.csv\".format(orbitize.DATADIR) # a file with input in radec since rn it only works for that\n", "\n", "# system parameters\n", "num_secondary_bodies = 1\n", "system_mass = 1.75 # [Msol]\n", "plx = 51.44 # [mas]\n", "mass_err = 0.05 # [Msol]\n", "plx_err = 0.12 # [mas]\n", "\n", "# MCMC parameters\n", "num_temps = 5\n", "num_walkers = 20\n", "num_threads = mp.cpu_count() # or a different number if you prefer\n", "\n", "\n", "my_driver = driver.Driver(\n", " filename, 'MCMC', num_secondary_bodies, system_mass, plx, mass_err=mass_err, plx_err=plx_err,\n", " mcmc_kwargs={'num_temps': num_temps, 'num_walkers': num_walkers, 'num_threads': num_threads},\n", " system_kwargs={'fitting_basis': 'XYZ'}\n", ")\n", "\n", "s = my_driver.sampler" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### (Properly) initializing walkers in the XYZ basis" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the standard basis at this point we would be ready to use the `s.run_sampler` method to start the sampling, but with the XYZ basis we have to make sure that all our walkers are initialized in a valid region of parameter space. This is because randomly generated values of ($x$, $y$, $z$, $\\dot{x}$, $\\dot{y}$, $\\dot{z}$) can result in unbound, invalid orbits with, for example, negative eccentricities (which is not cool). This can be easily corrected with the `s.validate_xyz_positions` method:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "All walker positions validated.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/sblunt/Projects/orbitize/orbitize/basis.py:944: RuntimeWarning: invalid value encountered in arccos\n", " eanom = np.arccos(cos_eanom)\n", "/data/user/sblunt/miniconda3/envs/python3.7/lib/python3.7/site-packages/astropy/units/quantity.py:481: RuntimeWarning: invalid value encountered in sqrt\n", " result = super().__array_ufunc__(function, method, *arrays, **kwargs)\n" ] } ], "source": [ "s.validate_xyz_positions()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "___________________" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "After this is done, the sampler can be run and the results saved normally:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/data/user/sblunt/miniconda3/envs/python3.7/lib/python3.7/site-packages/astropy/table/column.py:1020: RuntimeWarning: invalid value encountered in greater\n", " result = getattr(super(), op)(other)\n", "/home/sblunt/Projects/orbitize/orbitize/kepler.py:112: RuntimeWarning: invalid value encountered in sqrt\n", " tanom = 2.*np.arctan(np.sqrt((1.0 + ecc)/(1.0 - ecc))*np.tan(0.5*eanom))\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Starting Burn in\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/data/user/sblunt/miniconda3/envs/python3.7/lib/python3.7/site-packages/astropy/table/column.py:1020: RuntimeWarning: invalid value encountered in greater\n", " result = getattr(super(), op)(other)\n", "/home/sblunt/Projects/orbitize/orbitize/kepler.py:112: RuntimeWarning: invalid value encountered in sqrt\n", " tanom = 2.*np.arctan(np.sqrt((1.0 + ecc)/(1.0 - ecc))*np.tan(0.5*eanom))\n", "/data/user/sblunt/miniconda3/envs/python3.7/lib/python3.7/site-packages/astropy/table/column.py:1020: RuntimeWarning: invalid value encountered in greater\n", " result = getattr(super(), op)(other)\n", "/home/sblunt/Projects/orbitize/orbitize/kepler.py:112: RuntimeWarning: invalid value encountered in sqrt\n", " tanom = 2.*np.arctan(np.sqrt((1.0 + ecc)/(1.0 - ecc))*np.tan(0.5*eanom))\n", "/data/user/sblunt/miniconda3/envs/python3.7/lib/python3.7/site-packages/astropy/table/column.py:1020: RuntimeWarning: invalid value encountered in greater\n", " result = getattr(super(), op)(other)\n", "/home/sblunt/Projects/orbitize/orbitize/kepler.py:112: RuntimeWarning: invalid value encountered in sqrt\n", " tanom = 2.*np.arctan(np.sqrt((1.0 + ecc)/(1.0 - ecc))*np.tan(0.5*eanom))\n", "/data/user/sblunt/miniconda3/envs/python3.7/lib/python3.7/site-packages/astropy/table/column.py:1020: RuntimeWarning: invalid value encountered in greater\n", " result = getattr(super(), op)(other)\n", "/home/sblunt/Projects/orbitize/orbitize/kepler.py:112: RuntimeWarning: invalid value encountered in sqrt\n", " tanom = 2.*np.arctan(np.sqrt((1.0 + ecc)/(1.0 - ecc))*np.tan(0.5*eanom))\n", "/data/user/sblunt/miniconda3/envs/python3.7/lib/python3.7/site-packages/astropy/table/column.py:1020: RuntimeWarning: invalid value encountered in greater\n", " result = getattr(super(), op)(other)\n", "/home/sblunt/Projects/orbitize/orbitize/kepler.py:112: RuntimeWarning: invalid value encountered in sqrt\n", " tanom = 2.*np.arctan(np.sqrt((1.0 + ecc)/(1.0 - ecc))*np.tan(0.5*eanom))\n", "/data/user/sblunt/miniconda3/envs/python3.7/lib/python3.7/site-packages/astropy/table/column.py:1020: RuntimeWarning: invalid value encountered in greater\n", " result = getattr(super(), op)(other)\n", "/home/sblunt/Projects/orbitize/orbitize/kepler.py:112: RuntimeWarning: invalid value encountered in sqrt\n", " tanom = 2.*np.arctan(np.sqrt((1.0 + ecc)/(1.0 - ecc))*np.tan(0.5*eanom))\n", "/data/user/sblunt/miniconda3/envs/python3.7/lib/python3.7/site-packages/astropy/table/column.py:1020: RuntimeWarning: invalid value encountered in greater\n", " result = getattr(super(), op)(other)\n", "/home/sblunt/Projects/orbitize/orbitize/kepler.py:112: RuntimeWarning: invalid value encountered in sqrt\n", " tanom = 2.*np.arctan(np.sqrt((1.0 + ecc)/(1.0 - ecc))*np.tan(0.5*eanom))\n" ] }, { "name": "stdout", "output_type": "stream", "text": [] }, { "name": "stderr", "output_type": "stream", "text": [ "/data/user/sblunt/miniconda3/envs/python3.7/lib/python3.7/site-packages/astropy/table/column.py:1020: RuntimeWarning: invalid value encountered in greater\n", " result = getattr(super(), op)(other)\n", "/home/sblunt/Projects/orbitize/orbitize/kepler.py:112: RuntimeWarning: invalid value encountered in sqrt\n", " tanom = 2.*np.arctan(np.sqrt((1.0 + ecc)/(1.0 - ecc))*np.tan(0.5*eanom))\n", "/data/user/sblunt/miniconda3/envs/python3.7/lib/python3.7/site-packages/astropy/table/column.py:1020: RuntimeWarning: invalid value encountered in greater\n", " result = getattr(super(), op)(other)\n", "/home/sblunt/Projects/orbitize/orbitize/kepler.py:112: RuntimeWarning: invalid value encountered in sqrt\n", " tanom = 2.*np.arctan(np.sqrt((1.0 + ecc)/(1.0 - ecc))*np.tan(0.5*eanom))\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "10/10 steps of burn-in complete\n", "Burn in complete. Sampling posterior now.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/data/user/sblunt/miniconda3/envs/python3.7/lib/python3.7/site-packages/astropy/table/column.py:1020: RuntimeWarning: invalid value encountered in greater\n", " result = getattr(super(), op)(other)\n", "/home/sblunt/Projects/orbitize/orbitize/kepler.py:112: RuntimeWarning: invalid value encountered in sqrt\n", " tanom = 2.*np.arctan(np.sqrt((1.0 + ecc)/(1.0 - ecc))*np.tan(0.5*eanom))\n", "/data/user/sblunt/miniconda3/envs/python3.7/lib/python3.7/site-packages/astropy/table/column.py:1020: RuntimeWarning: invalid value encountered in greater\n", " result = getattr(super(), op)(other)\n", "/home/sblunt/Projects/orbitize/orbitize/kepler.py:112: RuntimeWarning: invalid value encountered in sqrt\n", " tanom = 2.*np.arctan(np.sqrt((1.0 + ecc)/(1.0 - ecc))*np.tan(0.5*eanom))\n" ] }, { "name": "stdout", "output_type": "stream", "text": [] }, { "name": "stderr", "output_type": "stream", "text": [ "/data/user/sblunt/miniconda3/envs/python3.7/lib/python3.7/site-packages/astropy/table/column.py:1020: RuntimeWarning: invalid value encountered in greater\n", " result = getattr(super(), op)(other)\n", "/home/sblunt/Projects/orbitize/orbitize/kepler.py:112: RuntimeWarning: invalid value encountered in sqrt\n", " tanom = 2.*np.arctan(np.sqrt((1.0 + ecc)/(1.0 - ecc))*np.tan(0.5*eanom))\n" ] }, { "name": "stdout", "output_type": "stream", "text": [] }, { "name": "stderr", "output_type": "stream", "text": [ "/data/user/sblunt/miniconda3/envs/python3.7/lib/python3.7/site-packages/astropy/table/column.py:1020: RuntimeWarning: invalid value encountered in greater\n", " result = getattr(super(), op)(other)\n", "/home/sblunt/Projects/orbitize/orbitize/kepler.py:112: RuntimeWarning: invalid value encountered in sqrt\n", " tanom = 2.*np.arctan(np.sqrt((1.0 + ecc)/(1.0 - ecc))*np.tan(0.5*eanom))\n", "/data/user/sblunt/miniconda3/envs/python3.7/lib/python3.7/site-packages/astropy/table/column.py:1020: RuntimeWarning: invalid value encountered in greater\n", " result = getattr(super(), op)(other)\n", "/home/sblunt/Projects/orbitize/orbitize/kepler.py:112: RuntimeWarning: invalid value encountered in sqrt\n", " tanom = 2.*np.arctan(np.sqrt((1.0 + ecc)/(1.0 - ecc))*np.tan(0.5*eanom))\n" ] }, { "name": "stdout", "output_type": "stream", "text": [] }, { "name": "stderr", "output_type": "stream", "text": [ "/data/user/sblunt/miniconda3/envs/python3.7/lib/python3.7/site-packages/astropy/table/column.py:1020: RuntimeWarning: invalid value encountered in greater\n", " result = getattr(super(), op)(other)\n", "/home/sblunt/Projects/orbitize/orbitize/kepler.py:112: RuntimeWarning: invalid value encountered in sqrt\n", " tanom = 2.*np.arctan(np.sqrt((1.0 + ecc)/(1.0 - ecc))*np.tan(0.5*eanom))\n", "/data/user/sblunt/miniconda3/envs/python3.7/lib/python3.7/site-packages/astropy/table/column.py:1020: RuntimeWarning: invalid value encountered in greater\n", " result = getattr(super(), op)(other)\n", "/home/sblunt/Projects/orbitize/orbitize/kepler.py:112: RuntimeWarning: invalid value encountered in sqrt\n", " tanom = 2.*np.arctan(np.sqrt((1.0 + ecc)/(1.0 - ecc))*np.tan(0.5*eanom))\n", "/data/user/sblunt/miniconda3/envs/python3.7/lib/python3.7/site-packages/astropy/table/column.py:1020: RuntimeWarning: invalid value encountered in greater\n", " result = getattr(super(), op)(other)\n", "/home/sblunt/Projects/orbitize/orbitize/kepler.py:112: RuntimeWarning: invalid value encountered in sqrt\n", " tanom = 2.*np.arctan(np.sqrt((1.0 + ecc)/(1.0 - ecc))*np.tan(0.5*eanom))\n" ] }, { "name": "stdout", "output_type": "stream", "text": [] }, { "name": "stderr", "output_type": "stream", "text": [ "/data/user/sblunt/miniconda3/envs/python3.7/lib/python3.7/site-packages/astropy/table/column.py:1020: RuntimeWarning: invalid value encountered in greater\n", " result = getattr(super(), op)(other)\n", "/home/sblunt/Projects/orbitize/orbitize/kepler.py:112: RuntimeWarning: invalid value encountered in sqrt\n", " tanom = 2.*np.arctan(np.sqrt((1.0 + ecc)/(1.0 - ecc))*np.tan(0.5*eanom))\n", "/data/user/sblunt/miniconda3/envs/python3.7/lib/python3.7/site-packages/astropy/table/column.py:1020: RuntimeWarning: invalid value encountered in greater\n", " result = getattr(super(), op)(other)\n", "/home/sblunt/Projects/orbitize/orbitize/kepler.py:112: RuntimeWarning: invalid value encountered in sqrt\n", " tanom = 2.*np.arctan(np.sqrt((1.0 + ecc)/(1.0 - ecc))*np.tan(0.5*eanom))\n", "/data/user/sblunt/miniconda3/envs/python3.7/lib/python3.7/site-packages/astropy/table/column.py:1020: RuntimeWarning: invalid value encountered in greater\n", " result = getattr(super(), op)(other)\n", "/home/sblunt/Projects/orbitize/orbitize/kepler.py:112: RuntimeWarning: invalid value encountered in sqrt\n", " tanom = 2.*np.arctan(np.sqrt((1.0 + ecc)/(1.0 - ecc))*np.tan(0.5*eanom))\n", "/data/user/sblunt/miniconda3/envs/python3.7/lib/python3.7/site-packages/astropy/table/column.py:1020: RuntimeWarning: invalid value encountered in greater\n", " result = getattr(super(), op)(other)\n", "/home/sblunt/Projects/orbitize/orbitize/kepler.py:112: RuntimeWarning: invalid value encountered in sqrt\n", " tanom = 2.*np.arctan(np.sqrt((1.0 + ecc)/(1.0 - ecc))*np.tan(0.5*eanom))\n", "/data/user/sblunt/miniconda3/envs/python3.7/lib/python3.7/site-packages/astropy/table/column.py:1020: RuntimeWarning: invalid value encountered in greater\n", " result = getattr(super(), op)(other)\n", "/home/sblunt/Projects/orbitize/orbitize/kepler.py:112: RuntimeWarning: invalid value encountered in sqrt\n", " tanom = 2.*np.arctan(np.sqrt((1.0 + ecc)/(1.0 - ecc))*np.tan(0.5*eanom))\n", "/data/user/sblunt/miniconda3/envs/python3.7/lib/python3.7/site-packages/astropy/table/column.py:1020: RuntimeWarning: invalid value encountered in greater\n", " result = getattr(super(), op)(other)\n", "/home/sblunt/Projects/orbitize/orbitize/kepler.py:112: RuntimeWarning: invalid value encountered in sqrt\n", " tanom = 2.*np.arctan(np.sqrt((1.0 + ecc)/(1.0 - ecc))*np.tan(0.5*eanom))\n", "/data/user/sblunt/miniconda3/envs/python3.7/lib/python3.7/site-packages/astropy/table/column.py:1020: RuntimeWarning: invalid value encountered in greater\n", " result = getattr(super(), op)(other)\n", "/home/sblunt/Projects/orbitize/orbitize/kepler.py:112: RuntimeWarning: invalid value encountered in sqrt\n", " tanom = 2.*np.arctan(np.sqrt((1.0 + ecc)/(1.0 - ecc))*np.tan(0.5*eanom))\n", "/data/user/sblunt/miniconda3/envs/python3.7/lib/python3.7/site-packages/astropy/table/column.py:1020: RuntimeWarning: invalid value encountered in greater\n", " result = getattr(super(), op)(other)\n", "/home/sblunt/Projects/orbitize/orbitize/kepler.py:112: RuntimeWarning: invalid value encountered in sqrt\n", " tanom = 2.*np.arctan(np.sqrt((1.0 + ecc)/(1.0 - ecc))*np.tan(0.5*eanom))\n", "/data/user/sblunt/miniconda3/envs/python3.7/lib/python3.7/site-packages/astropy/table/column.py:1020: RuntimeWarning: invalid value encountered in greater\n", " result = getattr(super(), op)(other)\n", "/home/sblunt/Projects/orbitize/orbitize/kepler.py:112: RuntimeWarning: invalid value encountered in sqrt\n", " tanom = 2.*np.arctan(np.sqrt((1.0 + ecc)/(1.0 - ecc))*np.tan(0.5*eanom))\n", "/data/user/sblunt/miniconda3/envs/python3.7/lib/python3.7/site-packages/astropy/table/column.py:1020: RuntimeWarning: invalid value encountered in greater\n", " result = getattr(super(), op)(other)\n", "/home/sblunt/Projects/orbitize/orbitize/kepler.py:112: RuntimeWarning: invalid value encountered in sqrt\n", " tanom = 2.*np.arctan(np.sqrt((1.0 + ecc)/(1.0 - ecc))*np.tan(0.5*eanom))\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "30/30 steps completed\n", "Run complete\n" ] } ], "source": [ "total_orbits = 600 # number of steps x number of walkers (at lowest temperature)\n", "burn_steps = 10 # steps to burn in per walker\n", "thin = 2 # only save every 2nd step\n", "\n", "s.run_sampler(total_orbits, burn_steps=burn_steps, thin=thin)\n", "s.results.save_results('my_posterior.hdf5')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Loading and converting results" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can load the results as you normally would. The orbit posteriors are saved in the `results.post` attribute, and the basis you used for the fit in the `results.fitting_basis` attribute:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Converting ra/dec data points in data_table to sep/pa. Original data are stored in input_table.\n", "The used basis for the fit was XYZ\n", "The posteriors are [[-1.55895340e+01 -3.20352269e+01 9.38119986e+00 ... -7.46195838e-03\n", " 5.15299823e+01 1.72400897e+00]\n", " [-1.56081092e+01 -3.20562773e+01 1.49810951e+00 ... 8.53387015e-02\n", " 5.13883909e+01 1.73579787e+00]\n", " [-1.55612462e+01 -3.20801251e+01 -2.96308303e-01 ... 7.18261775e-01\n", " 5.14206555e+01 1.76608956e+00]\n", " ...\n", " [-1.55671475e+01 -3.20089823e+01 3.62094122e+01 ... 3.72598271e-01\n", " 5.15487665e+01 1.74432532e+00]\n", " [-1.55794884e+01 -3.20303979e+01 2.75912018e+01 ... -2.40407388e-01\n", " 5.14895599e+01 1.78876637e+00]\n", " [-1.55837449e+01 -3.20532712e+01 1.93382300e+01 ... -1.24185021e-01\n", " 5.14598433e+01 1.80483891e+00]]\n" ] } ], "source": [ "myResults = orbitize.results.Results() # create empty Results object\n", "myResults.load_results('my_posterior.hdf5') \n", "print('The used basis for the fit was ', myResults.fitting_basis)\n", "print('The posteriors are ', myResults.post)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's convert back to the good old standard basis:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "My posterior in standard basis is [[ 7.89543529 6.92670011 4.84646336 ... 0.36153132 -6.60018059\n", " 1.52254522]\n", " [ 1.00000429 1.00000659 0.96238341 ... 1.00053929 1.00014326\n", " 0.9999977 ]\n", " [ 1.43499876 1.22274162 0.79516331 ... 1.46937244 2.46248534\n", " 1.23409917]\n", " ...\n", " [-15.56714749 -32.00898233 36.20941224 ... 0.37259827 51.54876654\n", " 1.74432532]\n", " [-15.57948843 -32.03039793 27.59120179 ... -0.24040739 51.48955988\n", " 1.78876637]\n", " [-15.58374489 -32.0532712 19.33822997 ... -0.12418502 51.45984325\n", " 1.80483891]]\n" ] } ], "source": [ "xyz_posterior = myResults.post\n", "\n", "standard_posterior = myResults.system.basis.to_standard_basis(xyz_posterior)\n", "\n", "print('My posterior in standard basis is ', standard_posterior)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And we're done! Enjoy the XYZ basis." ] } ], "metadata": { "interpreter": { "hash": "e899b22145868d3cd465733d82c36c2ae3ac0d3591d6a0807ec2e5e577a9cf5c" }, "kernelspec": { "display_name": "Python 3.7.5 64-bit ('python3.7': conda)", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }