Commit 39804417 authored by Jacob Potter's avatar Jacob Potter

Support non-nullable pointers.

This enables use of Djinni with statically-checked non-nullable pointers,
like https://github.com/dropbox/nn
parent 3c1ea5cb
...@@ -53,7 +53,11 @@ class CppMarshal(spec: Spec) extends Marshal(spec) { ...@@ -53,7 +53,11 @@ class CppMarshal(spec: Spec) extends Marshal(spec) {
} }
case DInterface => case DInterface =>
if (d.name != exclude) { if (d.name != exclude) {
List(ImportRef("<memory>"), DeclRef(s"class ${typename(d.name, d.body)};", Some(spec.cppNamespace))) val base = List(ImportRef("<memory>"), DeclRef(s"class ${typename(d.name, d.body)};", Some(spec.cppNamespace)))
spec.cppNnHeader match {
case Some(nnHdr) => ImportRef(nnHdr) :: base
case _ => base
}
} else { } else {
List(ImportRef("<memory>")) List(ImportRef("<memory>"))
} }
...@@ -93,9 +97,39 @@ class CppMarshal(spec: Spec) extends Marshal(spec) { ...@@ -93,9 +97,39 @@ class CppMarshal(spec: Spec) extends Marshal(spec) {
case p: MParam => idCpp.typeParam(p.name) case p: MParam => idCpp.typeParam(p.name)
} }
def expr(tm: MExpr): String = { def expr(tm: MExpr): String = {
spec.cppNnType match {
case Some(nnType) => {
// if we're using non-nullable pointers for interfaces, then special-case
// both optional and non-optional interface types
val args = if (tm.args.isEmpty) "" else tm.args.map(expr).mkString("<", ", ", ">")
tm.base match {
case d: MDef =>
d.defType match {
case DInterface => s"${nnType}<${withNs(namespace, idCpp.ty(d.name))}>"
case _ => base(tm.base) + args
}
case MOptional =>
tm.args.head.base match {
case d: MDef =>
d.defType match {
case DInterface => s"std::shared_ptr<${withNs(namespace, idCpp.ty(d.name))}>"
case _ => base(tm.base) + args
}
case _ => base(tm.base) + args
}
case _ => base(tm.base) + args
}
}
case None =>
if (isOptionalInterface(tm)) {
// otherwise, interfaces are always plain old shared_ptr
expr(tm.args.head)
} else {
val args = if (tm.args.isEmpty) "" else tm.args.map(expr).mkString("<", ", ", ">") val args = if (tm.args.isEmpty) "" else tm.args.map(expr).mkString("<", ", ", ">")
base(tm.base) + args base(tm.base) + args
} }
}
}
expr(tm) expr(tm)
} }
......
...@@ -250,7 +250,8 @@ class JNIGenerator(spec: Spec) extends Generator(spec) { ...@@ -250,7 +250,8 @@ class JNIGenerator(spec: Spec) extends Generator(spec) {
val ret = cppMarshal.fqReturnType(m.ret) val ret = cppMarshal.fqReturnType(m.ret)
val params = m.params.map(p => cppMarshal.fqParamType(p.ty) + " c_" + idCpp.local(p.ident)) val params = m.params.map(p => cppMarshal.fqParamType(p.ty) + " c_" + idCpp.local(p.ident))
writeJniTypeParams(w, typeParams) writeJniTypeParams(w, typeParams)
w.w(s"$ret $jniSelfWithParams::JavaProxy::${idCpp.method(m.ident)}${params.mkString("(", ", ", ")")}").braced { val methodNameAndSignature: String = s"${idCpp.method(m.ident)}${params.mkString("(", ", ", ")")}"
w.w(s"$ret $jniSelfWithParams::JavaProxy::$methodNameAndSignature").braced {
w.wl(s"auto jniEnv = ::djinni::jniGetThreadEnv();") w.wl(s"auto jniEnv = ::djinni::jniGetThreadEnv();")
w.wl(s"::djinni::JniLocalScope jscope(jniEnv, 10);") w.wl(s"::djinni::JniLocalScope jscope(jniEnv, 10);")
w.wl(s"const auto& data = ::djinni::JniClass<${withNs(Some(spec.jniNamespace), jniSelf)}>::get();") w.wl(s"const auto& data = ::djinni::JniClass<${withNs(Some(spec.jniNamespace), jniSelf)}>::get();")
...@@ -269,7 +270,18 @@ class JNIGenerator(spec: Spec) extends Generator(spec) { ...@@ -269,7 +270,18 @@ class JNIGenerator(spec: Spec) extends Generator(spec) {
w.w(")") w.w(")")
w.wl(";") w.wl(";")
w.wl(s"::djinni::jniExceptionCheck(jniEnv);") w.wl(s"::djinni::jniExceptionCheck(jniEnv);")
m.ret.fold()(r => w.wl(s"return ${jniMarshal.toCpp(r, "jret")};")) m.ret.fold()(ty => (spec.cppNnCheckExpression, isInterface(ty.resolved)) match {
case (Some(check), true) => {
// We have a non-optional interface, assert that we're getting a non-null value
val javaParams = m.params.map(p => javaMarshal.fqParamType(p.ty) + " " + idJava.local(p.ident))
val javaParamsString: String = javaParams.mkString("(", ",", ")")
val functionString: String = s"${javaMarshal.fqTypename(ident, i)}#$javaMethodName$javaParamsString"
w.wl(s"""DJINNI_ASSERT_MSG(jret, jniEnv, "Got unexpected null return value from function $functionString");""")
w.wl(s"return ${check}(${jniMarshal.toCpp(ty, "jret")});")
}
case _ =>
w.wl(s"return ${jniMarshal.toCpp(ty, "jret")};")
})
} }
} }
} }
...@@ -310,10 +322,27 @@ class JNIGenerator(spec: Spec) extends Generator(spec) { ...@@ -310,10 +322,27 @@ class JNIGenerator(spec: Spec) extends Generator(spec) {
nativeHook(nativeAddon + idJava.method(m.ident), m.static, m.params, m.ret, { nativeHook(nativeAddon + idJava.method(m.ident), m.static, m.params, m.ret, {
//w.wl(s"::${spec.jniNamespace}::JniLocalScope jscope(jniEnv, 10);") //w.wl(s"::${spec.jniNamespace}::JniLocalScope jscope(jniEnv, 10);")
if (!m.static) w.wl(s"const auto& ref = ::djinni::CppProxyHandle<$cppSelf>::get(nativeRef);") if (!m.static) w.wl(s"const auto& ref = ::djinni::CppProxyHandle<$cppSelf>::get(nativeRef);")
m.params.foreach(p => {
if (isInterface(p.ty.resolved) && spec.cppNnCheckExpression.nonEmpty) {
// We have a non-optional interface in nn mode, assert that we're getting a non-null value
val paramName = idJava.local(p.ident)
val javaMethodName = idJava.method(m.ident)
val javaParams = m.params.map(p => javaMarshal.fqParamType(p.ty) + " " + idJava.local(p.ident))
val javaParamsString: String = javaParams.mkString("(", ", ", ")")
val functionString: String = s"${javaMarshal.fqTypename(ident, i)}#$javaMethodName$javaParamsString"
w.wl( s"""DJINNI_ASSERT_MSG(j_$paramName, jniEnv, "Got unexpected null parameter '$paramName' to function $functionString");""")
}
})
val methodName = idCpp.method(m.ident) val methodName = idCpp.method(m.ident)
val ret = m.ret.fold("")(r => "auto r = ") val ret = m.ret.fold("")(r => "auto r = ")
val call = if (m.static) s"$cppSelf::$methodName(" else s"ref->$methodName(" val call = if (m.static) s"$cppSelf::$methodName(" else s"ref->$methodName("
writeAlignedCall(w, ret + call, m.params, ")", p => jniMarshal.toCpp(p.ty, "j_" + idJava.local(p.ident))) writeAlignedCall(w, ret + call, m.params, ")", p => {
val v = jniMarshal.toCpp(p.ty, "j_" + idJava.local(p.ident))
(spec.cppNnCheckExpression, isInterface(p.ty.resolved)) match {
case (Some(check), true) => s"$check($v)"
case _ => v
}
})
w.wl(";") w.wl(";")
m.ret.fold()(r => w.wl(s"return ::djinni::release(${jniMarshal.fromCpp(r, "r")});")) m.ret.fold()(r => w.wl(s"return ::djinni::release(${jniMarshal.fromCpp(r, "r")});"))
}) })
......
...@@ -31,7 +31,13 @@ class JNIMarshal(spec: Spec) extends Marshal(spec) { ...@@ -31,7 +31,13 @@ class JNIMarshal(spec: Spec) extends Marshal(spec) {
// Name for the autogenerated class containing field/method IDs and toJava()/fromJava() methods // Name for the autogenerated class containing field/method IDs and toJava()/fromJava() methods
def helperClass(name: String) = spec.jniClassIdentStyle(name) def helperClass(name: String) = spec.jniClassIdentStyle(name)
private def helperClass(tm: MExpr) = helperName(tm) + helperTemplates(tm) private def helperClass(tm: MExpr): String = {
if (isOptionalInterface(tm)) {
helperClass(tm.args.head)
} else {
helperName(tm) + helperTemplates(tm)
}
}
def references(m: Meta, exclude: String = ""): Seq[SymbolReference] = m match { def references(m: Meta, exclude: String = ""): Seq[SymbolReference] = m match {
case o: MOpaque => List(ImportRef(q(spec.jniBaseLibIncludePrefix + "Marshal.hpp"))) case o: MOpaque => List(ImportRef(q(spec.jniBaseLibIncludePrefix + "Marshal.hpp")))
...@@ -114,6 +120,7 @@ class JNIMarshal(spec: Spec) extends Marshal(spec) { ...@@ -114,6 +120,7 @@ class JNIMarshal(spec: Spec) extends Marshal(spec) {
tm.base match { tm.base match {
case MOptional => case MOptional =>
assert(tm.args.size == 1) assert(tm.args.size == 1)
assert(!isInterface(tm.args.head))
val argHelperClass = helperClass(tm.args.head) val argHelperClass = helperClass(tm.args.head)
s"<${spec.cppOptionalTemplate}, $argHelperClass>" s"<${spec.cppOptionalTemplate}, $argHelperClass>"
case MList | MSet => case MList | MSet =>
......
...@@ -39,18 +39,20 @@ class JavaMarshal(spec: Spec) extends Marshal(spec) { ...@@ -39,18 +39,20 @@ class JavaMarshal(spec: Spec) extends Marshal(spec) {
case _ => List() case _ => List()
} }
val interfaceNullityAnnotation = if (spec.cppNnType.nonEmpty) javaNonnullAnnotation else javaNullableAnnotation
def nullityAnnotation(ty: Option[TypeRef]): Option[String] = ty.map(nullityAnnotation).getOrElse(None) def nullityAnnotation(ty: Option[TypeRef]): Option[String] = ty.map(nullityAnnotation).getOrElse(None)
def nullityAnnotation(ty: TypeRef): Option[String] = { def nullityAnnotation(ty: TypeRef): Option[String] = {
ty.resolved.base match { ty.resolved.base match {
case MOptional => javaNullableAnnotation case MOptional => javaNullableAnnotation
case p: MPrimitive => None case p: MPrimitive => None
case m: MDef => m.defType match { case m: MDef => m.defType match {
case DInterface => javaNullableAnnotation case DInterface => interfaceNullityAnnotation
case DEnum => javaNonnullAnnotation case DEnum => javaNonnullAnnotation
case DRecord => javaNonnullAnnotation case DRecord => javaNonnullAnnotation
} }
case e: MExtern => e.defType match { case e: MExtern => e.defType match {
case DInterface => javaNullableAnnotation case DInterface => interfaceNullityAnnotation
case DRecord => if(e.java.reference) javaNonnullAnnotation else None case DRecord => if(e.java.reference) javaNonnullAnnotation else None
case DEnum => javaNonnullAnnotation case DEnum => javaNonnullAnnotation
} }
......
...@@ -31,6 +31,9 @@ object Main { ...@@ -31,6 +31,9 @@ object Main {
var cppOptionalTemplate: String = "std::optional" var cppOptionalTemplate: String = "std::optional"
var cppOptionalHeader: String = "<optional>" var cppOptionalHeader: String = "<optional>"
var cppEnumHashWorkaround : Boolean = true var cppEnumHashWorkaround : Boolean = true
var cppNnHeader: Option[String] = None
var cppNnType: Option[String] = None
var cppNnCheckExpression: Option[String] = None
var javaOutFolder: Option[File] = None var javaOutFolder: Option[File] = None
var javaPackage: Option[String] = None var javaPackage: Option[String] = None
var javaCppException: Option[String] = None var javaCppException: Option[String] = None
...@@ -119,6 +122,12 @@ object Main { ...@@ -119,6 +122,12 @@ object Main {
.text("The header to use for optional values (default: \"<optional>\")") .text("The header to use for optional values (default: \"<optional>\")")
opt[Boolean]("cpp-enum-hash-workaround").valueName("<true/false>").foreach(x => cppEnumHashWorkaround = x) opt[Boolean]("cpp-enum-hash-workaround").valueName("<true/false>").foreach(x => cppEnumHashWorkaround = x)
.text("Work around LWG-2148 by generating std::hash specializations for C++ enums (default: true)") .text("Work around LWG-2148 by generating std::hash specializations for C++ enums (default: true)")
opt[String]("cpp-nn-header").valueName("<header>").foreach(x => cppNnHeader = Some(x))
.text("The header to use for non-nullable pointers")
opt[String]("cpp-nn-type").valueName("<header>").foreach(x => cppNnType = Some(x))
.text("The type to use for non-nullable pointers (as a substitute for std::shared_ptr)")
opt[String]("cpp-nn-check-expression").valueName("<header>").foreach(x => cppNnCheckExpression = Some(x))
.text("The expression to use for building non-nullable pointers")
note("") note("")
opt[File]("jni-out").valueName("<out-folder>").foreach(x => jniOutFolder = Some(x)) opt[File]("jni-out").valueName("<out-folder>").foreach(x => jniOutFolder = Some(x))
.text("The folder for the JNI C++ output files (Generator disabled if unspecified).") .text("The folder for the JNI C++ output files (Generator disabled if unspecified).")
...@@ -270,6 +279,9 @@ object Main { ...@@ -270,6 +279,9 @@ object Main {
cppOptionalTemplate, cppOptionalTemplate,
cppOptionalHeader, cppOptionalHeader,
cppEnumHashWorkaround, cppEnumHashWorkaround,
cppNnHeader,
cppNnType,
cppNnCheckExpression,
jniOutFolder, jniOutFolder,
jniHeaderOutFolder, jniHeaderOutFolder,
jniIncludePrefix, jniIncludePrefix,
......
...@@ -18,17 +18,18 @@ class ObjcMarshal(spec: Spec) extends Marshal(spec) { ...@@ -18,17 +18,18 @@ class ObjcMarshal(spec: Spec) extends Marshal(spec) {
def nullability(tm: MExpr): Option[String] = { def nullability(tm: MExpr): Option[String] = {
val nonnull = Some("nonnull") val nonnull = Some("nonnull")
val nullable = Some("nullable") val nullable = Some("nullable")
val interfaceNullity = if (spec.cppNnType.nonEmpty) nonnull else nullable
tm.base match { tm.base match {
case MOptional => nullable case MOptional => nullable
case MPrimitive(_,_,_,_,_,_,_,_) => None case MPrimitive(_,_,_,_,_,_,_,_) => None
case d: MDef => d.defType match { case d: MDef => d.defType match {
case DEnum => None case DEnum => None
case DInterface => nullable case DInterface => interfaceNullity
case DRecord => nonnull case DRecord => nonnull
} }
case e: MExtern => e.defType match { case e: MExtern => e.defType match {
case DEnum => None case DEnum => None
case DInterface => nullable case DInterface => interfaceNullity
case DRecord => if(e.objc.pointer) nonnull else None case DRecord => if(e.objc.pointer) nonnull else None
} }
case _ => nonnull case _ => nonnull
......
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
package djinni package djinni
import djinni.ast.Record.DerivingType import java.io.StringWriter
import djinni.ast._ import djinni.ast._
import djinni.generatorTools._ import djinni.generatorTools._
import djinni.meta._ import djinni.meta._
...@@ -148,9 +149,29 @@ class ObjcppGenerator(spec: Spec) extends Generator(spec) { ...@@ -148,9 +149,29 @@ class ObjcppGenerator(spec: Spec) extends Generator(spec) {
writeObjcFuncDecl(m, w) writeObjcFuncDecl(m, w)
w.braced { w.braced {
w.w("try").bracedEnd(" DJINNI_TRANSLATE_EXCEPTIONS()") { w.w("try").bracedEnd(" DJINNI_TRANSLATE_EXCEPTIONS()") {
m.params.foreach(p => {
if (isInterface(p.ty.resolved) && spec.cppNnCheckExpression.nonEmpty) {
// We have a non-optional interface, assert that we're getting a non-null value
val paramName = idObjc.local(p.ident)
val stringWriter = new StringWriter()
writeObjcFuncDecl(m, new IndentWriter(stringWriter))
val singleLineFunctionDecl = stringWriter.toString.replaceAll("\n *", " ")
val exceptionReason = s"Got unexpected null parameter '$paramName' to function $objcSelf $singleLineFunctionDecl"
w.w(s"if ($paramName == nil)").braced {
w.wl(s"""throw std::invalid_argument("$exceptionReason");""")
}
}
})
val ret = m.ret.fold("")(_ => "auto r = ") val ret = m.ret.fold("")(_ => "auto r = ")
val call = ret + (if (!m.static) "_cppRefHandle.get()->" else cppSelf + "::") + idCpp.method(m.ident) + "(" val call = ret + (if (!m.static) "_cppRefHandle.get()->" else cppSelf + "::") + idCpp.method(m.ident) + "("
writeAlignedCall(w, call, m.params, ")", p => objcppMarshal.toCpp(p.ty, idObjc.local(p.ident.name))) writeAlignedCall(w, call, m.params, ")", p => {
val v = objcppMarshal.toCpp(p.ty, idObjc.local(p.ident.name))
(spec.cppNnCheckExpression, isInterface(p.ty.resolved)) match {
case (Some(check), true) => s"$check($v)"
case _ => v
}
})
w.wl(";") w.wl(";")
m.ret.fold()(r => w.wl(s"return ${objcppMarshal.fromCpp(r, "r")};")) m.ret.fold()(r => w.wl(s"return ${objcppMarshal.fromCpp(r, "r")};"))
} }
...@@ -177,7 +198,22 @@ class ObjcppGenerator(spec: Spec) extends Generator(spec) { ...@@ -177,7 +198,22 @@ class ObjcppGenerator(spec: Spec) extends Generator(spec) {
val call = s"[(ObjcType)Handle::get() ${idObjc.method(m.ident)}" val call = s"[(ObjcType)Handle::get() ${idObjc.method(m.ident)}"
writeAlignedObjcCall(w, ret + call, m.params, "]", p => (idObjc.field(p.ident), s"(${objcppMarshal.fromCpp(p.ty, "c_" + idCpp.local(p.ident))})")) writeAlignedObjcCall(w, ret + call, m.params, "]", p => (idObjc.field(p.ident), s"(${objcppMarshal.fromCpp(p.ty, "c_" + idCpp.local(p.ident))})"))
w.wl(";") w.wl(";")
m.ret.fold()(r => { w.wl(s"return ${objcppMarshal.toCpp(r, "r")};") }) m.ret.fold()(ty => (spec.cppNnCheckExpression, isInterface(ty.resolved)) match {
case (Some(check), true) => {
// We have a non-optional interface, assert that we're getting a non-null value
// and put it into a non-null pointer
val stringWriter = new StringWriter()
writeObjcFuncDecl(m, new IndentWriter(stringWriter))
val singleLineFunctionDecl = stringWriter.toString.replaceAll("\n *", " ")
val exceptionReason = s"Got unexpected null return value from function $objcSelf $singleLineFunctionDecl"
w.w(s"if (r == nil)").braced {
w.wl(s"""throw std::invalid_argument("$exceptionReason");""")
}
w.wl(s"return ${check}(${objcppMarshal.toCpp(ty, "r")});")
}
case _ =>
w.wl(s"return ${objcppMarshal.toCpp(ty, "r")};")
})
} }
} }
} }
......
...@@ -56,7 +56,13 @@ class ObjcppMarshal(spec: Spec) extends Marshal(spec) { ...@@ -56,7 +56,13 @@ class ObjcppMarshal(spec: Spec) extends Marshal(spec) {
} }
def helperClass(name: String) = idCpp.ty(name) def helperClass(name: String) = idCpp.ty(name)
private def helperClass(tm: MExpr): String = helperName(tm) + helperTemplates(tm) private def helperClass(tm: MExpr): String = {
if (isOptionalInterface(tm)) {
helperClass(tm.args.head)
} else {
helperName(tm) + helperTemplates(tm)
}
}
def privateHeaderName(ident: String): String = idObjc.ty(ident) + "+Private." + spec.objcHeaderExt def privateHeaderName(ident: String): String = idObjc.ty(ident) + "+Private." + spec.objcHeaderExt
......
...@@ -44,6 +44,9 @@ package object generatorTools { ...@@ -44,6 +44,9 @@ package object generatorTools {
cppOptionalTemplate: String, cppOptionalTemplate: String,
cppOptionalHeader: String, cppOptionalHeader: String,
cppEnumHashWorkaround: Boolean, cppEnumHashWorkaround: Boolean,
cppNnHeader: Option[String],
cppNnType: Option[String],
cppNnCheckExpression: Option[String],
jniOutFolder: Option[File], jniOutFolder: Option[File],
jniHeaderOutFolder: Option[File], jniHeaderOutFolder: Option[File],
jniIncludePrefix: String, jniIncludePrefix: String,
......
...@@ -98,4 +98,15 @@ val defaults: Map[String,MOpaque] = immutable.HashMap( ...@@ -98,4 +98,15 @@ val defaults: Map[String,MOpaque] = immutable.HashMap(
("list", MList), ("list", MList),
("set", MSet), ("set", MSet),
("map", MMap)) ("map", MMap))
def isInterface(ty: MExpr): Boolean = {
ty.base match {
case d: MDef => d.defType == DInterface
case _ => false
}
}
def isOptionalInterface(ty: MExpr): Boolean = {
ty.base == MOptional && ty.args.length == 1 && isInterface(ty.args.head)
}
} }
...@@ -162,15 +162,16 @@ void jniThrowCppFromJavaException(JNIEnv * env, jthrowable java_exception); ...@@ -162,15 +162,16 @@ void jniThrowCppFromJavaException(JNIEnv * env, jthrowable java_exception);
#endif #endif
void jniThrowAssertionError(JNIEnv * env, const char * file, int line, const char * check); void jniThrowAssertionError(JNIEnv * env, const char * file, int line, const char * check);
#define DJINNI_ASSERT(check, env) \ #define DJINNI_ASSERT_MSG(check, env, message) \
do { \ do { \
djinni::jniExceptionCheck(env); \ djinni::jniExceptionCheck(env); \
const bool check__res = bool(check); \ const bool check__res = bool(check); \
djinni::jniExceptionCheck(env); \ djinni::jniExceptionCheck(env); \
if (!check__res) { \ if (!check__res) { \
djinni::jniThrowAssertionError(env, __FILE__, __LINE__, #check); \ djinni::jniThrowAssertionError(env, __FILE__, __LINE__, message); \
} \ } \
} while(false) } while(false)
#define DJINNI_ASSERT(check, env) DJINNI_ASSERT_MSG(check, env, #check)
/* /*
* Helper for JniClassInitializer. * Helper for JniClassInitializer.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment