I'm trying to update data in a mayavi 3D plot. Some of the changes to the data don't affect the data shape so the mlab_source.set() method can be used (which updates underlying data and refreshes the display without resetting the camera, regenerating the VTK pipeline, regenerating the underlying data structure, etc. This is the best possible case for animation or quick plot updates.
If the underlying data changes shape, the documentation recommends using the mlab_source.reset() method, which while not recreating the entire pipeline or messing up the current scene's camera, does cause the data structure to be rebuilt, costing some performance overhead. This is causing crashes for me.
The worst way to go is completely deleting the plot source and generating a new one with a new call to mlab.mesh() or whatever function was used to plot the data. This recreates a new VTK pipeline, new data structure, and resets the scene's view (loses current zoom and camera settings, which can make smooth interactivity impossible depending on the application).
I've illustrated a simple example from my application in which a Sphere class can have it's properties manipulated (position, size, and resolution). While changing position and size cause the coordinates to refresh, the data remains the same size. However changing the resolution affects the number of latitude and longitude subdivisions used to represent the sphere, which changes the number of coordinates. When attempting to use the "reset" function, the interpreter crashes completely. I'm pretty sure this is a C level segfault in the VTK code based on similar errors around the web. This thread seems to indicate the core developers dealing with the problem almost 5 years ago, but I can't tell if it was truly solved.
I am on Mayavi 4.3.1 which I got along with the Enthought Tool Suite with Python(x, y). I'm on Windows 7 64-bit with Python 2.7.5. I'm using PySide, but I removed those calls and let mlab work by itself for this example.
Here's my example that shows mlab_source.set() working but crashes on mlab_source.reset(). Any thoughts on why it's crashing? Can others duplicate it? I'm pretty sure there are other ways to update the data through the source (TVTK?) object, but I can't find it in the docs and the dozens of traits related attributes are very difficult to wade through.
Any help is appreciated!
#Numpy Imports
from time import sleep
import numpy as np
from numpy import sin, cos, pi
class Sphere(object):
"""
Class for a sphere
"""
def __init__(self, c=None, r=None, n=None):
#Initial defaults
self._coordinates = None
self._c = np.array([0.0, 0.0, 0.0])
self._r = 1.0
self._n = 20
self._hash = []
self._required_inputs = [('c', list),
('r', float)]
#Assign Inputs
if c is not None:
self.c = c
else:
self.c = self._c
if r is not None:
self.r = r
else:
self.r = self._r
if n is not None:
self.n = n
else:
self.n = self._n
#property
def c(self):
"""
Center point of sphere
- Point is specified as a cartesian coordinate triplet, [x, y, z]
- Coordinates are stored as a numpy array
- Coordinates input as a list will be coerced to a numpy array
"""
return self._c
#c.setter
def c(self, val):
if isinstance(val, list):
val = np.array(val)
self._c = val
#property
def r(self):
"""
Radius of sphere
"""
return self._r
#r.setter
def r(self, val):
if val < 0:
raise ValueError("Sphere radius input must be positive")
self._r = val
#property
def n(self):
"""
Resolution of curvature
- Number of points used to represent circles and arcs
- For a sphere, n is the number of subdivisions per hemisphere (in both latitude and longitude)
"""
return self._n
#n.setter
def n(self, val):
if val < 0:
raise ValueError("Sphere n-value for specifying arc/circle resolution must be positive")
self._n = val
#property
def coordinates(self):
"""
Returns x, y, z coordinate arrays to visualize the shape in 3D
"""
self._lazy_update()
return self._coordinates
def _lazy_update(self):
"""
Only update the coordinates data if necessary
"""
#Get a newly calculated hash based on the sphere's inputs
new_hash = self._get_hash()
#Get the old hash
old_hash = self._hash
#Check if the sphere's state has changed
if new_hash != old_hash:
#Something changed - update the coordinates
self._update_coordinates()
def _get_hash(self):
"""
Get the sphere's inputs as an immutable data structure
"""
return tuple(map(tuple, [self._c, [self._r, self._n]]))
def _update_coordinates(self):
"""
Calculate 3D coordinates to represent the sphere
"""
c, r, n = self._c, self._r, self._n
#Get the angular distance between latitude and longitude lines
dphi, dtheta = pi / n, pi / n
#Generate a latitude and longitude grid
[phi, theta] = np.mgrid[0:pi + dphi*1.0:dphi,
0:2 * pi + dtheta*1.0:dtheta]
#Map the latitude longitude grid into cartesian x, y, z coordinates
x = c[0] + r * cos(phi) * sin(theta)
y = c[1] + r * sin(phi) * sin(theta)
z = c[2] + r * cos(theta)
#Store the coordinates
self._coordinates = x, y, z
#Update the hash to coordinates to these coordinates
self._hash = self._get_hash()
if __name__ == '__main__':
from mayavi import mlab
#Make a sphere
sphere = Sphere()
#Plot the sphere
source = mlab.mesh(*sphere.coordinates, representation='wireframe')
#Get the mlab_source
ms = source.mlab_source
#Increase the sphere's radius by 2
sphere.r *= 2
#New coordinates (with larger radius)
x, y, z = sphere.coordinates
#Currently plotted coordinates
x_old, y_old, z_old = ms.x, ms.y, ms.z
#Verify that new x, y, z are all same shape as old x, y, z
data_is_same_shape = all([i.shape == j.shape for i, j in zip([x_old, y_old, z_old], [x, y, z])])
#Pause to see the old sphere
sleep(2)
#Check if data has changed shape... (shouldn't have)
if data_is_same_shape:
print "Updating same-shaped data"
ms.set(x=x, y=y, z=z)
else:
print "Updating with different shaped data"
ms.reset(x=x, y=y, z=z)
#Increase the arc resolution
sphere.n = 50
#New coordinates (with more points)
x, y, z = sphere.coordinates
#Currently plotted coordinates
x_old, y_old, z_old = ms.x, ms.y, ms.z
#Verify that new x, y, z are all same shape as old x, y, z
data_is_same_shape = all([i.shape == j.shape for i, j in zip([x_old, y_old, z_old], [x, y, z])])
#Pause to see the bigger sphere
sleep(2)
#Check if data has changed shape... (should have this time...)
if data_is_same_shape:
print "Updating same-shaped data"
ms.set(x=x, y=y, z=z)
else:
#This is where the segfault / crash occurs
print "Updating with different shaped data"
ms.reset(x=x, y=y, z=z)
mlab.show()
EDIT:
I just verified that all of these mlab_source tests pass for me which includes testing reset on an MGridSource. This does show some possible workarounds like accessing source.mlab_source.dataset.points ... maybe there's a way to update the data manually?
EDIT 2:
I tried this:
p = np.array([x.flatten(), y.flatten(), z.flatten()]).T
ms.dataset.points = p
ms.dataset.point_data.scalars = np.zeros(x.shape)
ms.dataset.points.modified()
#Regenerate the data structure
ms.reset(x=x, y=y, z=z)
It appears that modifying the TVTK Polydata object directly partly works. It appears that it's updating the points without also auto-fixing the connectivity, which is why I have to also run the mlab_source.reset(). I assume the reset() can work now because the data coming in has the same number of points and the mlab_source handles auto-generating the connectivity data. It still crashes when reducing the number of points, maybe because connectivity data exists for points that don't exist? I'm still very frustrated with this.
EDIT 3:
I've implemented the brute force method of just generating a new surface from mlab.mesh(). To prevent resetting the view I disable rendering and store the camera settings, then restore the camera settings after mlab.mesh() and then re-enable rendering. Seems to work quick enough - still wish underlying data could be updated with reset()
Here's the entire class I use to manage plotting objects (responds to GUI signals after an edit has been made).
class PlottablePrimitive(QtCore.QObject):
def __init__(self, parent=None, shape=None, scene=None, mlab=None):
super(PlottablePrimitive, self).__init__(parent=parent)
self._shape = None
self._scene = None
self._mlab = None
self._source = None
self._color = [0.706, 0.510, 0.196]
self._visible = True
self._opacity = 1.0
self._camera = {'position': None,
'focal_point': None,
'view_angle': None,
'view_up': None,
'clipping_range': None}
if shape is not None:
self._shape = shape
if scene is not None:
self._scene = scene
if mlab is not None:
self._mlab = mlab
#property
def shape(self):
return self._shape
#shape.setter
def shape(self, val):
self._shape = val
#property
def color(self):
return self._color
#color.setter
def color(self, color):
self._color = color
if self._source is not None:
surface = self._source.children[0].children[0].children[0]
surface.actor.mapper.scalar_visibility = False
surface.actor.property.color = tuple(color)
def plot(self):
x, y, z = self._shape.coordinates
self._source = self._mlab.mesh(x, y, z)
def update_plot(self):
ms = self._source.mlab_source
x, y, z = self._shape.coordinates
a, b, c = ms.x, ms.y, ms.z
data_is_same_shape = all([i.shape == j.shape for i, j in zip([a, b, c], [x, y, z])])
if data_is_same_shape:
print "Same Data Shape... updating"
#Update the data in-place
ms.set(x=x, y=y, z=z)
else:
print "New Data Shape... resetting data"
method = 'new_source'
if method == 'tvtk':
#Modify TVTK directly
p = np.array([x.flatten(), y.flatten(), z.flatten()]).T
ms.dataset.points = p
ms.dataset.point_data.scalars = np.zeros(x.shape)
ms.dataset.points.modified()
#Regenerate the data structure
ms.reset(x=x, y=y, z=z)
elif method == 'reset':
#Regenerate the data structure
ms.reset(x=x, y=y, z=z)
elif method == 'new_source':
scene = self._scene
#Save camera settings
self._save_camera()
#Disable rendering
self._scene.disable_render = True
#Delete old plot
self.delete_plot()
#Generate new mesh
self._source = self._mlab.mesh(x, y, z)
#Reset camera
self._restore_camera()
self._scene.disable_render = False
def _save_camera(self):
scene = self._scene
#Save camera settings
for setting in self._camera.keys():
self._camera[setting] = getattr(scene.camera, setting)
def _restore_camera(self):
scene = self._scene
#Save camera settings
for setting in self._camera.keys():
if self._camera[setting] is not None:
setattr(scene.camera, setting, self._camera[setting])
def delete_plot(self):
#Remove
if self._source is not None:
self._source.remove()
self._source = None
Related
You I have been trying to run this script but I keep getting an indentation error at the end
of the backprop(x,y) function. I would really appreciate ANY help!!
import cPickle
import gzip
def load_data():
f = gzip.open('mnist.pkl.gz', 'rb')
training_data, validation_data, test_data = cPickle.load(f)
f.close()
return (training_data, validation_data, test_data)
import numpy as np
class Network(object):
def __init__(self, layers):
self.layers = layers
self.biases = [np.random.randn(y,1) for y
in layers[1:]]
self.weights = [np.transpose(np.random.randn(x,y))
for x,y
in zip(layers[:-1],layers[1:])]
self.num_layers = len(layers)
def backprop(self, x, y):
nabla_b = [np.zeros(b.shape) for b in self.biases]
nabla_w = [np.zeros(w.shape) for w in self.weights]
# feedforward
activation = x
activations = [x] # list to store all the activations, layer by layer
zs = [] # list to store all the z vectors, layer by layer
for b, w in zip(self.biases, self.weights):
z = np.dot(w, activation)+b
zs.append(z)
activation = sigmoid(z)
activations.append(activation)
# backward pass
delta = self.cost_derivative(activations[-1], y) * \
sigmoid_prime(zs[-1])#set first delta
nabla_b[-1] = delta#set last dC/db to delta vector
nabla_w[-1] = np.dot(delta, activations[-2].transpose())
#calculate nabla_b, nabla_w for the rest of the layers
for l in xrange(2, self.num_layers):
z = zs[-l]
sp = sigmoid_prime(z)
delta = np.dot(self.weights[-l+1].transpose(), delta) * sp
nabla_b[-l] = delta
nabla_w[-l] = np.dot(delta, activations[-l-1].transpose())
#this is where python says there is an indent error!
return (nabla_b, nabla_w)
The problem was fixed by selecting the "edit" drop-down menu of Notepad++, choosing Blank Operations, and finally, clicking 'TAB to spaces'; obviously, this should be done after selecting the portion of code that's triggering the error.
I'm creating a contour plot with a list of V values as the x-axis and a list of T values as the y-axis (the V and T values are float numbers with 2 digits after the decimal point but all sorted of course). I created a data matrix and populated it with the data correlating with the V-T coordinates.
If it helps anything, this is what the contour plot would look like:
I'm trying to override the format_coord method to also display the data along with the x-y (V-T) coordinates when the cursor moves
I can't post all my code here, but here are the relevant parts:
fig= Figure()
a = fig.add_subplot(111)
contour_plot = a.contourf(self.pre_formating[0],self.pre_formating[1],datapoint) #Plot the contour on the axes
def fmt(x, y):
'''Overrides the original matplotlib method to also display z value when moving the cursor
'''
V_lst = self.pre_formating[0] #List of V points
T_lst = self.pre_formating[1] #List of T points
Zflat = datapoint.flatten() #Flatten out the data matrix
print 'first print line'
# get closest point with known v,t values
dist = distance.cdist([x,y],np.stack([V_lst, T_lst],axis=-1))
print 'second print line'
closest_idx = np.argmin(dist)
z = Zflat[closest_idx]
return 'x={x:.5f} y={y:.5f} z={z:.5f}'.format(x=x, y=y, z=z)
a.format_coord = fmt
The above code does not work (when I move the cursor nothing shows up, even the x,y value. The 'first print line' gets printed but the 'second print line' doesn't, so I think the problem is with the 'dist' line).
But when I change the 'dist' line to
dist = np.linalg.norm(np.vstack([V_lst - x, T_lst - y]), axis=0)
Everything works (x,y,data shows up) for a 37x37 matrix but not for a 37x46 matrix (37T, 46V) and I don't know why.
What should I do to make my code work?
Thank you for helping!
Hi guys so I found out how to solve this in case anyone needs it. There were 2 problems:
1/ I was wrong in my way of generating a list of V,T coordinates
2/ distance.cdist requires everything in the form of a 2d array.
So this is the final solution:
def fmt(x, y):
'''Overrides the original matplotlib method to also display z value when moving the cursor
'''
V_lst = self.pre_formating[0] #List of V points
T_lst = self.pre_formating[1] #List of T points
coordinates = [[v,t] for t in T_lst for v in V_lst] # List of coordinates (V,T)
Zflat = datapoint.flatten() #Flatten out the data matrix
# get closest point with known v,t values
dist = distance.cdist([[x,y]],coordinates)
closest_idx = np.argmin(dist)
z = Zflat[closest_idx]
return 'x={x:.5f} y={y:.5f} z={z:.5f}'.format(x=x, y=y, z=z)
My code runs fine with one rectangle, but as soon as I add a second rectangle, it says "'point' object is not callable" despite successfully calling it for the first one. I have ran enough tests with different variations on the rectangles to conclude that the only cause is the fact that it is now more than one rectangle it is trying to create. Can anyone please help?
Here's the start of the code, which is used to define different elements and their parameters.
import matplotlib.pyplot as plt
import numpy
elementset = []
pointxs = []
pointys = []
class point(object):
"""General point in 2d space, with stored x value and y value.
Created and used in elements to give them shape.
"""
def __init__(self, x, y):
self.x = float(x)
self.y = float(y)
self.isAnchor = False
def __repr__(self):
return "(%d, %d)" % (self.x, self.y)
class element(object):
"""Most general class used to define any element to be used in
the cross-section being worked with. Used as a basis for more
specific classes. Has a coordinate value and a number of points that
need to be generated for the general case.
"""
def __init__(self, anchor_x, anchor_y):
self.num_points = 0
self.anchor_x = float(anchor_x)
self.anchor_y = float(anchor_y)
elementset.append(self)
def getinfo(self):
"""Used for debugging, prints all general info for the element.
Never called in the actual code."""
print "Number of points: " + str(self.num_points)
print "x coordinate: " + str(self.anchor_x)
print "y coordinate: " + str(self.anchor_y)
def debug(self):
self.getinfo()
class rectangle(element):
"""A rectangle, assumed to be aligned such that all sides are
either vertical or horizontal. Calls/assigns variables
created in the element class via super().
"""
def __init__(self, anchor_x, anchor_y, width, height):
super(rectangle, self).__init__(anchor_x, anchor_y)
self.title = "rectangle"
self.num_points = 4
self.width = float(width)
self.height = float(height)
self.generate()
self.calculate()
def generate(self):
"""Creates the points that frame the rectangle using coordinates.
For a rectangle, the anchor point represents the bottom left point."""
self.anchor = point(self.anchor_x, self.anchor_y)
self.pointxpos = point(self.anchor_x + self.width, self.anchor_y)
self.pointxypos = point(self.anchor_x + self.width, self.anchor_y + self.height)
self.pointypos = point(self.anchor_x, self.anchor_y + self.height)
self.points = [self.anchor, self.pointxpos, self.pointxypos, self.pointypos]
self.plotpoints = [self.anchor, self.pointxpos, self.pointxypos, self.pointypos, self.anchor]
And here is the function that calls these (with only 1 rectangle defined):
ar = rectangle(0,0,50,20)
for element in elementset:
if isinstance(element,rectangle):
element.generate()
for point in element.plotpoints:
pointxs.append(point.x)
pointys.append(point.y)
plt.plot(pointxs,pointys, linewidth=3)
elif isinstance(element,square):
pass #placeholder
elif isinstance(element,circle):
pass #placeholder
elif isinstance(element,semicircle):
pass #placeholder
plt.show()
This is successful, plotting a 50x20 rectangle with the bottom left corner at (0,0).
but if i were to add another element below ar:
ar = rectangle(0,0,50,20)
br = rectangle(50,20,10,10)
it throws "'point' object is not callable".
I'm honestly stumped by this, so thank you so much in advance for any given help.
I wonder why this is even successful for the first rectangle. The reason may be that you don't show a true minimal complete verifiable example.
In any case, the problem is as simple as that: Never use the same name for a class as for a (loop-) variable.
I.e. the following works:
ar = rectangle(0,0,50,20)
br = rectangle(23,16,22,13)
elementset.append(ar)
elementset.append(br)
for elefant in elementset:
if isinstance(elefant,rectangle):
elefant.generate()
for wombat in elefant.plotpoints:
pointxs.append(wombat.x)
pointys.append(wombat.y)
plt.plot(pointxs,pointys, linewidth=3)
...apart from the fact that all objects will be connected by a line, but that is probably a different question.
I am new to matplotlib animation and am trying to animate a scatter plot where points moving towards the right will turn red gradually while points moving towards the left will turn blue gradually. The code doesn't work perfectly as it doesn't change the color of the points gradually. When I pause the animation and maximize it, the gradual change in color suddenly appears, when I play it, it is again the same. Here is the animation link. The final image should be something like this:
But the animation doesn't show gradual change of colors as you can see in the video.
Here is the code, I'd really appreciate your help. Thanks
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
import pandas as pd
class AnimatedScatter(object):
"""An animated scatter plot using matplotlib.animations.FuncAnimation."""
def __init__(self, numpoints=5):
self.numpoints = numpoints
self.stream = self.data_stream()
# Setup the figure and axes...
self.fig, self.ax = plt.subplots()
# Then setup FuncAnimation.
self.ani = animation.FuncAnimation(self.fig, self.update, interval=500,
init_func=self.setup_plot, blit=True,repeat=False)
self.fig.canvas.mpl_connect('button_press_event',self.onClick)
#self.ani.save("animation.mp4")
def setup_plot(self):
"""Initial drawing of the scatter plot."""
t=next(self.stream)
x, y, c = t[:,0],t[:,1],t[:,2]
self.scat = self.ax.scatter(x, y, c=c, s=50, animated=True)
self.ax.axis([-15, 15, -10, 10])
# For FuncAnimation's sake, we need to return the artist we'll be using
# Note that it expects a sequence of artists, thus the trailing comma.
return self.scat,
def data_stream(self):
#f=pd.read_csv("crc_viz.csv")
columns = ['TbyN','CbyS']
#f=f[['TbyN','CbyS']]
index=range(1,self.numpoints+1)
x=10*(np.ones((self.numpoints,1))-2*np.random.random((self.numpoints,1)))
y = 5*(np.ones((self.numpoints,1))-2*np.random.random((self.numpoints,1)))
f=np.column_stack((x,y))
f=pd.DataFrame(f,columns=columns)
print f
f['new_cbys'] = f['CbyS']
f['new_cbys'][f['new_cbys']<0] = -1
f['new_cbys'][f['new_cbys']>0] = 1
f=f[:self.numpoints]
cbys=np.array(list(f['CbyS']))
sign = np.array(list(f['new_cbys']))
x = np.array([0]*self.numpoints)
y = np.array(f['TbyN'])
c = np.array([0.5]*self.numpoints)
t = [(255,0,0) for i in range(self.numpoints)]
data=np.column_stack((x,y,c))
x = data[:, 0]
c = data[:,2]
while True:
#print xy
#print cbys
if not pause:
for i in range(len(x)):
if sign[i]==1:
if x[i]<cbys[i]-0.1:
x[i]+=0.1
c[i]+=0.05
else:
x[i]=cbys[i]
elif sign[i]==-1:
if x[i]>cbys[i]+0.1:
x[i]-=0.1
c[i]-=0.05
else:
x[i]=cbys[i]
print c
#print data
#print c
yield data
def onClick(self,event):
global pause
pause ^=True
def update(self, i):
"""Update the scatter plot."""
data = next(self.stream)
print data[:,2]
# Set x and y data...
self.scat.set_offsets(data[:, :2])
# Set colors..
self.scat.set_array(data[:,2])
return self.scat,
def save(self):
plt.rcParams['animation.ffmpeg_path'] = 'C:\\ffmpeg\\bin\\ffmpeg.exe'
self.mywriter = animation.FFMpegWriter()
self.ani.save("myMovie.mp4",writer=self.mywriter)
self.show()
def show(self):
#mng = plt.get_current_fig_manager()
#mng.window.state('zoomed')
plt.show()
pause = False
if __name__ == '__main__':
a = AnimatedScatter(10)
a.show()
#a.save()
The problem you have is that the scatter plot is redrawn in every iteration, renormalizing the colors to the minimal and maximal value of c. So even at the start there will be a dot coresponding to the minmal and maximal color in the colormap already.
The solution would be to use a color normalization which is absolute from the start. The easiest way to do this is using the vmin and vmax keyword arguments.
ax.scatter(x, y, c=c, vmin=-1.5, vmax=2)
(This means that a value of c=-1.5 is the lowest color in the colormap and c=2 corresponds to the highest.)
Now it may be a bit hard to find the appropriate values, as the values are constantly changing in an infinite loop, so you need to find out appropriate values yourself depending on the use case.
I was trying to plot a loss curve, but is always abnormal (just like a circle, I really don't know how to describe it in English properly), I had found many topics about question like this and just can't solve, my tensorflow version is 0.10.0.
import tensorflow as tf
from tensorflow.core.util.event_pb2 import SessionLog
import os
# initialize variables/model parameters
# define the training loop operations
def inputs():
# read/generate input training data X and expected outputs Y
weight_age = [[84,46],[73,20],[65,52],[70,30],[76,57],[69,25],[63,28],[72,36],[79,57],[75,44],[27,24]
,[89,31],[65,52],[57,23],[59,60],[69,48],[60,34],[79,51],[75,50],[82,34],[59,46],[67,23],
[85,37],[55,40],[63,30]]
blodd_fat_content = [354,190,405,263,451,302,288,385,402,365,209,290,346,
254,395,434,220,374,308,220,311,181,274,303,244]
return tf.to_float(weight_age), tf.to_float(blodd_fat_content)
def inference(X):
# compute inference model over data X and return the result
return tf.matmul(X, W) + b
def loss(X, Y):
# compute loss over training data X and expected outputs Y
Y_predicted = inference(X)
return tf.reduce_sum(tf.squared_difference(Y, Y_predicted))
def train(total_loss):
# train / adjust model parameters according to computed total loss
learning_rate = 1e-7
return tf.train.GradientDescentOptimizer(learning_rate).minimize(total_loss)
def evaluate(sess, X, Y):
# evaluate the resulting trained model
print (sess.run(inference([[80., 25.]])))
print (sess.run(inference([[60., 25.]])))
g1 = tf.Graph()
with tf.Session(graph=g1) as sess:
W = tf.Variable(tf.zeros([2,1]), name="weights")
b = tf.Variable(0., name="bias")
tf.initialize_all_variables().run()
X, Y = inputs()
print (sess.run(W))
total_loss = loss(X, Y)
train_op = train(total_loss)
tf.scalar_summary("loss", total_loss)
summaries = tf.merge_all_summaries()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
summary_writer = tf.train.SummaryWriter('linear', g1)
summary_writer.add_session_log(session_log= SessionLog(status=SessionLog.START), global_step=1)
# actual training loop
training_steps = 100
tolerance = 100
total_loss_last = 0
initial_step = 0
# Create a saver.
saver = tf.train.Saver()
# verify if we don't have a checkpoint saved already
ckpt = tf.train.get_checkpoint_state(os.path.dirname('my_model'))
if ckpt and ckpt.model_checkpoint_path:
# Restores from checkpoint
saver.restore(sess, ckpt.model_checkpoint_path)
initial_step = int(ckpt.model_checkpoint_path.rsplit('-', 1)[1])
# summary_writer.add_session_log(SessionLog(status=SessionLog.START), global_step=initial_step)
for step in range(initial_step, training_steps):
sess.run([train_op])
if step%20 == 0:
saver.save(sess, 'my-model', global_step=step)
gap = abs(sess.run(total_loss) - total_loss_last)
total_loss_last = sess.run(total_loss)
summary_writer.add_summary(sess.run(summaries), step)
# for debugging and learning purposes, see how the loss gets decremented thru training steps
if step % 10 == 0:
print ("loss: ", sess.run([total_loss]))
print("step: ", step)
if gap < tolerance:
break
# evaluation...
evaluate(sess, X, Y)
coord.request_stop()
coord.join(threads)
saver.save(sess, 'my-model', global_step=training_steps)
summary_writer.flush()
sess.close()