Skip to content

Commit

Permalink
Add 3D
Browse files Browse the repository at this point in the history
  • Loading branch information
hdsingh committed Aug 3, 2020
1 parent b596923 commit f6dc17e
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 1 deletion.
4 changes: 4 additions & 0 deletions xrviz/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import xarray as xr
from .sigslot import SigSlot
from .display import Display
from .display3d import Display3d
from .describe import Describe
from .fields import Fields
from .style import Style
Expand Down Expand Up @@ -53,10 +54,12 @@ def __init__(self, data):
super().__init__()
self.data = data
self.displayer = Display(self.data)
self.displayer3d = Display3d(self.data)
self.describer = Describe(self.data)
self.fields = Fields(self.data)
self.style = Style()
self.coord_setter = CoordSetter(self.data)

self.tabs = pn.Tabs(
pn.Column(
pn.pane.Markdown(TEXT, margin=(0, 10)),
Expand All @@ -67,6 +70,7 @@ def __init__(self, data):
self.coord_setter.panel,
self.fields.panel,
self.style.panel,
self.displayer3d.panel,
background='#f5f5f5',
width_policy='max',
tabs_location='left',
Expand Down
4 changes: 3 additions & 1 deletion xrviz/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ def __init__(self, data, initial_params={}):

self.panel = pn.Column(self.control.panel,
pn.Row(self.plot_button,
self.clear_series_button),
self.clear_series_button,
self.control.displayer3d.plot_button_3d),
self.control.displayer3d.data_cube,
self.output,
self.series_graph, width_policy='max')

Expand Down
134 changes: 134 additions & 0 deletions xrviz/display3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from .utils3d import *
from xrviz.sigslot import SigSlot
from .selection3d import Selection3d
import numpy as np
import panel as pn
import plotly
import plotly.graph_objects as go
import plotly.express as px
import xarray as xr
import warnings
from textwrap import wrap
import json
pn.extension('plotly')
warnings.filterwarnings('ignore')

class Display3d(SigSlot):
"""Displays a list of data variables for selection.
"""

def __init__(self, data, skip_lat=20, skip_lon=20, skip_lev=2,colorscale="inferno"):
super().__init__()
self.data = data
self.coords = self.data.coords
self.sel3d = Selection3d()
self.skip_lat = int(self.sel3d.kwargs['skip_lat'])
self.skip_lon = int(self.sel3d.kwargs['skip_lon'])
self.skip_lev = int(self.sel3d.kwargs['skip_lev'])
self.X_map, self.Y_map = self.get_map_X_Y()

self.colorscale = self.sel3d.kwargs['cmap_3d']
self.levs_is_sorted = is_sorted(self.data['lev'])
self.X,self.Y, self.Z = np.meshgrid(
self.data['lon'][::skip_lon],
self.data['lat'][::skip_lat],
check_levs(self.levs_is_sorted,self.data['lev'][::skip_lev]),
indexing='xy')

self.name = 'Variables 3D'
self.select_var_3d = pn.widgets.Select(
min_width=100, max_width=200, width_policy='max',
name=self.name)
self.time_select_3d = pn.widgets.Select(
min_width=100, max_width=200, width_policy='max',
name='Time')
self.plot_button_3d = pn.widgets.Button(name='Plot 3D', width=200)
self.data_cube = pn.Row(pn.Spacer(name='Series Graph'))

self.panel = pn.Column(
pn.Column(
pn.Row(self.select_var_3d, self.time_select_3d),
self.sel3d.panel,
),
name='3D Cube',
)

self.set_variables()
self.set_times()

self._register(self.select_var_3d, "variable_selected")
self._register(self.time_select_3d, "time_selected")
self._register(self.plot_button_3d, 'plot_clicked', 'clicks')
self.connect('plot_clicked', self.create_cube)

def set_variables(self,):
self.select_var_3d.options = [var for var in list(self.data.variables) if var not in self.coords]

def set_times(self,):
self.time_select_3d.options = [val for val in self.data.time.data]

def create_cube(self, _):
var =self.select_var_3d.value
tval = self.time_select_3d.value
self.ds = self.data[var].sel(time=tval).data#.compute()

self.ds = np.transpose(self.ds, (2,1,0))
isomin = round(self.ds.min())
isomax = round(self.ds.max())

volume = plotly.graph_objects.Volume(
x=self.X.flatten(),
y=self.Y.flatten(),
z=self.Z.flatten(),
value=get_vals(self.levs_is_sorted,
self.ds,
self.skip_lon,
self.skip_lat,
self.skip_lev).flatten(),
isomin=isomin,
isomax=isomax,
opacity=self.sel3d.kwargs['opacity'], # needs to be small to see through all surfaces
surface_count = self.sel3d.kwargs['surface_count'], # needs to be a large number for good volume rendering,
colorscale = self.sel3d.kwargs['cmap_3d']
)

plotly_layout = go.Layout(
# title = 'India 3D Plot',
autosize = False,
width = self.sel3d.kwargs['frame_width_3d'],
height = self.sel3d.kwargs['frame_height_3d'],
margin = dict(t=100, b=100, r=100, l=100),
scene = dict(
xaxis_title='Longitude',
yaxis_title='Latitude',
zaxis_title='Pressure',
zaxis_autorange = "reversed",
))
map_lines = go.Scatter3d(x = self.X_map,
y = self.Y_map,
z = [1000]*len(self.X_map),
mode='lines')

fig = dict(data=[volume, map_lines], layout=plotly_layout)
self.data_cube[0] = pn.pane.Plotly(fig)

def get_map_X_Y(self,):
with open("data/INDIA_STATES.json") as json_file:
jdata = json_file.read()
geoJSON = json.loads(jdata)
pts=[]#list of points defining boundaries of polygons
for feature in geoJSON['features']:
if feature['geometry']['type']=='Polygon':
pts.extend(feature['geometry']['coordinates'][0])
pts.append([None, None])#mark the end of a polygon

elif feature['geometry']['type']=='MultiPolygon':
for polyg in feature['geometry']['coordinates']:
pts.extend(polyg[0])
pts.append([None, None])#end of polygon
else: raise ValueError("geometry type irrelevant for map")
return zip(*pts)

@property
def kwargs(self):
return {self.name: self.select.value}
40 changes: 40 additions & 0 deletions xrviz/selection3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import panel as pn
import plotly.express as px
from .sigslot import SigSlot

class Selection3d(SigSlot):
def __init__(self):
super().__init__()
self.skip_lon = pn.widgets.TextInput(name='skip_lon', value = "20", width=140)
self.skip_lat = pn.widgets.TextInput(name='skip_lat', value = "20", width=140)
self.skip_lev = pn.widgets.TextInput(name='skip_lev', value = "2", width=140)

self.frame_height_3d = pn.widgets.IntSlider(name='frame_height_3d', value=600, start=100,
end=1200)
self.frame_width_3d = pn.widgets.IntSlider(name='frame_width_3d', value=600, start=100,
end=1200)
self.surface_count = pn.widgets.IntSlider(name='surface_count', value=30, start=0,
end=100)
self.cmap_3d = pn.widgets.Select(name='cmap_3d', value='inferno',
options=px.colors.named_colorscales())
self.opacity = pn.widgets.FloatSlider(name='opacity', start=0, end=1,
step=0.01, value=0.2, width=300)

# self.lower_limit = pn.widgets.TextInput(name='cmap lower limit', width=140)
# self.upper_limit = pn.widgets.TextInput(name='cmap upper limit', width=140)
TEXT = """Customize the Cube."""

self.panel = pn.Column(
pn.pane.Markdown(TEXT, margin=(0, 10)),
pn.Row(self.skip_lon, self.skip_lat, self.skip_lev),
pn.Row(self.frame_height_3d, self.frame_width_3d),
pn.Row(self.opacity, self.surface_count),
pn.Row(self.cmap_3d),
name='Selection_3d'
)

@property
def kwargs(self):
out = {widget.name: widget.value
for row in self.panel[1:] for widget in row}
return out
20 changes: 20 additions & 0 deletions xrviz/utils3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import numba

@numba.jit
def is_sorted(a):
for i in range(a.size-1):
if a[i+1] < a[i] :
return False
return True

def check_levs(is_sorted,levs):
if is_sorted:
return levs
else:
return levs[::-1]

def get_vals(is_sorted,arr,skip_lon,skip_lat,skip_lev):
if is_sorted:
return arr[::skip_lon,::skip_lat,::skip_lev]
else:
return (arr[::skip_lon,::skip_lat,::skip_lev])[::,::,::-1]

0 comments on commit f6dc17e

Please sign in to comment.