Custom resource in Tensorflow - c++

For some reasons, I need to implement a custom resource for Tensorflow. I tried to get inspiration from lookup table implementations. If I understood well, I need to implement 3 TF operations:
creation of my resource
initialization of the resource (e.g. populate the hash table in case of the lookup table)
implementation of the find / lookup / query step.
To facilitate the implementation, I'm relying on tensorflow/core/framework/resource_op_kernel.h. I get the following error
[F tensorflow/core/lib/core/refcount.h:90] Check failed: ref_.load() == 0 (1 vs. 0)
1] 29701 abort python test.py
Here is the full code to reproduce the issue:
using namespace tensorflow;
/** CUSTOM RESOURCE **/
class MyVector : public ResourceBase {
public:
string DebugString() override { return "MyVector"; };
private:
std::vector<int> vec_;
};
/** CREATE VECTOR **/
REGISTER_OP("CreateMyVector")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.Output("resource: resource")
.SetIsStateful();
class MyVectorOp : public ResourceOpKernel<MyVector> {
public:
explicit MyVectorOp(OpKernelConstruction* ctx) : ResourceOpKernel(ctx) {}
private:
Status CreateResource(MyVector** resource) override {
*resource = CHECK_NOTNULL(new MyVector);
if(*resource == nullptr) {
return errors::ResourceExhausted("Failed to allocate");
}
return Status::OK();
}
Status VerifyResource(MyVector* vec) override {
return Status::OK();
}
};
REGISTER_KERNEL_BUILDER(Name("CreateMyVector").Device(DEVICE_CPU), MyVectorOp)
and then, after compiling, the error can be reproduced with this Python snippet of code:
test_module = tf.load_op_library('./test.so')
my_vec = test_module.create_my_vector()
with tf.Session() as s:
s.run(my_vec)
As a side question, I'd be interested in having tutorials / guidelines to implement custom resources. In particular, I'd like to have information about what needs to be implemented for checkpoints / graph export / serialization / etc.
Thanks a lot.

Add -DNDEBUG to your build flags.
This workaround is explained in TF issue 17316.

Related

How to use custom logger with websocketpp?

I am creating a telemetry server using websocketpp, and have followed the example here. My application will be running as a linux daemon which starts on boot, and therefore I won't be able to write logs to standard out. I would therefore like to add a customer logger using spdlog, and understand that it can be done based on what's on this page. Looks like I need to use the websocketpp::log::stub interface to create my own customer logger. The issue is, the documentation on this is quite limited regarding logging, and I am not sure where to begin and how to incorporate it in the context of the telemetry server example linked above. I am not sure how to specify the logger when I define my server: typedef websocketpp::server<websocketpp::config::asio> server;.
How do I go about extending the stub class, and how do I initialize my server with this customer logger?
The only sample code I could find is in this thread here, but based on the linked comment this code is no longer relevant after V 0.3.x+.
For anyone wanting sample code using spdlog as a customer logger, here it is:
Create a new file customerLogger.hpp with the contents:
#pragma once
#include <websocketpp/logger/basic.hpp>
#include <websocketpp/common/cpp11.hpp>
#include <websocketpp/logger/levels.hpp>
#include "spdlog/logger.h"
#include "spdlog/sinks/rotating_file_sink.h"
namespace websocketpp {
namespace log {
/// Basic logger that outputs to syslog
template <typename concurrency, typename names>
class myLogger : public basic<concurrency, names> {
public:
typedef basic<concurrency, names> base;
/// Construct the logger
/**
* #param hint A channel type specific hint for how to construct the logger
*/
myLogger<concurrency,names>(channel_type_hint::value hint =
channel_type_hint::access)
: basic<concurrency,names>(hint), m_channel_type_hint(hint) {
auto max_size = 1048576 * 5;
auto max_files = 3;
auto rotating_sink = std::make_shared<spdlog::sinks::rotating_file_sink_mt> ("/var/logs/my_logger.log", max_size, max_files);
m_logger = std::make_shared<spdlog::logger>("my_logger", rotating_sink);
m_logger->flush_on(spdlog::level::info);
m_logger->set_level(spdlog::level::level_enum::info);
}
/// Construct the logger
/**
* #param channels A set of channels to statically enable
* #param hint A channel type specific hint for how to construct the logger
*/
myLogger<concurrency,names>(level channels, channel_type_hint::value hint =
channel_type_hint::access)
: basic<concurrency,names>(channels, hint), m_channel_type_hint(hint) {
auto max_size = 1048576 * 5;
auto max_files = 3;
auto rotating_sink = std::make_shared<spdlog::sinks::rotating_file_sink_mt> ("/var/logs/my_logger.log", max_size, max_files);
m_logger = std::make_shared<spdlog::logger>("my_logger", rotating_sink);
m_logger->flush_on(spdlog::level::info);
m_logger->set_level(spdlog::level::level_enum::info);
}
/// Write a string message to the given channel
/**
* #param channel The channel to write to
* #param msg The message to write
*/
void write(level channel, std::string const & msg) {
write(channel, msg.c_str());
}
/// Write a cstring message to the given channel
/**
* #param channel The channel to write to
* #param msg The message to write
*/
void write(level channel, char const * msg) {
scoped_lock_type lock(base::m_lock);
if (!this->dynamic_test(channel)) { return; }
if (m_channel_type_hint == channel_type_hint::access) {
m_logger->info(msg);
} else {
if (channel == elevel::devel) {
m_logger->debug(msg);
} else if (channel == elevel::library) {
m_logger->debug(msg);
} else if (channel == elevel::info) {
m_logger->info(msg);
} else if (channel == elevel::warn) {
m_logger->warn(msg);
} else if (channel == elevel::rerror) {
m_logger->error(msg);
} else if (channel == elevel::fatal) {
m_logger->critical(msg);
}
}
}
private:
std::shared_ptr<spdlog::logger> m_logger;
typedef typename base::scoped_lock_type scoped_lock_type;
channel_type_hint::value m_channel_type_hint;
};
} // log
} // websocketpp
Next, create another file, customConfig.hpp which has the following content:
#pragma once
#include "./customLogger.hpp"
#include <websocketpp/logger/syslog.hpp>
#include <websocketpp/extensions/permessage_deflate/enabled.hpp>
// Custom server config based on bundled asio config
struct my_config : public websocketpp::config::asio {
// Replace default stream logger with the custom logger
typedef websocketpp::log::myLogger<concurrency_type, websocketpp::log::elevel> elog_type;
typedef websocketpp::log::myLogger<concurrency_type, websocketpp::log::alevel> alog_type;
};
typedef websocketpp::server<my_config> my_server;
Finally, when you want to create the server, you simple do my_server endpoint;
Building a custom logger has two steps. First, write a policy class with the appropriate interface then create a custom config that uses that policy.
To write the policy class websocketpp::log::stub is a minimal implementation that doesn't actually do anything (it is primarily used for stubbing out logging in the unit tests) but it demonstrates and documents the interface that a logging class needs to implement. The logging class does not need to be a subclass of websocketpp::log::stub. You can look at other examples in the websocketpp/logger/* folder. The syslog logger in particular might be interesting as an example of a logging policy that outputs to something other than standard out.
To set up the custom config you will create a config class. It can be standalone or a subclass of one of the standard ones, like websocketpp::config::asio, that just overrides a small number of things. Your config might only override the loggers, for example. Once created, you will pass your config class into the endpoint template parameter instead of websocketpp::config::asio.
More details about what you can override at compile time via this config system can be found at https://docs.websocketpp.org/reference_8config.html. There is an example on this page that shows a custom config that replaces the default logger (among other changes).

Checking uniqueness of OperationId when using c.CustomOperationIds?

At the moment, if you use the following
c.CustomOperationIds(apiDesc =>
{
return apiDesc.TryGetMethodInfo(out MethodInfo methodInfo) ? methodInfo.Name : null;
});
And, by programmer mistake, you have 2 methods with same name, you are not warned that you violate OpenAPI spec
Is there a way to add a check ?
I was thinking to either
at the end of generation like "2 operations with id {0}"
when swashbuckle calls CustomOperationId "selector", have a hook to access already defined operations
Thanks for your time
P.S : using Swashbuckle.AspNetCore.SwaggerGen 5.3.1
After many tries, I've found a workaround: use an operation filter that will throw an exception if OperationId is already used
using Microsoft.OpenApi.Models;
using Swashbuckle.AspNetCore.SwaggerGen;
using System;
using System.Collections.Generic;
namespace Service.Utils
{
/// <summary>
/// Guarantee that OperationId is not already used
/// </summary>
public class SwaggerUniqueOperationId : IOperationFilter
{
private readonly HashSet<string> ids = new HashSet<string>();
public void Apply(OpenApiOperation operation, OperationFilterContext context)
{
if (operation.OperationId != null)
{
if (ids.Contains(operation.OperationId))
throw new NotSupportedException($"There are 2 operations with same OperationId {operation.OperationId}");
ids.Add(operation.OperationId);
}
}
}
}
This is not ideal at all because the error message is pretty vague and it is a runtime error, but this is better than providing an OpenApi spec that violates this unique OperationId constraint...

Builder pattern - configuration file reading

I'm facing a design problem. I want to separate building objects with a builder pattern, but the problem is that objects have to be built from configuration file.
So far I have decided that all objects, created from configuration, will be stored in DataContext class (container for all objects), because these objects states will be updated from a transmission (so it's easier to have them in one place).
I'm using external library for reading from XML file - and my question is how to hide it - is it better to inject it to concreteBuilder class? I have to notice that builder class will have to create lots of objects and at the end - connect them between each other.
Base class could look like that:
/*
* IDataContextBuilder
* base class for building data context object
* and sub obejcts
*/
class IDataContextBuilder {
public:
/*
* GetResult()
* returns result of building process
*/
virtual DataContext * GetResult () = 0;
/*
* Virtual destructor
*/
virtual ~IDataContextBuilder() { }
};
class ConcreteDataContextBuilder {
public:
ConcreteDataContextBuilder(pugi::xml_node & rootNode);
DataContext * GetResult ();
}
How to implement it correctly? What could be better pattern to build classes from configuration files?
I don't see a problem with that, but maybe you could inject another 'Director' class that receives a specific builder, loads the config files, and produces objects calling the respective builder-subclasses.
What I mean:
class DataContextDirector {
public:
void SetBuilder(IDataContextBuilder* builder);
void SetConfig(const std::string& configFilePath); // or whatever
DataContext* ProduceObject() {
// pseudo-code here:
// myBuilder->setup(xmlNodeOfConfig);
// return myBuilder->GetResult();
}
};

RavenDb : Force indexes to wait until not stale whilst unit testing

When unit testing with RavenDb, it is often the case that newly added data is retrieved or otherwise processed. This can lead to 'stale index' exceptions e.g.
Bulk operation cancelled because the index is stale and allowStale is false
According to a number of answers
How should stale indexes be handled during testing?
WaitForNonStaleResults per DocumentStore
RavenDb : Update a Denormalized Reference property value
The way to force the database (the IDocumentStore instance) to wait until its indexes are not stale before processing a query or batch operation is to use DefaultQueryingConsistency = ConsistencyOptions.QueryYourWrites during the IDocumentStore initialisation, like this:
public class InMemoryRavenSessionProvider : IRavenSessionProvider
{
private static IDocumentStore documentStore;
public static IDocumentStore DocumentStore
{
get { return (documentStore ?? (documentStore = CreateDocumentStore())); }
}
private static IDocumentStore CreateDocumentStore()
{
var store = new EmbeddableDocumentStore
{
RunInMemory = true,
Conventions = new DocumentConvention
{
DefaultQueryingConsistency = ConsistencyOptions.QueryYourWrites,
IdentityPartsSeparator = "-"
}
};
store.Initialize();
IndexCreation.CreateIndexes(typeof (RavenIndexes).Assembly, store);
return store;
}
public IDocumentSession GetSession()
{
return DocumentStore.OpenSession();
}
}
Unfortunately, the code above does not work. I am still receiving exceptions regarding stale indexes. These can be resolved by putting in dummy queries that include .Customize(x => x.WaitForNonStaleResultsAsOfLastWrite()).
This is fine, as long as these can be contained in the Unit Test, but what if they can't? I am finding that these WaitForNonStaleResults* calls are creeping into production code just so I can get unit-tests to pass.
So, is there a sure fire way, using the latest version of RavenDb, to force the indexes to freshen before allowing commands to be processed - for the purposes of unit testing only?
Edit 1
Here is a solution based on the answer give below that forces a wait until the index is not stale. I have written it as an extension method for the sake of unit-testing convenience;
public static class IDocumentSessionExt
{
public static void ClearStaleIndexes(this IDocumentSession db)
{
while (db.Advanced.DatabaseCommands.GetStatistics().StaleIndexes.Length != 0)
{
Thread.Sleep(10);
}
}
}
And here is a Unit Test that was using the verbose WaitForNonStaleResultsAsOfLastWrite technique but now uses the neater extension method.
[Fact]
public void Should_return_list_of_Relationships_for_given_mentor()
{
using (var db = Fake.Db())
{
var mentorId = Fake.Mentor(db).Id;
Fake.Relationship(db, mentorId, Fake.Mentee(db).Id);
Fake.Relationship(db, mentorId, Fake.Mentee(db).Id);
Fake.Relationship(db, Fake.Mentor(db).Id, Fake.Mentee(db).Id);
//db.Query<Relationship>()
// .Customize(x => x.WaitForNonStaleResultsAsOfLastWrite())
// .Count()
// .ShouldBe(3);
db.ClearStaleIndexes();
db.Query<Relationship>().Count().ShouldBe(3);
MentorService.GetRelationships(db, mentorId).Count.ShouldBe(2);
}
}
If you have a Map/Reduce index, DefaultQueryingConsistency = ConsistencyOptions.QueryYourWrites won't work. You need to use an alternative method.
In your units tests, call code like this, straight after you've inserted any data, this will force the all indexes to update before you do anything else:
while (documentStore.DatabaseCommands.GetStatistics().StaleIndexes.Length != 0)
{
Thread.Sleep(10);
}
Update You can of course put it in an extension method if you want to:
public static class IDocumentSessionExt
{
public static void ClearStaleIndexes(this IDocumentSession db)
{
while (db.Advanced.DatabaseCommands.GetStatistics().StaleIndexes.Length != 0)
{
Thread.Sleep(10);
}
}
}
Then you can say:
db.ClearStaleIndexes();
You can actually add a query listener on the DocumentStore to wait for nonstale results. This can be used just for unit tests as it is on the document store and not each operation.
// Initialise the Store.
var documentStore = new EmbeddableDocumentStore
{
RunInMemory = true
};
documentStore.Initialize();
// Force queries to wait for indexes to catch up. Unit Testing only :P
documentStore.RegisterListener(new NoStaleQueriesListener());
....
#region Nested type: NoStaleQueriesListener
public class NoStaleQueriesListener : IDocumentQueryListener
{
#region Implementation of IDocumentQueryListener
public void BeforeQueryExecuted(IDocumentQueryCustomization queryCustomization)
{
queryCustomization.WaitForNonStaleResults();
}
#endregion
}
#endregion
(Shamelessly stolen from RavenDB how to flush?)
Be aware that StaleIndexes also include abondoned and disabled indices - which will never get up to date.
So to avoid waiting indefinetely use this property instead:
var staleIndices = store.DatabaseCommands.GetStatistics().CountOfStaleIndexesExcludingDisabledAndAbandoned;

Execute #PostLoad _after_ eagerly fetching?

Using JPA2/Hibernate, I've created an entity A that has a uni-directional mapping to an entity X (see below). Inside A, I also have a transient member "t" that I am trying to calculate using a #PostLoad method. The calculation requires access to the assosiated Xs:
#Entity
public class A {
// ...
#Transient
int t;
#OneToMany(orphanRemoval = false, fetch = FetchType.EAGER)
private List listOfX;
#PostLoad
public void calculateT() {
t = 0;
for (X x : listOfX)
t = t + x.someMethod();
}
}
However, when I try to load this entity, I get a "org.hibernate.LazyInitializationException: illegal access to loading collection" error.
at org.hibernate.collection.AbstractPersistentCollection.initialize(AbstractPersistentCollection.java:363)
at org.hibernate.collection.AbstractPersistentCollection.read(AbstractPersistentCollection.java:108)
at org.hibernate.collection.PersistentBag.get(PersistentBag.java:445)
at java.util.Collections$UnmodifiableList.get(Collections.java:1154)
at mypackage.A.calculateT(A.java:32)
Looking at hibernate's code (AbstractPersistentCollection.java) while debugging, I found that:
1) My #PostLoad method is called BEFORE the "listOfX" member is initialized
2) Hibernate's code has an explicit check to prevent initialization of an eagerly fetched collection during a #PostLoad:
protected final void initialize(boolean writing) {
if (!initialized) {
if (initializing) {
throw new LazyInitializationException("illegal access to loading collection");
}
throwLazyInitializationExceptionIfNotConnected();
session.initializeCollection(this, writing);
}
}
The only way I'm thinking to fix this is to stop using #PostLoad and move the initialization code into the getT() accessor, adding a synchronized block. However, I want to avoid that.
So, is there a way to have eager fetching executed prior to #PostLoad being called? I don't know of a JPA facility to do that, so I'm hoping there's something I don't know.
Also, perhaps Hibernate's proprietary API has something to control this behaviour?
This might be too late, but hibernate seems not to support the default jpa fetchtype option
#OneToMany(orphanRemoval = false, fetch = FetchType.EAGER)
You must use the hibernate specific one:
#LazyCollection(LazyCollectionOption.FALSE)
I don't know how to fix this but I think a little refactoring might help, the idea would be to move the code to a #PostConstruct
so for example your class would be:
#Entity
public class A {
// ...
#Transient
int t;
#OneToMany(orphanRemoval = false, fetch = FetchType.EAGER)
private List listOfX;
#PostConstruct
public void calculateT() {
t = 0;
for (X x : listOfX)
t = t + x.someMethod();
}
}
The server will call PostConstruct as soon as it has completed initializing all the container services for the bean.
Updated link to bug report:
https://hibernate.atlassian.net/browse/HHH-6043
This is fixed in 4.1.8 and 4.3.0 or later