仿TensorFlow懒执行编程和注解自编排
名为【仿TensorFlow懒执行编程和注解自编排】,在我的个人代办里面。 依稀记得是很久前学习 TensorFlow 的一个想法,年尾翻到,简要 finish 一下这个 idea,属于一种临时简要的探索实践。
背景
TensorFlow 里面有个图计算,先编排好计算流程,最后点执行直接渲染节点数据。我觉得这种功能可以用 Java 实现,于是就简单写了这个应用。
下面我们来简单看下图计算模型是怎么样的渲染流程:
图计算
如下有计算模型
getFive 渲染流程
getFive 的时候,会回溯计算 five 需要的值 second 和 fourth,
发现 second 的值需要通过 first、origin2 计算,origin2 是已知的值,
然后需要计算 first = origin1 + origin2,然后 second = first + origin2
依次回溯递推,依次将 first、second、third、fourth、five 全部计算出来,并赋值。
通过 getFive 的动作,我们发现原来的模型被渲染成了下面这样。
当节点被计算赋值后(变成绿色的那种)就可以直接被使用,如上,再 getSix 时候,直接可以根据 third 的值进行计算。
PS: 那这种模型是否可用于业务代码呢?
1、复杂业务代码组装(过程变量可复用,自动化执行需要的步骤)
2、构建业务上下文,用于复杂业务处理支撑
3、对外提供简化版业务上下文,调用方按需直接获取,依赖按需加载,降低编码复杂度
Java 实现并应用
本来是想模拟 lombok 实现静态代码生成注入,这样可读性和性能会更好,不过先简单实现,后续有时间再研究和优化。
代码只有一两百行,感兴趣参见: https://github.com/petterobam/TensorFlow-Context
依赖
<dependency>
<groupId>com.github.petterobam</groupId>
<artifactId>tensor-flow-context</artifactId>
<version>1.0-SNAPSHOT</version>
</dependency>
定义计算模型
PS:计算模型定义需要在 Spring 扫包范围
@Data
@EnableTensorFlow
public class TfGetExampleContext {
private Integer origin1 = 10;
private Integer origin2 = 20;
/**
* 本地方法
*/
@TensorFlowGet(type = ServiceType.Local, methodName = "add", params = {"origin1", "origin2"})
private Integer first;
@TensorFlowGet(type = ServiceType.Local, methodName = "add", params = {"first", "origin2"})
private Integer second;
@TensorFlowGet(type = ServiceType.Local, methodName = "sub", params = {"origin1", "origin2"})
private Integer third;
/**
* 静态方法
*/
@TensorFlowGet(type = ServiceType.Static, classType = Math.class, methodName = "negateExact", paramTypes = {int.class}, params = {"third"})
private Integer forth;
@TensorFlowGet(type = ServiceType.Local, methodName = "add", params = {"second", "forth"})
private Integer five;
/**
* Spring Bean 的方法
*/
@TensorFlowGet(type = ServiceType.Spring, classType = TfGetExampleService.class, methodName = "calculate1", params = {"third"})
private Integer six;
@TensorFlowGet(type = ServiceType.Spring, springName = "tfGetExampleService", methodName = "calculate2", params = {"third"})
private Integer seven;
/**
* 自身入参
*/
@TensorFlowGet(type = ServiceType.Spring, springName = "tfGetExampleService", methodName = "fetchOrigins", params = {"this"})
private Map<String, Object> originDatas;
public Integer add(Integer a, Integer b) {
return (a == null ? 0 : a) + (b == null ? 0 : b);
}
public Integer sub(Integer c, Integer b) {
return (c == null ? 0 : c) - (b == null ? 0 : b);
}
}
@Component
public class TfGetExampleService {
public Integer calculate1(Integer val) {
if (null == val) {
return 0;
}
return val * val * val;
}
public Integer calculate2(Integer val) {
if (null == val) {
return 0;
}
return val + val + val;
}
public Map<String, Object> fetchOrigins(TfGetExampleContext context) {
Map<String, Object> res = new HashMap<>();
res.put("origin1", context.getOrigin1());
res.put("origin2", context.getOrigin2());
return res;
}
}
如何使用?
示例: http://127.0.0.1:7600/test/tf/context?origin1=87&origin2=23
@RequestMapping("/test/tf/context")
public Object testTfContext(Integer origin1, Integer origin2) throws InstantiationException, IllegalAccessException {
Map<String, Object> res = new HashMap<>();
TfGetExampleContext context = TensorFlowUtil.fetchTensorFlowContext(TfGetExampleContext.class);
context.setOrigin1(origin1);
context.setOrigin2(origin2);
Map<String, Object> res1 = new HashMap<>();
res1.put("five", context.getFive());
res1.put("first", context.getFirst());
res1.put("second", context.getSecond());
res1.put("third", context.getThird());
res1.put("seven", context.getSeven());
res1.put("fourth", context.getForth());
res1.put("six", context.getSix());
res1.put("originDatas", context.getOriginDatas());
res.put("data1", res1);
TfGetExampleContext context2 = TensorFlowUtil.fetchTensorFlowContext(TfGetExampleContext.class);
context2.setOrigin2(origin1);
context2.setOrigin1(origin2);
Map<String, Object> res2 = new HashMap<>();
res2.put("five", context2.getFive());
res2.put("first", context2.getFirst());
res2.put("second", context2.getSecond());
res2.put("third", context2.getThird());
res2.put("seven", context2.getSeven());
res2.put("fourth", context2.getForth());
res2.put("six", context2.getSix());
res2.put("originDatas", context2.getOriginDatas());
res.put("data2", res2);
log.info("TensorFlowGet for Context finish!");
return res;
}
计算结果
{
"data2": {
"six": 262144,
"third": 64,
"originDatas": {
"origin2": 23,
"origin1": 87
},
"seven": 192,
"fourth": -64,
"five": 69,
"first": 110,
"second": 133
},
"data1": {
"six": -262144,
"third": -64,
"originDatas": {
"origin2": 87,
"origin1": 23
},
"seven": -192,
"fourth": 64,
"five": 261,
"first": 110,
"second": 197
}
}
示例 - data1
示例 - data2
邀请标记你的阅读体验😉 | →